diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/importlib/__init__.py | 15 | ||||
-rw-r--r-- | Lib/test/test_importlib/test_api.py | 68 |
2 files changed, 79 insertions, 4 deletions
diff --git a/Lib/importlib/__init__.py b/Lib/importlib/__init__.py index 3e969bb..a4b1f55 100644 --- a/Lib/importlib/__init__.py +++ b/Lib/importlib/__init__.py @@ -153,10 +153,17 @@ def reload(module): _RELOADING[name] = module try: parent_name = name.rpartition('.')[0] - if parent_name and parent_name not in sys.modules: - msg = "parent {!r} not in sys.modules" - raise ImportError(msg.format(parent_name), name=parent_name) - spec = module.__spec__ = _bootstrap._find_spec(name, None, module) + if parent_name: + try: + parent = sys.modules[parent_name] + except KeyError: + msg = "parent {!r} not in sys.modules" + raise ImportError(msg.format(parent_name), name=parent_name) + else: + pkgpath = parent.__path__ + else: + pkgpath = None + spec = module.__spec__ = _bootstrap._find_spec(name, pkgpath, module) methods = _bootstrap._SpecMethods(spec) methods.exec(module) # The module may have replaced itself in sys.modules! diff --git a/Lib/test/test_importlib/test_api.py b/Lib/test/test_importlib/test_api.py index c6c2d47..792c82d 100644 --- a/Lib/test/test_importlib/test_api.py +++ b/Lib/test/test_importlib/test_api.py @@ -4,6 +4,7 @@ frozen_init, source_init = util.import_importlib('importlib') frozen_util, source_util = util.import_importlib('importlib.util') frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') +from contextlib import contextmanager import os.path import sys from test import support @@ -11,6 +12,37 @@ import types import unittest +@contextmanager +def temp_module(name, content='', *, pkg=False): + conflicts = [n for n in sys.modules if n.partition('.')[0] == name] + with support.temp_cwd(None) as cwd: + with util.uncache(name, *conflicts): + with support.DirsOnSysPath(cwd): + frozen_init.invalidate_caches() + + location = os.path.join(cwd, name) + if pkg: + modpath = os.path.join(location, '__init__.py') + os.mkdir(name) + else: + modpath = location + '.py' + if content is None: + # Make sure the module file gets created. + content = '' + if content is not None: + # not a namespace package + with open(modpath, 'w') as modfile: + modfile.write(content) + yield location + + +def submodule(parent, name, pkg_dir, content=''): + path = os.path.join(pkg_dir, name + '.py') + with open(path, 'w') as subfile: + subfile.write(content) + return '{}.{}'.format(parent, name), path + + class ImportModuleTests: """Test importlib.import_module.""" @@ -246,6 +278,32 @@ class FindSpecTests: # None is returned upon failure to find a loader. self.assertIsNone(self.init.find_spec('nevergoingtofindthismodule')) + def test_find_submodule(self): + name = 'spam' + subname = 'ham' + with temp_module(name, pkg=True) as pkg_dir: + fullname, _ = submodule(name, subname, pkg_dir) + spec = self.init.find_spec(fullname, [pkg_dir]) + self.assertIsNot(spec, None) + self.assertNotIn(name, sorted(sys.modules)) + # Ensure successive calls behave the same. + spec_again = self.init.find_spec(fullname, [pkg_dir]) + # XXX Once #19927 is resolved, uncomment this line. + #self.assertEqual(spec_again, spec) + + def test_find_submodule_missing_path(self): + name = 'spam' + subname = 'ham' + with temp_module(name, pkg=True) as pkg_dir: + fullname, _ = submodule(name, subname, pkg_dir) + spec = self.init.find_spec(fullname) + self.assertIs(spec, None) + self.assertNotIn(name, sorted(sys.modules)) + # Ensure successive calls behave the same. + spec = self.init.find_spec(fullname) + self.assertIs(spec, None) + + class Frozen_FindSpecTests(FindSpecTests, unittest.TestCase): init = frozen_init machinery = frozen_machinery @@ -410,6 +468,16 @@ class ReloadTests: self.assertEqual(loader.path, init_path) self.assertEqual(ns, expected) + def test_reload_submodule(self): + # See #19851. + name = 'spam' + subname = 'ham' + with temp_module(name, pkg=True) as pkg_dir: + fullname, _ = submodule(name, subname, pkg_dir) + ham = self.init.import_module(fullname) + reloaded = self.init.reload(ham) + self.assertIs(reloaded, ham) + class Frozen_ReloadTests(ReloadTests, unittest.TestCase): init = frozen_init |