diff options
author | Christian Heimes <christian@cheimes.de> | 2013-11-23 14:59:07 (GMT) |
---|---|---|
committer | Christian Heimes <christian@cheimes.de> | 2013-11-23 14:59:07 (GMT) |
commit | 5de397e158d538ffa065974006e58547891bd955 (patch) | |
tree | dfa6715108e51cb33635bdf23210443f42898b03 /Lib/unittest/loader.py | |
parent | 4c05b472ddd4634138b6abfa857ee37761d33185 (diff) | |
parent | 2cf3917954dab65c025ea39f8b6a0298c598f9f7 (diff) | |
download | cpython-5de397e158d538ffa065974006e58547891bd955.zip cpython-5de397e158d538ffa065974006e58547891bd955.tar.gz cpython-5de397e158d538ffa065974006e58547891bd955.tar.bz2 |
merge
Diffstat (limited to 'Lib/unittest/loader.py')
-rw-r--r-- | Lib/unittest/loader.py | 60 |
1 files changed, 51 insertions, 9 deletions
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index e872fcc..808c50e 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -61,8 +61,9 @@ class TestLoader(object): def loadTestsFromTestCase(self, testCaseClass): """Return a suite of all tests cases contained in testCaseClass""" if issubclass(testCaseClass, suite.TestSuite): - raise TypeError("Test cases should not be derived from TestSuite." \ - " Maybe you meant to derive from TestCase?") + raise TypeError("Test cases should not be derived from " + "TestSuite. Maybe you meant to derive from " + "TestCase?") testCaseNames = self.getTestCaseNames(testCaseClass) if not testCaseNames and hasattr(testCaseClass, 'runTest'): testCaseNames = ['runTest'] @@ -200,6 +201,8 @@ class TestLoader(object): self._top_level_dir = top_level_dir is_not_importable = False + is_namespace = False + tests = [] if os.path.isdir(os.path.abspath(start_dir)): start_dir = os.path.abspath(start_dir) if start_dir != top_level_dir: @@ -213,15 +216,52 @@ class TestLoader(object): else: the_module = sys.modules[start_dir] top_part = start_dir.split('.')[0] - start_dir = os.path.abspath(os.path.dirname((the_module.__file__))) + try: + start_dir = os.path.abspath( + os.path.dirname((the_module.__file__))) + except AttributeError: + # look for namespace packages + try: + spec = the_module.__spec__ + except AttributeError: + spec = None + + if spec and spec.loader is None: + if spec.submodule_search_locations is not None: + is_namespace = True + + for path in the_module.__path__: + if (not set_implicit_top and + not path.startswith(top_level_dir)): + continue + self._top_level_dir = \ + (path.split(the_module.__name__ + .replace(".", os.path.sep))[0]) + tests.extend(self._find_tests(path, + pattern, + namespace=True)) + elif the_module.__name__ in sys.builtin_module_names: + # builtin module + raise TypeError('Can not use builtin modules ' + 'as dotted module names') from None + else: + raise TypeError( + 'don\'t know how to discover from {!r}' + .format(the_module)) from None + if set_implicit_top: - self._top_level_dir = self._get_directory_containing_module(top_part) - sys.path.remove(top_level_dir) + if not is_namespace: + self._top_level_dir = \ + self._get_directory_containing_module(top_part) + sys.path.remove(top_level_dir) + else: + sys.path.remove(top_level_dir) if is_not_importable: raise ImportError('Start directory is not importable: %r' % start_dir) - tests = list(self._find_tests(start_dir, pattern)) + if not is_namespace: + tests = list(self._find_tests(start_dir, pattern)) return self.suiteClass(tests) def _get_directory_containing_module(self, module_name): @@ -254,7 +294,7 @@ class TestLoader(object): # override this method to use alternative matching strategy return fnmatch(path, pattern) - def _find_tests(self, start_dir, pattern): + def _find_tests(self, start_dir, pattern, namespace=False): """Used by discovery. Yields test suites it loads.""" paths = sorted(os.listdir(start_dir)) @@ -287,7 +327,8 @@ class TestLoader(object): raise ImportError(msg % (mod_name, module_dir, expected_dir)) yield self.loadTestsFromModule(module) elif os.path.isdir(full_path): - if not os.path.isfile(os.path.join(full_path, '__init__.py')): + if (not namespace and + not os.path.isfile(os.path.join(full_path, '__init__.py'))): continue load_tests = None @@ -304,7 +345,8 @@ class TestLoader(object): # tests loaded from package file yield tests # recurse into the package - yield from self._find_tests(full_path, pattern) + yield from self._find_tests(full_path, pattern, + namespace=namespace) else: try: yield load_tests(self, tests, pattern) |