diff options
Diffstat (limited to 'Lib/unittest/loader.py')
-rw-r--r-- | Lib/unittest/loader.py | 162 |
1 files changed, 105 insertions, 57 deletions
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index 811bedf..8c10ad1 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -65,6 +65,9 @@ class TestLoader(object): def __init__(self): super(TestLoader, self).__init__() self.errors = [] + # Tracks packages which we have called into via load_tests, to + # avoid infinite re-entrancy. + self._loading_packages = set() def loadTestsFromTestCase(self, testCaseClass): """Return a suite of all tests cases contained in testCaseClass""" @@ -229,9 +232,13 @@ class TestLoader(object): If a test package name (directory with '__init__.py') matches the pattern then the package will be checked for a 'load_tests' function. If - this exists then it will be called with loader, tests, pattern. + this exists then it will be called with (loader, tests, pattern) unless + the package has already had load_tests called from the same discovery + invocation, in which case the package module object is not scanned for + tests - this ensures that when a package uses discover to further + discover child tests that infinite recursion does not happen. - If load_tests exists then discovery does *not* recurse into the package, + If load_tests exists then discovery does *not* recurse into the package, load_tests is responsible for loading all tests in the package. The pattern is deliberately not stored as a loader attribute so that @@ -355,69 +362,110 @@ class TestLoader(object): def _find_tests(self, start_dir, pattern, namespace=False): """Used by discovery. Yields test suites it loads.""" + # Handle the __init__ in this package + name = self._get_name_from_path(start_dir) + # name is '.' when start_dir == top_level_dir (and top_level_dir is by + # definition not a package). + if name != '.' and name not in self._loading_packages: + # name is in self._loading_packages while we have called into + # loadTestsFromModule with name. + tests, should_recurse = self._find_test_path( + start_dir, pattern, namespace) + if tests is not None: + yield tests + if not should_recurse: + # Either an error occured, or load_tests was used by the + # package. + return + # Handle the contents. paths = sorted(os.listdir(start_dir)) - for path in paths: full_path = os.path.join(start_dir, path) - if os.path.isfile(full_path): - if not VALID_MODULE_NAME.match(path): - # valid Python identifiers only - continue - if not self._match_path(path, full_path, pattern): - continue - # if the test file matches, load it + tests, should_recurse = self._find_test_path( + full_path, pattern, namespace) + if tests is not None: + yield tests + if should_recurse: + # we found a package that didn't use load_tests. name = self._get_name_from_path(full_path) + self._loading_packages.add(name) try: - module = self._get_module_from_name(name) - except case.SkipTest as e: - yield _make_skipped_test(name, e, self.suiteClass) - except: - error_case, error_message = \ - _make_failed_import_test(name, self.suiteClass) - self.errors.append(error_message) - yield error_case - else: - mod_file = os.path.abspath(getattr(module, '__file__', full_path)) - realpath = _jython_aware_splitext(os.path.realpath(mod_file)) - fullpath_noext = _jython_aware_splitext(os.path.realpath(full_path)) - if realpath.lower() != fullpath_noext.lower(): - module_dir = os.path.dirname(realpath) - mod_name = _jython_aware_splitext(os.path.basename(full_path)) - expected_dir = os.path.dirname(full_path) - msg = ("%r module incorrectly imported from %r. Expected %r. " - "Is this module globally installed?") - raise ImportError(msg % (mod_name, module_dir, expected_dir)) - yield self.loadTestsFromModule(module, pattern=pattern) - elif os.path.isdir(full_path): - if (not namespace and - not os.path.isfile(os.path.join(full_path, '__init__.py'))): - continue - - load_tests = None - tests = None - name = self._get_name_from_path(full_path) + yield from self._find_tests(full_path, pattern, namespace) + finally: + self._loading_packages.discard(name) + + def _find_test_path(self, full_path, pattern, namespace=False): + """Used by discovery. + + Loads tests from a single file, or a directories' __init__.py when + passed the directory. + + Returns a tuple (None_or_tests_from_file, should_recurse). + """ + basename = os.path.basename(full_path) + if os.path.isfile(full_path): + if not VALID_MODULE_NAME.match(basename): + # valid Python identifiers only + return None, False + if not self._match_path(basename, full_path, pattern): + return None, False + # if the test file matches, load it + name = self._get_name_from_path(full_path) + try: + module = self._get_module_from_name(name) + except case.SkipTest as e: + return _make_skipped_test(name, e, self.suiteClass), False + except: + error_case, error_message = \ + _make_failed_import_test(name, self.suiteClass) + self.errors.append(error_message) + return error_case, False + else: + mod_file = os.path.abspath( + getattr(module, '__file__', full_path)) + realpath = _jython_aware_splitext( + os.path.realpath(mod_file)) + fullpath_noext = _jython_aware_splitext( + os.path.realpath(full_path)) + if realpath.lower() != fullpath_noext.lower(): + module_dir = os.path.dirname(realpath) + mod_name = _jython_aware_splitext( + os.path.basename(full_path)) + expected_dir = os.path.dirname(full_path) + msg = ("%r module incorrectly imported from %r. Expected " + "%r. Is this module globally installed?") + raise ImportError( + msg % (mod_name, module_dir, expected_dir)) + return self.loadTestsFromModule(module, pattern=pattern), False + elif os.path.isdir(full_path): + if (not namespace and + not os.path.isfile(os.path.join(full_path, '__init__.py'))): + return None, False + + load_tests = None + tests = None + name = self._get_name_from_path(full_path) + try: + package = self._get_module_from_name(name) + except case.SkipTest as e: + return _make_skipped_test(name, e, self.suiteClass), False + except: + error_case, error_message = \ + _make_failed_import_test(name, self.suiteClass) + self.errors.append(error_message) + return error_case, False + else: + load_tests = getattr(package, 'load_tests', None) + # Mark this package as being in load_tests (possibly ;)) + self._loading_packages.add(name) try: - package = self._get_module_from_name(name) - except case.SkipTest as e: - yield _make_skipped_test(name, e, self.suiteClass) - except: - error_case, error_message = \ - _make_failed_import_test(name, self.suiteClass) - self.errors.append(error_message) - yield error_case - else: - load_tests = getattr(package, 'load_tests', None) tests = self.loadTestsFromModule(package, pattern=pattern) - if tests is not None: - # tests loaded from package file - yield tests - if load_tests is not None: - # loadTestsFromModule(package) has load_tests for us. - continue - # recurse into the package - yield from self._find_tests(full_path, pattern, - namespace=namespace) + # loadTestsFromModule(package) has loaded tests for us. + return tests, False + return tests, True + finally: + self._loading_packages.discard(name) defaultTestLoader = TestLoader() |