diff options
author | Robert Collins <rbtcollins@hp.com> | 2014-11-04 14:09:01 (GMT) |
---|---|---|
committer | Robert Collins <rbtcollins@hp.com> | 2014-11-04 14:09:01 (GMT) |
commit | bf2bda3c9704181cebb6163f5eacd5ad4e1c15f4 (patch) | |
tree | 1f2e15056fc7e7bd965dd5b043183bc575d28f62 /Lib/unittest | |
parent | d39e199a0d49504583c5672252f653fc01837e32 (diff) | |
download | cpython-bf2bda3c9704181cebb6163f5eacd5ad4e1c15f4.zip cpython-bf2bda3c9704181cebb6163f5eacd5ad4e1c15f4.tar.gz cpython-bf2bda3c9704181cebb6163f5eacd5ad4e1c15f4.tar.bz2 |
Close #22457: Honour load_tests in the start_dir of discovery.
We were not honouring load_tests in a package/__init__.py when that was the
start_dir parameter, though we do when it is a child package. The fix required
a little care since it introduces the possibility of infinite recursion.
Diffstat (limited to 'Lib/unittest')
-rw-r--r-- | Lib/unittest/__init__.py | 9 | ||||
-rw-r--r-- | Lib/unittest/loader.py | 162 | ||||
-rw-r--r-- | Lib/unittest/test/test_discovery.py | 45 | ||||
-rw-r--r-- | Lib/unittest/test/test_loader.py | 2 |
4 files changed, 160 insertions, 58 deletions
diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py index a5d50af..f6d7ae2 100644 --- a/Lib/unittest/__init__.py +++ b/Lib/unittest/__init__.py @@ -67,3 +67,12 @@ from .signals import installHandler, registerResult, removeResult, removeHandler # deprecated _TextTestResult = TextTestResult + +# There are no tests here, so don't try to run anything discovered from +# introspecting the symbols (e.g. FunctionTestCase). Instead, all our +# tests come from within unittest.test. +def load_tests(loader, tests, pattern): + import os.path + # top level directory cached on loader instance + this_dir = os.path.dirname(__file__) + return loader.discover(start_dir=this_dir, pattern=pattern) diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index 811bedf..8c10ad1 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -65,6 +65,9 @@ class TestLoader(object): def __init__(self): super(TestLoader, self).__init__() self.errors = [] + # Tracks packages which we have called into via load_tests, to + # avoid infinite re-entrancy. + self._loading_packages = set() def loadTestsFromTestCase(self, testCaseClass): """Return a suite of all tests cases contained in testCaseClass""" @@ -229,9 +232,13 @@ class TestLoader(object): If a test package name (directory with '__init__.py') matches the pattern then the package will be checked for a 'load_tests' function. If - this exists then it will be called with loader, tests, pattern. + this exists then it will be called with (loader, tests, pattern) unless + the package has already had load_tests called from the same discovery + invocation, in which case the package module object is not scanned for + tests - this ensures that when a package uses discover to further + discover child tests that infinite recursion does not happen. - If load_tests exists then discovery does *not* recurse into the package, + If load_tests exists then discovery does *not* recurse into the package, load_tests is responsible for loading all tests in the package. The pattern is deliberately not stored as a loader attribute so that @@ -355,69 +362,110 @@ class TestLoader(object): def _find_tests(self, start_dir, pattern, namespace=False): """Used by discovery. Yields test suites it loads.""" + # Handle the __init__ in this package + name = self._get_name_from_path(start_dir) + # name is '.' when start_dir == top_level_dir (and top_level_dir is by + # definition not a package). + if name != '.' and name not in self._loading_packages: + # name is in self._loading_packages while we have called into + # loadTestsFromModule with name. + tests, should_recurse = self._find_test_path( + start_dir, pattern, namespace) + if tests is not None: + yield tests + if not should_recurse: + # Either an error occured, or load_tests was used by the + # package. + return + # Handle the contents. paths = sorted(os.listdir(start_dir)) - for path in paths: full_path = os.path.join(start_dir, path) - if os.path.isfile(full_path): - if not VALID_MODULE_NAME.match(path): - # valid Python identifiers only - continue - if not self._match_path(path, full_path, pattern): - continue - # if the test file matches, load it + tests, should_recurse = self._find_test_path( + full_path, pattern, namespace) + if tests is not None: + yield tests + if should_recurse: + # we found a package that didn't use load_tests. name = self._get_name_from_path(full_path) + self._loading_packages.add(name) try: - module = self._get_module_from_name(name) - except case.SkipTest as e: - yield _make_skipped_test(name, e, self.suiteClass) - except: - error_case, error_message = \ - _make_failed_import_test(name, self.suiteClass) - self.errors.append(error_message) - yield error_case - else: - mod_file = os.path.abspath(getattr(module, '__file__', full_path)) - realpath = _jython_aware_splitext(os.path.realpath(mod_file)) - fullpath_noext = _jython_aware_splitext(os.path.realpath(full_path)) - if realpath.lower() != fullpath_noext.lower(): - module_dir = os.path.dirname(realpath) - mod_name = _jython_aware_splitext(os.path.basename(full_path)) - expected_dir = os.path.dirname(full_path) - msg = ("%r module incorrectly imported from %r. Expected %r. " - "Is this module globally installed?") - raise ImportError(msg % (mod_name, module_dir, expected_dir)) - yield self.loadTestsFromModule(module, pattern=pattern) - elif os.path.isdir(full_path): - if (not namespace and - not os.path.isfile(os.path.join(full_path, '__init__.py'))): - continue - - load_tests = None - tests = None - name = self._get_name_from_path(full_path) + yield from self._find_tests(full_path, pattern, namespace) + finally: + self._loading_packages.discard(name) + + def _find_test_path(self, full_path, pattern, namespace=False): + """Used by discovery. + + Loads tests from a single file, or a directories' __init__.py when + passed the directory. + + Returns a tuple (None_or_tests_from_file, should_recurse). + """ + basename = os.path.basename(full_path) + if os.path.isfile(full_path): + if not VALID_MODULE_NAME.match(basename): + # valid Python identifiers only + return None, False + if not self._match_path(basename, full_path, pattern): + return None, False + # if the test file matches, load it + name = self._get_name_from_path(full_path) + try: + module = self._get_module_from_name(name) + except case.SkipTest as e: + return _make_skipped_test(name, e, self.suiteClass), False + except: + error_case, error_message = \ + _make_failed_import_test(name, self.suiteClass) + self.errors.append(error_message) + return error_case, False + else: + mod_file = os.path.abspath( + getattr(module, '__file__', full_path)) + realpath = _jython_aware_splitext( + os.path.realpath(mod_file)) + fullpath_noext = _jython_aware_splitext( + os.path.realpath(full_path)) + if realpath.lower() != fullpath_noext.lower(): + module_dir = os.path.dirname(realpath) + mod_name = _jython_aware_splitext( + os.path.basename(full_path)) + expected_dir = os.path.dirname(full_path) + msg = ("%r module incorrectly imported from %r. Expected " + "%r. Is this module globally installed?") + raise ImportError( + msg % (mod_name, module_dir, expected_dir)) + return self.loadTestsFromModule(module, pattern=pattern), False + elif os.path.isdir(full_path): + if (not namespace and + not os.path.isfile(os.path.join(full_path, '__init__.py'))): + return None, False + + load_tests = None + tests = None + name = self._get_name_from_path(full_path) + try: + package = self._get_module_from_name(name) + except case.SkipTest as e: + return _make_skipped_test(name, e, self.suiteClass), False + except: + error_case, error_message = \ + _make_failed_import_test(name, self.suiteClass) + self.errors.append(error_message) + return error_case, False + else: + load_tests = getattr(package, 'load_tests', None) + # Mark this package as being in load_tests (possibly ;)) + self._loading_packages.add(name) try: - package = self._get_module_from_name(name) - except case.SkipTest as e: - yield _make_skipped_test(name, e, self.suiteClass) - except: - error_case, error_message = \ - _make_failed_import_test(name, self.suiteClass) - self.errors.append(error_message) - yield error_case - else: - load_tests = getattr(package, 'load_tests', None) tests = self.loadTestsFromModule(package, pattern=pattern) - if tests is not None: - # tests loaded from package file - yield tests - if load_tests is not None: - # loadTestsFromModule(package) has load_tests for us. - continue - # recurse into the package - yield from self._find_tests(full_path, pattern, - namespace=namespace) + # loadTestsFromModule(package) has loaded tests for us. + return tests, False + return tests, True + finally: + self._loading_packages.discard(name) defaultTestLoader = TestLoader() diff --git a/Lib/unittest/test/test_discovery.py b/Lib/unittest/test/test_discovery.py index 92b983a..4f61314 100644 --- a/Lib/unittest/test/test_discovery.py +++ b/Lib/unittest/test/test_discovery.py @@ -368,6 +368,51 @@ class TestDiscovery(unittest.TestCase): self.assertEqual(_find_tests_args, [(start_dir, 'pattern')]) self.assertIn(top_level_dir, sys.path) + def test_discover_start_dir_is_package_calls_package_load_tests(self): + # This test verifies that the package load_tests in a package is indeed + # invoked when the start_dir is a package (and not the top level). + # http://bugs.python.org/issue22457 + + # Test data: we expect the following: + # an isfile to verify the package, then importing and scanning + # as per _find_tests' normal behaviour. + # We expect to see our load_tests hook called once. + vfs = {abspath('/toplevel'): ['startdir'], + abspath('/toplevel/startdir'): ['__init__.py']} + def list_dir(path): + return list(vfs[path]) + self.addCleanup(setattr, os, 'listdir', os.listdir) + os.listdir = list_dir + self.addCleanup(setattr, os.path, 'isfile', os.path.isfile) + os.path.isfile = lambda path: path.endswith('.py') + self.addCleanup(setattr, os.path, 'isdir', os.path.isdir) + os.path.isdir = lambda path: not path.endswith('.py') + self.addCleanup(sys.path.remove, abspath('/toplevel')) + + class Module(object): + paths = [] + load_tests_args = [] + + def __init__(self, path): + self.path = path + + def load_tests(self, loader, tests, pattern): + return ['load_tests called ' + self.path] + + def __eq__(self, other): + return self.path == other.path + + loader = unittest.TestLoader() + loader._get_module_from_name = lambda name: Module(name) + loader.suiteClass = lambda thing: thing + + suite = loader.discover('/toplevel/startdir', top_level_dir='/toplevel') + + # We should have loaded tests from the package __init__. + # (normally this would be nested TestSuites.) + self.assertEqual(suite, + [['load_tests called startdir']]) + def setup_import_issue_tests(self, fakefile): listdir = os.listdir os.listdir = lambda _: [fakefile] diff --git a/Lib/unittest/test/test_loader.py b/Lib/unittest/test/test_loader.py index c489730..68f1036 100644 --- a/Lib/unittest/test/test_loader.py +++ b/Lib/unittest/test/test_loader.py @@ -841,7 +841,7 @@ class Test_TestLoader(unittest.TestCase): loader = unittest.TestLoader() suite = loader.loadTestsFromNames( - ['unittest.loader.sdasfasfasdf', 'unittest']) + ['unittest.loader.sdasfasfasdf', 'unittest.test.dummy']) error, test = self.check_deferred_error(loader, list(suite)[0]) expected = "module 'unittest.loader' has no attribute 'sdasfasfasdf'" self.assertIn( |