summaryrefslogtreecommitdiffstats
path: root/Lib/unittest/loader.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/unittest/loader.py')
-rw-r--r--Lib/unittest/loader.py258
1 files changed, 188 insertions, 70 deletions
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index af39216..eb447d7 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -6,6 +6,7 @@ import sys
import traceback
import types
import functools
+import warnings
from fnmatch import fnmatch
@@ -13,9 +14,9 @@ from . import case, suite, util
__unittest = True
-# what about .pyc or .pyo (etc)
+# what about .pyc (etc)
# we would need to avoid loading the same tests multiple times
-# from '.py', '.pyc' *and* '.pyo'
+# from '.py', *and* '.pyc'
VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
@@ -35,15 +36,18 @@ class _FailedTest(case.TestCase):
def _make_failed_import_test(name, suiteClass):
- message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
- return _make_failed_test(name, ImportError(message), suiteClass)
+ message = 'Failed to import test module: %s\n%s' % (
+ name, traceback.format_exc())
+ return _make_failed_test(name, ImportError(message), suiteClass, message)
def _make_failed_load_tests(name, exception, suiteClass):
- return _make_failed_test(name, exception, suiteClass)
+ message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),)
+ return _make_failed_test(
+ name, exception, suiteClass, message)
-def _make_failed_test(methodname, exception, suiteClass):
+def _make_failed_test(methodname, exception, suiteClass, message):
test = _FailedTest(methodname, exception)
- return suiteClass((test,))
+ return suiteClass((test,)), message
def _make_skipped_test(methodname, exception, suiteClass):
@case.skip(str(exception))
@@ -69,6 +73,13 @@ class TestLoader(object):
suiteClass = suite.TestSuite
_top_level_dir = None
+ def __init__(self):
+ super(TestLoader, self).__init__()
+ self.errors = []
+ # Tracks packages which we have called into via load_tests, to
+ # avoid infinite re-entrancy.
+ self._loading_packages = set()
+
def loadTestsFromTestCase(self, testCaseClass):
"""Return a suite of all tests cases contained in testCaseClass"""
if issubclass(testCaseClass, suite.TestSuite):
@@ -81,8 +92,30 @@ class TestLoader(object):
loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
return loaded_suite
- def loadTestsFromModule(self, module, use_load_tests=True):
+ # XXX After Python 3.5, remove backward compatibility hacks for
+ # use_load_tests deprecation via *args and **kws. See issue 16662.
+ def loadTestsFromModule(self, module, *args, pattern=None, **kws):
"""Return a suite of all tests cases contained in the given module"""
+ # This method used to take an undocumented and unofficial
+ # use_load_tests argument. For backward compatibility, we still
+ # accept the argument (which can also be the first position) but we
+ # ignore it and issue a deprecation warning if it's present.
+ if len(args) > 0 or 'use_load_tests' in kws:
+ warnings.warn('use_load_tests is deprecated and ignored',
+ DeprecationWarning)
+ kws.pop('use_load_tests', None)
+ if len(args) > 1:
+ # Complain about the number of arguments, but don't forget the
+ # required `module` argument.
+ complaint = len(args) + 1
+ raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint))
+ if len(kws) != 0:
+ # Since the keyword arguments are unsorted (see PEP 468), just
+ # pick the alphabetically sorted first argument to complain about,
+ # if multiple were given. At least the error message will be
+ # predictable.
+ complaint = sorted(kws)[0]
+ raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint))
tests = []
for name in dir(module):
obj = getattr(module, name)
@@ -91,12 +124,14 @@ class TestLoader(object):
load_tests = getattr(module, 'load_tests', None)
tests = self.suiteClass(tests)
- if use_load_tests and load_tests is not None:
+ if load_tests is not None:
try:
- return load_tests(self, tests, None)
+ return load_tests(self, tests, pattern)
except Exception as e:
- return _make_failed_load_tests(module.__name__, e,
- self.suiteClass)
+ error_case, error_message = _make_failed_load_tests(
+ module.__name__, e, self.suiteClass)
+ self.errors.append(error_message)
+ return error_case
return tests
def loadTestsFromName(self, name, module=None):
@@ -109,20 +144,47 @@ class TestLoader(object):
The method optionally resolves the names relative to a given module.
"""
parts = name.split('.')
+ error_case, error_message = None, None
if module is None:
parts_copy = parts[:]
while parts_copy:
try:
- module = __import__('.'.join(parts_copy))
+ module_name = '.'.join(parts_copy)
+ module = __import__(module_name)
break
except ImportError:
- del parts_copy[-1]
+ next_attribute = parts_copy.pop()
+ # Last error so we can give it to the user if needed.
+ error_case, error_message = _make_failed_import_test(
+ next_attribute, self.suiteClass)
if not parts_copy:
- raise
+ # Even the top level import failed: report that error.
+ self.errors.append(error_message)
+ return error_case
parts = parts[1:]
obj = module
for part in parts:
- parent, obj = obj, getattr(obj, part)
+ try:
+ parent, obj = obj, getattr(obj, part)
+ except AttributeError as e:
+ # We can't traverse some part of the name.
+ if (getattr(obj, '__path__', None) is not None
+ and error_case is not None):
+ # This is a package (no __path__ per importlib docs), and we
+ # encountered an error importing something. We cannot tell
+ # the difference between package.WrongNameTestClass and
+ # package.wrong_module_name so we just report the
+ # ImportError - it is more informative.
+ self.errors.append(error_message)
+ return error_case
+ else:
+ # Otherwise, we signal that an AttributeError has occurred.
+ error_case, error_message = _make_failed_test(
+ part, e, self.suiteClass,
+ 'Failed to access attribute:\n%s' % (
+ traceback.format_exc(),))
+ self.errors.append(error_message)
+ return error_case
if isinstance(obj, types.ModuleType):
return self.loadTestsFromModule(obj)
@@ -181,9 +243,13 @@ class TestLoader(object):
If a test package name (directory with '__init__.py') matches the
pattern then the package will be checked for a 'load_tests' function. If
- this exists then it will be called with loader, tests, pattern.
+ this exists then it will be called with (loader, tests, pattern) unless
+ the package has already had load_tests called from the same discovery
+ invocation, in which case the package module object is not scanned for
+ tests - this ensures that when a package uses discover to further
+ discover child tests that infinite recursion does not happen.
- If load_tests exists then discovery does *not* recurse into the package,
+ If load_tests exists then discovery does *not* recurse into the package,
load_tests is responsible for loading all tests in the package.
The pattern is deliberately not stored as a loader attribute so that
@@ -288,6 +354,8 @@ class TestLoader(object):
return os.path.dirname(full_path)
def _get_name_from_path(self, path):
+ if path == self._top_level_dir:
+ return '.'
path = _jython_aware_splitext(os.path.normpath(path))
_relpath = os.path.relpath(path, self._top_level_dir)
@@ -307,63 +375,113 @@ class TestLoader(object):
def _find_tests(self, start_dir, pattern, namespace=False):
"""Used by discovery. Yields test suites it loads."""
+ # Handle the __init__ in this package
+ name = self._get_name_from_path(start_dir)
+ # name is '.' when start_dir == top_level_dir (and top_level_dir is by
+ # definition not a package).
+ if name != '.' and name not in self._loading_packages:
+ # name is in self._loading_packages while we have called into
+ # loadTestsFromModule with name.
+ tests, should_recurse = self._find_test_path(
+ start_dir, pattern, namespace)
+ if tests is not None:
+ yield tests
+ if not should_recurse:
+ # Either an error occurred, or load_tests was used by the
+ # package.
+ return
+ # Handle the contents.
paths = sorted(os.listdir(start_dir))
-
for path in paths:
full_path = os.path.join(start_dir, path)
- if os.path.isfile(full_path):
- if not VALID_MODULE_NAME.match(path):
- # valid Python identifiers only
- continue
- if not self._match_path(path, full_path, pattern):
- continue
- # if the test file matches, load it
+ tests, should_recurse = self._find_test_path(
+ full_path, pattern, namespace)
+ if tests is not None:
+ yield tests
+ if should_recurse:
+ # we found a package that didn't use load_tests.
name = self._get_name_from_path(full_path)
+ self._loading_packages.add(name)
try:
- module = self._get_module_from_name(name)
- except case.SkipTest as e:
- yield _make_skipped_test(name, e, self.suiteClass)
- except:
- yield _make_failed_import_test(name, self.suiteClass)
- else:
- mod_file = os.path.abspath(getattr(module, '__file__', full_path))
- realpath = _jython_aware_splitext(os.path.realpath(mod_file))
- fullpath_noext = _jython_aware_splitext(os.path.realpath(full_path))
- if realpath.lower() != fullpath_noext.lower():
- module_dir = os.path.dirname(realpath)
- mod_name = _jython_aware_splitext(os.path.basename(full_path))
- expected_dir = os.path.dirname(full_path)
- msg = ("%r module incorrectly imported from %r. Expected %r. "
- "Is this module globally installed?")
- raise ImportError(msg % (mod_name, module_dir, expected_dir))
- yield self.loadTestsFromModule(module)
- elif os.path.isdir(full_path):
- if (not namespace and
- not os.path.isfile(os.path.join(full_path, '__init__.py'))):
- continue
-
- load_tests = None
- tests = None
- if fnmatch(path, pattern):
- # only check load_tests if the package directory itself matches the filter
- name = self._get_name_from_path(full_path)
- package = self._get_module_from_name(name)
- load_tests = getattr(package, 'load_tests', None)
- tests = self.loadTestsFromModule(package, use_load_tests=False)
-
- if load_tests is None:
- if tests is not None:
- # tests loaded from package file
- yield tests
- # recurse into the package
- yield from self._find_tests(full_path, pattern,
- namespace=namespace)
- else:
- try:
- yield load_tests(self, tests, pattern)
- except Exception as e:
- yield _make_failed_load_tests(package.__name__, e,
- self.suiteClass)
+ yield from self._find_tests(full_path, pattern, namespace)
+ finally:
+ self._loading_packages.discard(name)
+
+ def _find_test_path(self, full_path, pattern, namespace=False):
+ """Used by discovery.
+
+ Loads tests from a single file, or a directories' __init__.py when
+ passed the directory.
+
+ Returns a tuple (None_or_tests_from_file, should_recurse).
+ """
+ basename = os.path.basename(full_path)
+ if os.path.isfile(full_path):
+ if not VALID_MODULE_NAME.match(basename):
+ # valid Python identifiers only
+ return None, False
+ if not self._match_path(basename, full_path, pattern):
+ return None, False
+ # if the test file matches, load it
+ name = self._get_name_from_path(full_path)
+ try:
+ module = self._get_module_from_name(name)
+ except case.SkipTest as e:
+ return _make_skipped_test(name, e, self.suiteClass), False
+ except:
+ error_case, error_message = \
+ _make_failed_import_test(name, self.suiteClass)
+ self.errors.append(error_message)
+ return error_case, False
+ else:
+ mod_file = os.path.abspath(
+ getattr(module, '__file__', full_path))
+ realpath = _jython_aware_splitext(
+ os.path.realpath(mod_file))
+ fullpath_noext = _jython_aware_splitext(
+ os.path.realpath(full_path))
+ if realpath.lower() != fullpath_noext.lower():
+ module_dir = os.path.dirname(realpath)
+ mod_name = _jython_aware_splitext(
+ os.path.basename(full_path))
+ expected_dir = os.path.dirname(full_path)
+ msg = ("%r module incorrectly imported from %r. Expected "
+ "%r. Is this module globally installed?")
+ raise ImportError(
+ msg % (mod_name, module_dir, expected_dir))
+ return self.loadTestsFromModule(module, pattern=pattern), False
+ elif os.path.isdir(full_path):
+ if (not namespace and
+ not os.path.isfile(os.path.join(full_path, '__init__.py'))):
+ return None, False
+
+ load_tests = None
+ tests = None
+ name = self._get_name_from_path(full_path)
+ try:
+ package = self._get_module_from_name(name)
+ except case.SkipTest as e:
+ return _make_skipped_test(name, e, self.suiteClass), False
+ except:
+ error_case, error_message = \
+ _make_failed_import_test(name, self.suiteClass)
+ self.errors.append(error_message)
+ return error_case, False
+ else:
+ load_tests = getattr(package, 'load_tests', None)
+ # Mark this package as being in load_tests (possibly ;))
+ self._loading_packages.add(name)
+ try:
+ tests = self.loadTestsFromModule(package, pattern=pattern)
+ if load_tests is not None:
+ # loadTestsFromModule(package) has loaded tests for us.
+ return tests, False
+ return tests, True
+ finally:
+ self._loading_packages.discard(name)
+ else:
+ return None, False
+
defaultTestLoader = TestLoader()
l kwa">None, number) gcold = gc.isenabled() gc.disable() try: timing = self.inner(it, self.timer) finally: if gcold: gc.enable() return timing def repeat(self, repeat=default_repeat, number=default_number): """Call timeit() a few times. This is a convenience function that calls the timeit() repeatedly, returning a list of results. The first argument specifies how many times to call timeit(), defaulting to 3; the second argument specifies the timer argument, defaulting to one million. Note: it's tempting to calculate mean and standard deviation from the result vector and report these. However, this is not very useful. In a typical case, the lowest value gives a lower bound for how fast your machine can run the given code snippet; higher values in the result vector are typically not caused by variability in Python's speed, but by other processes interfering with your timing accuracy. So the min() of the result is probably the only number you should be interested in. After that, you should look at the entire vector and apply common sense rather than statistics. """ r = [] for i in range(repeat): t = self.timeit(number) r.append(t) return r def timeit(stmt="pass", setup="pass", timer=default_timer, number=default_number, globals=None): """Convenience function to create Timer object and call timeit method.""" return Timer(stmt, setup, timer, globals).timeit(number) def repeat(stmt="pass", setup="pass", timer=default_timer, repeat=default_repeat, number=default_number, globals=None): """Convenience function to create Timer object and call repeat method.""" return Timer(stmt, setup, timer, globals).repeat(repeat, number) def main(args=None, *, _wrap_timer=None): """Main program, used when run as a script. The optional 'args' argument specifies the command line to be parsed, defaulting to sys.argv[1:]. The return value is an exit code to be passed to sys.exit(); it may be None to indicate success. When an exception happens during timing, a traceback is printed to stderr and the return value is 1. Exceptions at other times (including the template compilation) are not caught. '_wrap_timer' is an internal interface used for unit testing. If it is not None, it must be a callable that accepts a timer function and returns another timer function (used for unit testing). """ if args is None: args = sys.argv[1:] import getopt try: opts, args = getopt.getopt(args, "n:u:s:r:tcpvh", ["number=", "setup=", "repeat=", "time", "clock", "process", "verbose", "unit=", "help"]) except getopt.error as err: print(err) print("use -h/--help for command line help") return 2 timer = default_timer stmt = "\n".join(args) or "pass" number = 0 # auto-determine setup = [] repeat = default_repeat verbose = 0 time_unit = None units = {"usec": 1, "msec": 1e3, "sec": 1e6} precision = 3 for o, a in opts: if o in ("-n", "--number"): number = int(a) if o in ("-s", "--setup"): setup.append(a) if o in ("-u", "--unit"): if a in units: time_unit = a else: print("Unrecognized unit. Please select usec, msec, or sec.", file=sys.stderr) return 2 if o in ("-r", "--repeat"): repeat = int(a) if repeat <= 0: repeat = 1 if o in ("-t", "--time"): timer = time.time if o in ("-c", "--clock"): timer = time.clock if o in ("-p", "--process"): timer = time.process_time if o in ("-v", "--verbose"): if verbose: precision += 1 verbose += 1 if o in ("-h", "--help"): print(__doc__, end=' ') return 0 setup = "\n".join(setup) or "pass" # Include the current directory, so that local imports work (sys.path # contains the directory of this script, rather than the current # directory) import os sys.path.insert(0, os.curdir) if _wrap_timer is not None: timer = _wrap_timer(timer) t = Timer(stmt, setup, timer) if number == 0: # determine number so that 0.2 <= total time < 2.0 for i in range(1, 10): number = 10**i try: x = t.timeit(number) except: t.print_exc() return 1 if verbose: print("%d loops -> %.*g secs" % (number, precision, x)) if x >= 0.2: break try: r = t.repeat(repeat, number) except: t.print_exc() return 1 best = min(r) if verbose: print("raw times:", " ".join(["%.*g" % (precision, x) for x in r])) print("%d loops," % number, end=' ') usec = best * 1e6 / number if time_unit is not None: scale = units[time_unit] else: scales = [(scale, unit) for unit, scale in units.items()] scales.sort(reverse=True) for scale, time_unit in scales: if usec >= scale: break print("best of %d: %.*g %s per loop" % (repeat, precision, usec/scale, time_unit)) best = min(r) usec = best * 1e6 / number worst = max(r) if worst >= best * 4: usec = worst * 1e6 / number import warnings warnings.warn_explicit( "The test results are likely unreliable. The worst\n" "time (%.*g %s) was more than four times slower than the best time." % (precision, usec/scale, time_unit), UserWarning, '', 0) return None if __name__ == "__main__": sys.exit(main())