summaryrefslogtreecommitdiffstats
path: root/Lib/importlib/test/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/importlib/test/util.py')
-rw-r--r--Lib/importlib/test/util.py34
1 files changed, 20 insertions, 14 deletions
diff --git a/Lib/importlib/test/util.py b/Lib/importlib/test/util.py
index 845e380..93b7cd2 100644
--- a/Lib/importlib/test/util.py
+++ b/Lib/importlib/test/util.py
@@ -6,21 +6,22 @@ import unittest
import sys
-def case_insensitive_tests(class_):
+CASE_INSENSITIVE_FS = True
+# Windows is the only OS that is *always* case-insensitive
+# (OS X *can* be case-sensitive).
+if sys.platform not in ('win32', 'cygwin'):
+ changed_name = __file__.upper()
+ if changed_name == __file__:
+ changed_name = __file__.lower()
+ if not os.path.exists(changed_name):
+ CASE_INSENSITIVE_FS = False
+
+
+def case_insensitive_tests(test):
"""Class decorator that nullifies tests requiring a case-insensitive
file system."""
- # Windows is the only OS that is *always* case-insensitive
- # (OS X *can* be case-sensitive).
- if sys.platform not in ('win32', 'cygwin'):
- changed_name = __file__.upper()
- if changed_name == __file__:
- changed_name = __file__.lower()
- if os.path.exists(changed_name):
- return class_
- else:
- return unittest.TestCase
- else:
- return class_
+ return unittest.skipIf(not CASE_INSENSITIVE_FS,
+ "requires a case-insensitive filesystem")(test)
@contextmanager
@@ -83,8 +84,9 @@ class mock_modules:
"""A mock importer/loader."""
- def __init__(self, *names):
+ def __init__(self, *names, module_code={}):
self.modules = {}
+ self.module_code = {}
for name in names:
if not name.endswith('.__init__'):
import_name = name
@@ -104,6 +106,8 @@ class mock_modules:
if import_name != name:
module.__path__ = ['<mock __path__>']
self.modules[import_name] = module
+ if import_name in module_code:
+ self.module_code[import_name] = module_code[import_name]
def __getitem__(self, name):
return self.modules[name]
@@ -119,6 +123,8 @@ class mock_modules:
raise ImportError
else:
sys.modules[fullname] = self.modules[fullname]
+ if fullname in self.module_code:
+ self.module_code[fullname]()
return self.modules[fullname]
def __enter__(self):