summaryrefslogtreecommitdiffstats
path: root/Lib/unittest/loader.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/unittest/loader.py')
-rw-r--r--Lib/unittest/loader.py256
1 files changed, 186 insertions, 70 deletions
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index af39216..c776f16 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -6,6 +6,7 @@ import sys
import traceback
import types
import functools
+import warnings
from fnmatch import fnmatch
@@ -13,9 +14,9 @@ from . import case, suite, util
__unittest = True
-# what about .pyc or .pyo (etc)
+# what about .pyc (etc)
# we would need to avoid loading the same tests multiple times
-# from '.py', '.pyc' *and* '.pyo'
+# from '.py', *and* '.pyc'
VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
@@ -35,15 +36,18 @@ class _FailedTest(case.TestCase):
def _make_failed_import_test(name, suiteClass):
- message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
- return _make_failed_test(name, ImportError(message), suiteClass)
+ message = 'Failed to import test module: %s\n%s' % (
+ name, traceback.format_exc())
+ return _make_failed_test(name, ImportError(message), suiteClass, message)
def _make_failed_load_tests(name, exception, suiteClass):
- return _make_failed_test(name, exception, suiteClass)
+ message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),)
+ return _make_failed_test(
+ name, exception, suiteClass, message)
-def _make_failed_test(methodname, exception, suiteClass):
+def _make_failed_test(methodname, exception, suiteClass, message):
test = _FailedTest(methodname, exception)
- return suiteClass((test,))
+ return suiteClass((test,)), message
def _make_skipped_test(methodname, exception, suiteClass):
@case.skip(str(exception))
@@ -69,6 +73,13 @@ class TestLoader(object):
suiteClass = suite.TestSuite
_top_level_dir = None
+ def __init__(self):
+ super(TestLoader, self).__init__()
+ self.errors = []
+ # Tracks packages which we have called into via load_tests, to
+ # avoid infinite re-entrancy.
+ self._loading_packages = set()
+
def loadTestsFromTestCase(self, testCaseClass):
"""Return a suite of all tests cases contained in testCaseClass"""
if issubclass(testCaseClass, suite.TestSuite):
@@ -81,8 +92,30 @@ class TestLoader(object):
loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
return loaded_suite
- def loadTestsFromModule(self, module, use_load_tests=True):
+ # XXX After Python 3.5, remove backward compatibility hacks for
+ # use_load_tests deprecation via *args and **kws. See issue 16662.
+ def loadTestsFromModule(self, module, *args, pattern=None, **kws):
"""Return a suite of all tests cases contained in the given module"""
+ # This method used to take an undocumented and unofficial
+ # use_load_tests argument. For backward compatibility, we still
+ # accept the argument (which can also be the first position) but we
+ # ignore it and issue a deprecation warning if it's present.
+ if len(args) > 0 or 'use_load_tests' in kws:
+ warnings.warn('use_load_tests is deprecated and ignored',
+ DeprecationWarning)
+ kws.pop('use_load_tests', None)
+ if len(args) > 1:
+ # Complain about the number of arguments, but don't forget the
+ # required `module` argument.
+ complaint = len(args) + 1
+ raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint))
+ if len(kws) != 0:
+ # Since the keyword arguments are unsorted (see PEP 468), just
+ # pick the alphabetically sorted first argument to complain about,
+ # if multiple were given. At least the error message will be
+ # predictable.
+ complaint = sorted(kws)[0]
+ raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint))
tests = []
for name in dir(module):
obj = getattr(module, name)
@@ -91,12 +124,14 @@ class TestLoader(object):
load_tests = getattr(module, 'load_tests', None)
tests = self.suiteClass(tests)
- if use_load_tests and load_tests is not None:
+ if load_tests is not None:
try:
- return load_tests(self, tests, None)
+ return load_tests(self, tests, pattern)
except Exception as e:
- return _make_failed_load_tests(module.__name__, e,
- self.suiteClass)
+ error_case, error_message = _make_failed_load_tests(
+ module.__name__, e, self.suiteClass)
+ self.errors.append(error_message)
+ return error_case
return tests
def loadTestsFromName(self, name, module=None):
@@ -109,20 +144,47 @@ class TestLoader(object):
The method optionally resolves the names relative to a given module.
"""
parts = name.split('.')
+ error_case, error_message = None, None
if module is None:
parts_copy = parts[:]
while parts_copy:
try:
- module = __import__('.'.join(parts_copy))
+ module_name = '.'.join(parts_copy)
+ module = __import__(module_name)
break
except ImportError:
- del parts_copy[-1]
+ next_attribute = parts_copy.pop()
+ # Last error so we can give it to the user if needed.
+ error_case, error_message = _make_failed_import_test(
+ next_attribute, self.suiteClass)
if not parts_copy:
- raise
+ # Even the top level import failed: report that error.
+ self.errors.append(error_message)
+ return error_case
parts = parts[1:]
obj = module
for part in parts:
- parent, obj = obj, getattr(obj, part)
+ try:
+ parent, obj = obj, getattr(obj, part)
+ except AttributeError as e:
+ # We can't traverse some part of the name.
+ if (getattr(obj, '__path__', None) is not None
+ and error_case is not None):
+ # This is a package (no __path__ per importlib docs), and we
+ # encountered an error importing something. We cannot tell
+ # the difference between package.WrongNameTestClass and
+ # package.wrong_module_name so we just report the
+ # ImportError - it is more informative.
+ self.errors.append(error_message)
+ return error_case
+ else:
+ # Otherwise, we signal that an AttributeError has occurred.
+ error_case, error_message = _make_failed_test(
+ part, e, self.suiteClass,
+ 'Failed to access attribute:\n%s' % (
+ traceback.format_exc(),))
+ self.errors.append(error_message)
+ return error_case
if isinstance(obj, types.ModuleType):
return self.loadTestsFromModule(obj)
@@ -181,9 +243,13 @@ class TestLoader(object):
If a test package name (directory with '__init__.py') matches the
pattern then the package will be checked for a 'load_tests' function. If
- this exists then it will be called with loader, tests, pattern.
+ this exists then it will be called with (loader, tests, pattern) unless
+ the package has already had load_tests called from the same discovery
+ invocation, in which case the package module object is not scanned for
+ tests - this ensures that when a package uses discover to further
+ discover child tests that infinite recursion does not happen.
- If load_tests exists then discovery does *not* recurse into the package,
+ If load_tests exists then discovery does *not* recurse into the package,
load_tests is responsible for loading all tests in the package.
The pattern is deliberately not stored as a loader attribute so that
@@ -288,6 +354,8 @@ class TestLoader(object):
return os.path.dirname(full_path)
def _get_name_from_path(self, path):
+ if path == self._top_level_dir:
+ return '.'
path = _jython_aware_splitext(os.path.normpath(path))
_relpath = os.path.relpath(path, self._top_level_dir)
@@ -307,63 +375,111 @@ class TestLoader(object):
def _find_tests(self, start_dir, pattern, namespace=False):
"""Used by discovery. Yields test suites it loads."""
+ # Handle the __init__ in this package
+ name = self._get_name_from_path(start_dir)
+ # name is '.' when start_dir == top_level_dir (and top_level_dir is by
+ # definition not a package).
+ if name != '.' and name not in self._loading_packages:
+ # name is in self._loading_packages while we have called into
+ # loadTestsFromModule with name.
+ tests, should_recurse = self._find_test_path(
+ start_dir, pattern, namespace)
+ if tests is not None:
+ yield tests
+ if not should_recurse:
+ # Either an error occured, or load_tests was used by the
+ # package.
+ return
+ # Handle the contents.
paths = sorted(os.listdir(start_dir))
-
for path in paths:
full_path = os.path.join(start_dir, path)
- if os.path.isfile(full_path):
- if not VALID_MODULE_NAME.match(path):
- # valid Python identifiers only
- continue
- if not self._match_path(path, full_path, pattern):
- continue
- # if the test file matches, load it
+ tests, should_recurse = self._find_test_path(
+ full_path, pattern, namespace)
+ if tests is not None:
+ yield tests
+ if should_recurse:
+ # we found a package that didn't use load_tests.
name = self._get_name_from_path(full_path)
+ self._loading_packages.add(name)
try:
- module = self._get_module_from_name(name)
- except case.SkipTest as e:
- yield _make_skipped_test(name, e, self.suiteClass)
- except:
- yield _make_failed_import_test(name, self.suiteClass)
- else:
- mod_file = os.path.abspath(getattr(module, '__file__', full_path))
- realpath = _jython_aware_splitext(os.path.realpath(mod_file))
- fullpath_noext = _jython_aware_splitext(os.path.realpath(full_path))
- if realpath.lower() != fullpath_noext.lower():
- module_dir = os.path.dirname(realpath)
- mod_name = _jython_aware_splitext(os.path.basename(full_path))
- expected_dir = os.path.dirname(full_path)
- msg = ("%r module incorrectly imported from %r. Expected %r. "
- "Is this module globally installed?")
- raise ImportError(msg % (mod_name, module_dir, expected_dir))
- yield self.loadTestsFromModule(module)
- elif os.path.isdir(full_path):
- if (not namespace and
- not os.path.isfile(os.path.join(full_path, '__init__.py'))):
- continue
-
- load_tests = None
- tests = None
- if fnmatch(path, pattern):
- # only check load_tests if the package directory itself matches the filter
- name = self._get_name_from_path(full_path)
- package = self._get_module_from_name(name)
- load_tests = getattr(package, 'load_tests', None)
- tests = self.loadTestsFromModule(package, use_load_tests=False)
-
- if load_tests is None:
- if tests is not None:
- # tests loaded from package file
- yield tests
- # recurse into the package
- yield from self._find_tests(full_path, pattern,
- namespace=namespace)
- else:
- try:
- yield load_tests(self, tests, pattern)
- except Exception as e:
- yield _make_failed_load_tests(package.__name__, e,
- self.suiteClass)
+ yield from self._find_tests(full_path, pattern, namespace)
+ finally:
+ self._loading_packages.discard(name)
+
+ def _find_test_path(self, full_path, pattern, namespace=False):
+ """Used by discovery.
+
+ Loads tests from a single file, or a directories' __init__.py when
+ passed the directory.
+
+ Returns a tuple (None_or_tests_from_file, should_recurse).
+ """
+ basename = os.path.basename(full_path)
+ if os.path.isfile(full_path):
+ if not VALID_MODULE_NAME.match(basename):
+ # valid Python identifiers only
+ return None, False
+ if not self._match_path(basename, full_path, pattern):
+ return None, False
+ # if the test file matches, load it
+ name = self._get_name_from_path(full_path)
+ try:
+ module = self._get_module_from_name(name)
+ except case.SkipTest as e:
+ return _make_skipped_test(name, e, self.suiteClass), False
+ except:
+ error_case, error_message = \
+ _make_failed_import_test(name, self.suiteClass)
+ self.errors.append(error_message)
+ return error_case, False
+ else:
+ mod_file = os.path.abspath(
+ getattr(module, '__file__', full_path))
+ realpath = _jython_aware_splitext(
+ os.path.realpath(mod_file))
+ fullpath_noext = _jython_aware_splitext(
+ os.path.realpath(full_path))
+ if realpath.lower() != fullpath_noext.lower():
+ module_dir = os.path.dirname(realpath)
+ mod_name = _jython_aware_splitext(
+ os.path.basename(full_path))
+ expected_dir = os.path.dirname(full_path)
+ msg = ("%r module incorrectly imported from %r. Expected "
+ "%r. Is this module globally installed?")
+ raise ImportError(
+ msg % (mod_name, module_dir, expected_dir))
+ return self.loadTestsFromModule(module, pattern=pattern), False
+ elif os.path.isdir(full_path):
+ if (not namespace and
+ not os.path.isfile(os.path.join(full_path, '__init__.py'))):
+ return None, False
+
+ load_tests = None
+ tests = None
+ name = self._get_name_from_path(full_path)
+ try:
+ package = self._get_module_from_name(name)
+ except case.SkipTest as e:
+ return _make_skipped_test(name, e, self.suiteClass), False
+ except:
+ error_case, error_message = \
+ _make_failed_import_test(name, self.suiteClass)
+ self.errors.append(error_message)
+ return error_case, False
+ else:
+ load_tests = getattr(package, 'load_tests', None)
+ # Mark this package as being in load_tests (possibly ;))
+ self._loading_packages.add(name)
+ try:
+ tests = self.loadTestsFromModule(package, pattern=pattern)
+ if load_tests is not None:
+ # loadTestsFromModule(package) has loaded tests for us.
+ return tests, False
+ return tests, True
+ finally:
+ self._loading_packages.discard(name)
+
defaultTestLoader = TestLoader()