diff options
Diffstat (limited to 'Lib/test')
-rw-r--r-- | Lib/test/support/import_helper.py | 34 | ||||
-rw-r--r-- | Lib/test/support/os_helper.py | 4 | ||||
-rw-r--r-- | Lib/test/test_embed.py | 2 | ||||
-rw-r--r-- | Lib/test/test_frozen.py | 7 | ||||
-rw-r--r-- | Lib/test/test_importlib/frozen/test_finder.py | 8 | ||||
-rw-r--r-- | Lib/test/test_importlib/frozen/test_loader.py | 211 |
6 files changed, 164 insertions, 102 deletions
diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index 5d1e940..10f745a 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -1,4 +1,5 @@ import contextlib +import _imp import importlib import importlib.util import os @@ -109,7 +110,24 @@ def _save_and_block_module(name, orig_modules): return saved -def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): +@contextlib.contextmanager +def frozen_modules(enabled=True): + """Force frozen modules to be used (or not). + + This only applies to modules that haven't been imported yet. + Also, some essential modules will always be imported frozen. + """ + _imp._override_frozen_modules_for_tests(1 if enabled else -1) + try: + yield + finally: + _imp._override_frozen_modules_for_tests(0) + + +def import_fresh_module(name, fresh=(), blocked=(), *, + deprecated=False, + usefrozen=False, + ): """Import and return a module, deliberately bypassing sys.modules. This function imports and returns a fresh copy of the named Python module @@ -133,6 +151,9 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): This function will raise ImportError if the named module cannot be imported. + + If "usefrozen" is False (the default) then the frozen importer is + disabled (except for essential modules like importlib._bootstrap). """ # NOTE: test_heapq, test_json and test_warnings include extra sanity checks # to make sure that this utility function is working as expected @@ -148,7 +169,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): for blocked_name in blocked: if not _save_and_block_module(blocked_name, orig_modules): names_to_remove.append(blocked_name) - fresh_module = importlib.import_module(name) + with frozen_modules(usefrozen): + fresh_module = importlib.import_module(name) except ImportError: fresh_module = None finally: @@ -169,9 +191,12 @@ class CleanImport(object): with CleanImport("foo"): importlib.import_module("foo") # new reference + + If "usefrozen" is False (the default) then the frozen importer is + disabled (except for essential modules like importlib._bootstrap). """ - def __init__(self, *module_names): + def __init__(self, *module_names, usefrozen=False): self.original_modules = sys.modules.copy() for module_name in module_names: if module_name in sys.modules: @@ -183,12 +208,15 @@ class CleanImport(object): if module.__name__ != module_name: del sys.modules[module.__name__] del sys.modules[module_name] + self._frozen_modules = frozen_modules(usefrozen) def __enter__(self): + self._frozen_modules.__enter__() return self def __exit__(self, *ignore_exc): sys.modules.update(self.original_modules) + self._frozen_modules.__exit__(*ignore_exc) class DirsOnSysPath(object): diff --git a/Lib/test/support/os_helper.py b/Lib/test/support/os_helper.py index d9807a1..ce01417 100644 --- a/Lib/test/support/os_helper.py +++ b/Lib/test/support/os_helper.py @@ -599,6 +599,10 @@ class EnvironmentVarGuard(collections.abc.MutableMapping): def unset(self, envvar): del self[envvar] + def copy(self): + # We do what os.environ.copy() does. + return dict(self) + def __enter__(self): return self diff --git a/Lib/test/test_embed.py b/Lib/test/test_embed.py index 8e3dd50..e5e7c83 100644 --- a/Lib/test/test_embed.py +++ b/Lib/test/test_embed.py @@ -12,6 +12,7 @@ import re import shutil import subprocess import sys +import sysconfig import tempfile import textwrap @@ -426,6 +427,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase): 'pathconfig_warnings': 1, '_init_main': 1, '_isolated_interpreter': 0, + 'use_frozen_modules': False, } if MS_WINDOWS: CONFIG_COMPAT.update({ diff --git a/Lib/test/test_frozen.py b/Lib/test/test_frozen.py index 142f17d..52d8f7c 100644 --- a/Lib/test/test_frozen.py +++ b/Lib/test/test_frozen.py @@ -12,7 +12,7 @@ import sys import unittest -from test.support import captured_stdout +from test.support import captured_stdout, import_helper class TestFrozen(unittest.TestCase): @@ -20,8 +20,9 @@ class TestFrozen(unittest.TestCase): name = '__hello__' if name in sys.modules: del sys.modules[name] - with captured_stdout() as out: - import __hello__ + with import_helper.frozen_modules(): + with captured_stdout() as out: + import __hello__ self.assertEqual(out.getvalue(), 'Hello world!\n') diff --git a/Lib/test/test_importlib/frozen/test_finder.py b/Lib/test/test_importlib/frozen/test_finder.py index eb7a4d2..fbc3fc0 100644 --- a/Lib/test/test_importlib/frozen/test_finder.py +++ b/Lib/test/test_importlib/frozen/test_finder.py @@ -6,6 +6,8 @@ machinery = util.import_importlib('importlib.machinery') import unittest import warnings +from test.support import import_helper + class FindSpecTests(abc.FinderTests): @@ -13,7 +15,8 @@ class FindSpecTests(abc.FinderTests): def find(self, name, path=None): finder = self.machinery.FrozenImporter - return finder.find_spec(name, path) + with import_helper.frozen_modules(): + return finder.find_spec(name, path) def test_module(self): name = '__hello__' @@ -52,7 +55,8 @@ class FinderTests(abc.FinderTests): finder = self.machinery.FrozenImporter with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) - return finder.find_module(name, path) + with import_helper.frozen_modules(): + return finder.find_module(name, path) def test_module(self): name = '__hello__' diff --git a/Lib/test/test_importlib/frozen/test_loader.py b/Lib/test/test_importlib/frozen/test_loader.py index f0cf179..1b0a56f 100644 --- a/Lib/test/test_importlib/frozen/test_loader.py +++ b/Lib/test/test_importlib/frozen/test_loader.py @@ -3,27 +3,54 @@ from .. import util machinery = util.import_importlib('importlib.machinery') -from test.support import captured_stdout +from test.support import captured_stdout, import_helper +import contextlib import types import unittest import warnings +@contextlib.contextmanager +def deprecated(): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + yield + + +@contextlib.contextmanager +def fresh(name, *, oldapi=False): + with util.uncache(name): + with import_helper.frozen_modules(): + with captured_stdout() as stdout: + if oldapi: + with deprecated(): + yield stdout + else: + yield stdout + + class ExecModuleTests(abc.LoaderTests): def exec_module(self, name): - with util.uncache(name), captured_stdout() as stdout: - spec = self.machinery.ModuleSpec( - name, self.machinery.FrozenImporter, origin='frozen', - is_package=self.machinery.FrozenImporter.is_package(name)) - module = types.ModuleType(name) - module.__spec__ = spec - assert not hasattr(module, 'initialized') + with import_helper.frozen_modules(): + is_package = self.machinery.FrozenImporter.is_package(name) + spec = self.machinery.ModuleSpec( + name, + self.machinery.FrozenImporter, + origin='frozen', + is_package=is_package, + ) + module = types.ModuleType(name) + module.__spec__ = spec + assert not hasattr(module, 'initialized') + + with fresh(name) as stdout: self.machinery.FrozenImporter.exec_module(module) - self.assertTrue(module.initialized) - self.assertTrue(hasattr(module, '__spec__')) - self.assertEqual(module.__spec__.origin, 'frozen') - return module, stdout.getvalue() + + self.assertTrue(module.initialized) + self.assertTrue(hasattr(module, '__spec__')) + self.assertEqual(module.__spec__.origin, 'frozen') + return module, stdout.getvalue() def test_module(self): name = '__hello__' @@ -50,20 +77,19 @@ class ExecModuleTests(abc.LoaderTests): name = '__phello__.spam' with util.uncache('__phello__'): module, output = self.exec_module(name) - check = {'__name__': name} - for attr, value in check.items(): - attr_value = getattr(module, attr) - self.assertEqual(attr_value, value, - 'for {name}.{attr}, {given} != {expected!r}'.format( - name=name, attr=attr, given=attr_value, - expected=value)) - self.assertEqual(output, 'Hello world!\n') + check = {'__name__': name} + for attr, value in check.items(): + attr_value = getattr(module, attr) + self.assertEqual(attr_value, value, + 'for {name}.{attr}, {given} != {expected!r}'.format( + name=name, attr=attr, given=attr_value, + expected=value)) + self.assertEqual(output, 'Hello world!\n') def test_module_repr(self): name = '__hello__' module, output = self.exec_module(name) - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) + with deprecated(): repr_str = self.machinery.FrozenImporter.module_repr(module) self.assertEqual(repr_str, "<module '__hello__' (frozen)>") @@ -78,7 +104,8 @@ class ExecModuleTests(abc.LoaderTests): test_state_after_failure = None def test_unloadable(self): - assert self.machinery.FrozenImporter.find_spec('_not_real') is None + with import_helper.frozen_modules(): + assert self.machinery.FrozenImporter.find_spec('_not_real') is None with self.assertRaises(ImportError) as cm: self.exec_module('_not_real') self.assertEqual(cm.exception.name, '_not_real') @@ -91,84 +118,76 @@ class ExecModuleTests(abc.LoaderTests): class LoaderTests(abc.LoaderTests): + def load_module(self, name): + with fresh(name, oldapi=True) as stdout: + module = self.machinery.FrozenImporter.load_module(name) + return module, stdout + def test_module(self): - with util.uncache('__hello__'), captured_stdout() as stdout: - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - module = self.machinery.FrozenImporter.load_module('__hello__') - check = {'__name__': '__hello__', - '__package__': '', - '__loader__': self.machinery.FrozenImporter, - } - for attr, value in check.items(): - self.assertEqual(getattr(module, attr), value) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - self.assertFalse(hasattr(module, '__file__')) + module, stdout = self.load_module('__hello__') + check = {'__name__': '__hello__', + '__package__': '', + '__loader__': self.machinery.FrozenImporter, + } + for attr, value in check.items(): + self.assertEqual(getattr(module, attr), value) + self.assertEqual(stdout.getvalue(), 'Hello world!\n') + self.assertFalse(hasattr(module, '__file__')) def test_package(self): - with util.uncache('__phello__'), captured_stdout() as stdout: - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - module = self.machinery.FrozenImporter.load_module('__phello__') - check = {'__name__': '__phello__', - '__package__': '__phello__', - '__path__': [], - '__loader__': self.machinery.FrozenImporter, - } - for attr, value in check.items(): - attr_value = getattr(module, attr) - self.assertEqual(attr_value, value, - "for __phello__.%s, %r != %r" % - (attr, attr_value, value)) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - self.assertFalse(hasattr(module, '__file__')) + module, stdout = self.load_module('__phello__') + check = {'__name__': '__phello__', + '__package__': '__phello__', + '__path__': [], + '__loader__': self.machinery.FrozenImporter, + } + for attr, value in check.items(): + attr_value = getattr(module, attr) + self.assertEqual(attr_value, value, + "for __phello__.%s, %r != %r" % + (attr, attr_value, value)) + self.assertEqual(stdout.getvalue(), 'Hello world!\n') + self.assertFalse(hasattr(module, '__file__')) def test_lacking_parent(self): - with util.uncache('__phello__', '__phello__.spam'), \ - captured_stdout() as stdout: - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - module = self.machinery.FrozenImporter.load_module('__phello__.spam') - check = {'__name__': '__phello__.spam', - '__package__': '__phello__', - '__loader__': self.machinery.FrozenImporter, - } - for attr, value in check.items(): - attr_value = getattr(module, attr) - self.assertEqual(attr_value, value, - "for __phello__.spam.%s, %r != %r" % - (attr, attr_value, value)) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - self.assertFalse(hasattr(module, '__file__')) + with util.uncache('__phello__'): + module, stdout = self.load_module('__phello__.spam') + check = {'__name__': '__phello__.spam', + '__package__': '__phello__', + '__loader__': self.machinery.FrozenImporter, + } + for attr, value in check.items(): + attr_value = getattr(module, attr) + self.assertEqual(attr_value, value, + "for __phello__.spam.%s, %r != %r" % + (attr, attr_value, value)) + self.assertEqual(stdout.getvalue(), 'Hello world!\n') + self.assertFalse(hasattr(module, '__file__')) def test_module_reuse(self): - with util.uncache('__hello__'), captured_stdout() as stdout: - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - module1 = self.machinery.FrozenImporter.load_module('__hello__') - module2 = self.machinery.FrozenImporter.load_module('__hello__') - self.assertIs(module1, module2) - self.assertEqual(stdout.getvalue(), - 'Hello world!\nHello world!\n') + with fresh('__hello__', oldapi=True) as stdout: + module1 = self.machinery.FrozenImporter.load_module('__hello__') + module2 = self.machinery.FrozenImporter.load_module('__hello__') + self.assertIs(module1, module2) + self.assertEqual(stdout.getvalue(), + 'Hello world!\nHello world!\n') def test_module_repr(self): - with util.uncache('__hello__'), captured_stdout(): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - module = self.machinery.FrozenImporter.load_module('__hello__') - repr_str = self.machinery.FrozenImporter.module_repr(module) - self.assertEqual(repr_str, - "<module '__hello__' (frozen)>") + with fresh('__hello__', oldapi=True) as stdout: + module = self.machinery.FrozenImporter.load_module('__hello__') + repr_str = self.machinery.FrozenImporter.module_repr(module) + self.assertEqual(repr_str, + "<module '__hello__' (frozen)>") # No way to trigger an error in a frozen module. test_state_after_failure = None def test_unloadable(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - assert self.machinery.FrozenImporter.find_module('_not_real') is None + with import_helper.frozen_modules(): + with deprecated(): + assert self.machinery.FrozenImporter.find_module('_not_real') is None with self.assertRaises(ImportError) as cm: - self.machinery.FrozenImporter.load_module('_not_real') + self.load_module('_not_real') self.assertEqual(cm.exception.name, '_not_real') @@ -185,15 +204,17 @@ class InspectLoaderTests: # Make sure that the code object is good. name = '__hello__' with captured_stdout() as stdout: - code = self.machinery.FrozenImporter.get_code(name) - mod = types.ModuleType(name) - exec(code, mod.__dict__) - self.assertTrue(hasattr(mod, 'initialized')) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') + with import_helper.frozen_modules(): + code = self.machinery.FrozenImporter.get_code(name) + mod = types.ModuleType(name) + exec(code, mod.__dict__) + self.assertTrue(hasattr(mod, 'initialized')) + self.assertEqual(stdout.getvalue(), 'Hello world!\n') def test_get_source(self): # Should always return None. - result = self.machinery.FrozenImporter.get_source('__hello__') + with import_helper.frozen_modules(): + result = self.machinery.FrozenImporter.get_source('__hello__') self.assertIsNone(result) def test_is_package(self): @@ -201,7 +222,8 @@ class InspectLoaderTests: test_for = (('__hello__', False), ('__phello__', True), ('__phello__.spam', False)) for name, is_package in test_for: - result = self.machinery.FrozenImporter.is_package(name) + with import_helper.frozen_modules(): + result = self.machinery.FrozenImporter.is_package(name) self.assertEqual(bool(result), is_package) def test_failure(self): @@ -209,7 +231,8 @@ class InspectLoaderTests: for meth_name in ('get_code', 'get_source', 'is_package'): method = getattr(self.machinery.FrozenImporter, meth_name) with self.assertRaises(ImportError) as cm: - method('importlib') + with import_helper.frozen_modules(): + method('importlib') self.assertEqual(cm.exception.name, 'importlib') (Frozen_ILTests, |