summaryrefslogtreecommitdiffstats
path: root/Lib/importlib/test/support.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/importlib/test/support.py')
-rw-r--r--Lib/importlib/test/support.py223
1 files changed, 223 insertions, 0 deletions
diff --git a/Lib/importlib/test/support.py b/Lib/importlib/test/support.py
new file mode 100644
index 0000000..4e63cd1
--- /dev/null
+++ b/Lib/importlib/test/support.py
@@ -0,0 +1,223 @@
+from importlib import Import
+
+from contextlib import contextmanager
+from functools import update_wrapper
+import imp
+import os.path
+from test.support import unlink
+import sys
+from tempfile import gettempdir
+
+
+using___import__ = False
+
+def import_(*args, **kwargs):
+ """Delegate to allow for injecting different implementations of import."""
+ if using___import__:
+ return __import__(*args, **kwargs)
+ return Import()(*args, **kwargs)
+
+def importlib_only(fxn):
+ """Decorator to mark which tests are not supported by the current
+ implementation of __import__()."""
+ def inner(*args, **kwargs):
+ if using___import__:
+ return
+ else:
+ return fxn(*args, **kwargs)
+ update_wrapper(inner, fxn)
+ return inner
+
+def writes_bytecode(fxn):
+ """Decorator that returns the function if writing bytecode is enabled, else
+ a stub function that accepts anything and simply returns None."""
+ if sys.dont_write_bytecode:
+ return lambda *args, **kwargs: None
+ else:
+ return fxn
+
+@contextmanager
+def uncache(*names):
+ """Uncache a module from sys.modules.
+
+ A basic sanity check is performed to prevent uncaching modules that either
+ cannot/shouldn't be uncached.
+
+ """
+ for name in names:
+ if name in ('sys', 'marshal', 'imp'):
+ raise ValueError(
+ "cannot uncache {0} as it will break _importlib".format(name))
+ try:
+ del sys.modules[name]
+ except KeyError:
+ pass
+ try:
+ yield
+ finally:
+ for name in names:
+ try:
+ del sys.modules[name]
+ except KeyError:
+ pass
+
+@contextmanager
+def import_state(**kwargs):
+ """Context manager to manage the various importers and stored state in the
+ sys module.
+
+ The 'modules' attribute is not supported as the interpreter state stores a
+ pointer to the dict that the interpreter uses internally;
+ reassigning to sys.modules does not have the desired effect.
+
+ """
+ originals = {}
+ try:
+ for attr, default in (('meta_path', []), ('path', []),
+ ('path_hooks', []),
+ ('path_importer_cache', {})):
+ originals[attr] = getattr(sys, attr)
+ if attr in kwargs:
+ new_value = kwargs[attr]
+ del kwargs[attr]
+ else:
+ new_value = default
+ setattr(sys, attr, new_value)
+ if len(kwargs):
+ raise ValueError(
+ 'unrecognized arguments: {0}'.format(kwargs.keys()))
+ yield
+ finally:
+ for attr, value in originals.items():
+ setattr(sys, attr, value)
+
+
+@contextmanager
+def create_modules(*names):
+ """Temporarily create each named module with an attribute (named 'attr')
+ that contains the name passed into the context manager that caused the
+ creation of the module.
+
+ All files are created in a temporary directory specified by
+ tempfile.gettempdir(). This directory is inserted at the beginning of
+ sys.path. When the context manager exits all created files (source and
+ bytecode) are explicitly deleted.
+
+ No magic is performed when creating packages! This means that if you create
+ a module within a package you must also create the package's __init__ as
+ well.
+
+ """
+ source = 'attr = {0!r}'
+ created_paths = []
+ mapping = {}
+ try:
+ temp_dir = gettempdir()
+ mapping['.root'] = temp_dir
+ import_names = set()
+ for name in names:
+ if not name.endswith('__init__'):
+ import_name = name
+ else:
+ import_name = name[:-len('.__init__')]
+ import_names.add(import_name)
+ if import_name in sys.modules:
+ del sys.modules[import_name]
+ name_parts = name.split('.')
+ file_path = temp_dir
+ for directory in name_parts[:-1]:
+ file_path = os.path.join(file_path, directory)
+ if not os.path.exists(file_path):
+ os.mkdir(file_path)
+ created_paths.append(file_path)
+ file_path = os.path.join(file_path, name_parts[-1] + '.py')
+ with open(file_path, 'w') as file:
+ file.write(source.format(name))
+ created_paths.append(file_path)
+ mapping[name] = file_path
+ uncache_manager = uncache(*import_names)
+ uncache_manager.__enter__()
+ state_manager = import_state(path=[temp_dir])
+ state_manager.__enter__()
+ yield mapping
+ finally:
+ state_manager.__exit__(None, None, None)
+ uncache_manager.__exit__(None, None, None)
+ # Reverse the order for path removal to unroll directory creation.
+ for path in reversed(created_paths):
+ if file_path.endswith('.py'):
+ unlink(path)
+ unlink(path + 'c')
+ unlink(path + 'o')
+ else:
+ os.rmdir(path)
+
+
+class mock_modules:
+
+ """A mock importer/loader."""
+
+ def __init__(self, *names):
+ self.modules = {}
+ for name in names:
+ if not name.endswith('.__init__'):
+ import_name = name
+ else:
+ import_name = name[:-len('.__init__')]
+ if '.' not in name:
+ package = None
+ elif import_name == name:
+ package = name.rsplit('.', 1)[0]
+ else:
+ package = import_name
+ module = imp.new_module(import_name)
+ module.__loader__ = self
+ module.__file__ = '<mock __file__>'
+ module.__package__ = package
+ module.attr = name
+ if import_name != name:
+ module.__path__ = ['<mock __path__>']
+ self.modules[import_name] = module
+
+ def __getitem__(self, name):
+ return self.modules[name]
+
+ def find_module(self, fullname, path=None):
+ if fullname not in self.modules:
+ return None
+ else:
+ return self
+
+ def load_module(self, fullname):
+ if fullname not in self.modules:
+ raise ImportError
+ else:
+ sys.modules[fullname] = self.modules[fullname]
+ return self.modules[fullname]
+
+ def __enter__(self):
+ self._uncache = uncache(*self.modules.keys())
+ self._uncache.__enter__()
+ return self
+
+ def __exit__(self, *exc_info):
+ self._uncache.__exit__(None, None, None)
+
+
+def mock_path_hook(*entries, importer):
+ """A mock sys.path_hooks entry."""
+ def hook(entry):
+ if entry not in entries:
+ raise ImportError
+ return importer
+ return hook
+
+
+def bytecode_path(source_path):
+ for suffix, _, type_ in imp.get_suffixes():
+ if type_ == imp.PY_COMPILED:
+ bc_suffix = suffix
+ break
+ else:
+ raise ValueError("no bytecode suffix is defined")
+ return os.path.splitext(source_path)[0] + bc_suffix