diff options
Diffstat (limited to 'Lib/unittest')
-rw-r--r-- | Lib/unittest/__init__.py | 9 | ||||
-rw-r--r-- | Lib/unittest/case.py | 99 | ||||
-rw-r--r-- | Lib/unittest/loader.py | 256 | ||||
-rw-r--r-- | Lib/unittest/main.py | 25 | ||||
-rw-r--r-- | Lib/unittest/mock.py | 71 | ||||
-rw-r--r-- | Lib/unittest/result.py | 7 | ||||
-rw-r--r-- | Lib/unittest/runner.py | 10 | ||||
-rw-r--r-- | Lib/unittest/test/support.py | 4 | ||||
-rw-r--r-- | Lib/unittest/test/test_break.py | 3 | ||||
-rw-r--r-- | Lib/unittest/test/test_case.py | 113 | ||||
-rw-r--r-- | Lib/unittest/test/test_discovery.py | 320 | ||||
-rw-r--r-- | Lib/unittest/test/test_loader.py | 423 | ||||
-rw-r--r-- | Lib/unittest/test/test_program.py | 26 | ||||
-rw-r--r-- | Lib/unittest/test/test_result.py | 43 | ||||
-rw-r--r-- | Lib/unittest/test/test_runner.py | 10 | ||||
-rw-r--r-- | Lib/unittest/test/test_setups.py | 7 | ||||
-rw-r--r-- | Lib/unittest/test/testmock/testmagicmethods.py | 11 | ||||
-rw-r--r-- | Lib/unittest/test/testmock/testmock.py | 45 | ||||
-rw-r--r-- | Lib/unittest/test/testmock/testpatch.py | 24 | ||||
-rw-r--r-- | Lib/unittest/util.py | 2 |
20 files changed, 1210 insertions, 298 deletions
diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py index a5d50af..f6d7ae2 100644 --- a/Lib/unittest/__init__.py +++ b/Lib/unittest/__init__.py @@ -67,3 +67,12 @@ from .signals import installHandler, registerResult, removeResult, removeHandler # deprecated _TextTestResult = TextTestResult + +# There are no tests here, so don't try to run anything discovered from +# introspecting the symbols (e.g. FunctionTestCase). Instead, all our +# tests come from within unittest.test. +def load_tests(loader, tests, pattern): + import os.path + # top level directory cached on loader instance + this_dir = os.path.dirname(__file__) + return loader.discover(start_dir=this_dir, pattern=pattern) diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 69888a5..7701ad3 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -119,6 +119,10 @@ def expectedFailure(test_item): test_item.__unittest_expecting_failure__ = True return test_item +def _is_subtype(expected, basetype): + if isinstance(expected, tuple): + return all(_is_subtype(e, basetype) for e in expected) + return isinstance(expected, type) and issubclass(expected, basetype) class _BaseTestCaseContext: @@ -129,35 +133,45 @@ class _BaseTestCaseContext: msg = self.test_case._formatMessage(self.msg, standardMsg) raise self.test_case.failureException(msg) - class _AssertRaisesBaseContext(_BaseTestCaseContext): - def __init__(self, expected, test_case, callable_obj=None, - expected_regex=None): + def __init__(self, expected, test_case, expected_regex=None): _BaseTestCaseContext.__init__(self, test_case) self.expected = expected self.test_case = test_case - if callable_obj is not None: - try: - self.obj_name = callable_obj.__name__ - except AttributeError: - self.obj_name = str(callable_obj) - else: - self.obj_name = None if expected_regex is not None: expected_regex = re.compile(expected_regex) self.expected_regex = expected_regex + self.obj_name = None self.msg = None - def handle(self, name, callable_obj, args, kwargs): + def handle(self, name, args, kwargs): """ - If callable_obj is None, assertRaises/Warns is being used as a + If args is empty, assertRaises/Warns is being used as a context manager, so check for a 'msg' kwarg and return self. - If callable_obj is not None, call it passing args and kwargs. + If args is not empty, call a callable passing positional and keyword + arguments. """ - if callable_obj is None: + if not _is_subtype(self.expected, self._base_type): + raise TypeError('%s() arg 1 must be %s' % + (name, self._base_type_str)) + if args and args[0] is None: + warnings.warn("callable is None", + DeprecationWarning, 3) + args = () + if not args: self.msg = kwargs.pop('msg', None) + if kwargs: + warnings.warn('%r is an invalid keyword argument for ' + 'this function' % next(iter(kwargs)), + DeprecationWarning, 3) return self + + callable_obj, *args = args + try: + self.obj_name = callable_obj.__name__ + except AttributeError: + self.obj_name = str(callable_obj) with self: callable_obj(*args, **kwargs) @@ -165,6 +179,9 @@ class _AssertRaisesBaseContext(_BaseTestCaseContext): class _AssertRaisesContext(_AssertRaisesBaseContext): """A context manager used to implement TestCase.assertRaises* methods.""" + _base_type = BaseException + _base_type_str = 'an exception type or tuple of exception types' + def __enter__(self): return self @@ -199,6 +216,9 @@ class _AssertRaisesContext(_AssertRaisesBaseContext): class _AssertWarnsContext(_AssertRaisesBaseContext): """A context manager used to implement TestCase.assertWarns* methods.""" + _base_type = Warning + _base_type_str = 'a warning type or tuple of warning types' + def __enter__(self): # The __warningregistry__'s need to be in a pristine state for tests # to work properly. @@ -674,15 +694,15 @@ class TestCase(object): except UnicodeDecodeError: return '%s : %s' % (safe_repr(standardMsg), safe_repr(msg)) - def assertRaises(self, excClass, callableObj=None, *args, **kwargs): - """Fail unless an exception of class excClass is raised - by callableObj when invoked with arguments args and keyword - arguments kwargs. If a different type of exception is + def assertRaises(self, expected_exception, *args, **kwargs): + """Fail unless an exception of class expected_exception is raised + by the callable when invoked with specified positional and + keyword arguments. If a different type of exception is raised, it will not be caught, and the test case will be deemed to have suffered an error, exactly as for an unexpected exception. - If called with callableObj omitted or None, will return a + If called with the callable and arguments omitted, will return a context object used like this:: with self.assertRaises(SomeException): @@ -700,18 +720,18 @@ class TestCase(object): the_exception = cm.exception self.assertEqual(the_exception.error_code, 3) """ - context = _AssertRaisesContext(excClass, self, callableObj) - return context.handle('assertRaises', callableObj, args, kwargs) + context = _AssertRaisesContext(expected_exception, self) + return context.handle('assertRaises', args, kwargs) - def assertWarns(self, expected_warning, callable_obj=None, *args, **kwargs): + def assertWarns(self, expected_warning, *args, **kwargs): """Fail unless a warning of class warnClass is triggered - by callable_obj when invoked with arguments args and keyword - arguments kwargs. If a different type of warning is + by the callable when invoked with specified positional and + keyword arguments. If a different type of warning is triggered, it will not be handled: depending on the other warning filtering rules in effect, it might be silenced, printed out, or raised as an exception. - If called with callable_obj omitted or None, will return a + If called with the callable and arguments omitted, will return a context object used like this:: with self.assertWarns(SomeWarning): @@ -731,8 +751,8 @@ class TestCase(object): the_warning = cm.warning self.assertEqual(the_warning.some_attribute, 147) """ - context = _AssertWarnsContext(expected_warning, self, callable_obj) - return context.handle('assertWarns', callable_obj, args, kwargs) + context = _AssertWarnsContext(expected_warning, self) + return context.handle('assertWarns', args, kwargs) def assertLogs(self, logger=None, level=None): """Fail unless a log message of level *level* or higher is emitted @@ -1219,26 +1239,23 @@ class TestCase(object): self.fail(self._formatMessage(msg, standardMsg)) def assertRaisesRegex(self, expected_exception, expected_regex, - callable_obj=None, *args, **kwargs): + *args, **kwargs): """Asserts that the message in a raised exception matches a regex. Args: expected_exception: Exception class expected to be raised. expected_regex: Regex (re pattern object or string) expected to be found in error message. - callable_obj: Function to be called. + args: Function to be called and extra positional args. + kwargs: Extra kwargs. msg: Optional message used in case of failure. Can only be used when assertRaisesRegex is used as a context manager. - args: Extra args. - kwargs: Extra kwargs. """ - context = _AssertRaisesContext(expected_exception, self, callable_obj, - expected_regex) - - return context.handle('assertRaisesRegex', callable_obj, args, kwargs) + context = _AssertRaisesContext(expected_exception, self, expected_regex) + return context.handle('assertRaisesRegex', args, kwargs) def assertWarnsRegex(self, expected_warning, expected_regex, - callable_obj=None, *args, **kwargs): + *args, **kwargs): """Asserts that the message in a triggered warning matches a regexp. Basic functioning is similar to assertWarns() with the addition that only warnings whose messages also match the regular expression @@ -1248,15 +1265,13 @@ class TestCase(object): expected_warning: Warning class expected to be triggered. expected_regex: Regex (re pattern object or string) expected to be found in error message. - callable_obj: Function to be called. + args: Function to be called and extra positional args. + kwargs: Extra kwargs. msg: Optional message used in case of failure. Can only be used when assertWarnsRegex is used as a context manager. - args: Extra args. - kwargs: Extra kwargs. """ - context = _AssertWarnsContext(expected_warning, self, callable_obj, - expected_regex) - return context.handle('assertWarnsRegex', callable_obj, args, kwargs) + context = _AssertWarnsContext(expected_warning, self, expected_regex) + return context.handle('assertWarnsRegex', args, kwargs) def assertRegex(self, text, expected_regex, msg=None): """Fail the test unless the text matches the regular expression.""" diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index af39216..c776f16 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -6,6 +6,7 @@ import sys import traceback import types import functools +import warnings from fnmatch import fnmatch @@ -13,9 +14,9 @@ from . import case, suite, util __unittest = True -# what about .pyc or .pyo (etc) +# what about .pyc (etc) # we would need to avoid loading the same tests multiple times -# from '.py', '.pyc' *and* '.pyo' +# from '.py', *and* '.pyc' VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE) @@ -35,15 +36,18 @@ class _FailedTest(case.TestCase): def _make_failed_import_test(name, suiteClass): - message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc()) - return _make_failed_test(name, ImportError(message), suiteClass) + message = 'Failed to import test module: %s\n%s' % ( + name, traceback.format_exc()) + return _make_failed_test(name, ImportError(message), suiteClass, message) def _make_failed_load_tests(name, exception, suiteClass): - return _make_failed_test(name, exception, suiteClass) + message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),) + return _make_failed_test( + name, exception, suiteClass, message) -def _make_failed_test(methodname, exception, suiteClass): +def _make_failed_test(methodname, exception, suiteClass, message): test = _FailedTest(methodname, exception) - return suiteClass((test,)) + return suiteClass((test,)), message def _make_skipped_test(methodname, exception, suiteClass): @case.skip(str(exception)) @@ -69,6 +73,13 @@ class TestLoader(object): suiteClass = suite.TestSuite _top_level_dir = None + def __init__(self): + super(TestLoader, self).__init__() + self.errors = [] + # Tracks packages which we have called into via load_tests, to + # avoid infinite re-entrancy. + self._loading_packages = set() + def loadTestsFromTestCase(self, testCaseClass): """Return a suite of all tests cases contained in testCaseClass""" if issubclass(testCaseClass, suite.TestSuite): @@ -81,8 +92,30 @@ class TestLoader(object): loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames)) return loaded_suite - def loadTestsFromModule(self, module, use_load_tests=True): + # XXX After Python 3.5, remove backward compatibility hacks for + # use_load_tests deprecation via *args and **kws. See issue 16662. + def loadTestsFromModule(self, module, *args, pattern=None, **kws): """Return a suite of all tests cases contained in the given module""" + # This method used to take an undocumented and unofficial + # use_load_tests argument. For backward compatibility, we still + # accept the argument (which can also be the first position) but we + # ignore it and issue a deprecation warning if it's present. + if len(args) > 0 or 'use_load_tests' in kws: + warnings.warn('use_load_tests is deprecated and ignored', + DeprecationWarning) + kws.pop('use_load_tests', None) + if len(args) > 1: + # Complain about the number of arguments, but don't forget the + # required `module` argument. + complaint = len(args) + 1 + raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint)) + if len(kws) != 0: + # Since the keyword arguments are unsorted (see PEP 468), just + # pick the alphabetically sorted first argument to complain about, + # if multiple were given. At least the error message will be + # predictable. + complaint = sorted(kws)[0] + raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint)) tests = [] for name in dir(module): obj = getattr(module, name) @@ -91,12 +124,14 @@ class TestLoader(object): load_tests = getattr(module, 'load_tests', None) tests = self.suiteClass(tests) - if use_load_tests and load_tests is not None: + if load_tests is not None: try: - return load_tests(self, tests, None) + return load_tests(self, tests, pattern) except Exception as e: - return _make_failed_load_tests(module.__name__, e, - self.suiteClass) + error_case, error_message = _make_failed_load_tests( + module.__name__, e, self.suiteClass) + self.errors.append(error_message) + return error_case return tests def loadTestsFromName(self, name, module=None): @@ -109,20 +144,47 @@ class TestLoader(object): The method optionally resolves the names relative to a given module. """ parts = name.split('.') + error_case, error_message = None, None if module is None: parts_copy = parts[:] while parts_copy: try: - module = __import__('.'.join(parts_copy)) + module_name = '.'.join(parts_copy) + module = __import__(module_name) break except ImportError: - del parts_copy[-1] + next_attribute = parts_copy.pop() + # Last error so we can give it to the user if needed. + error_case, error_message = _make_failed_import_test( + next_attribute, self.suiteClass) if not parts_copy: - raise + # Even the top level import failed: report that error. + self.errors.append(error_message) + return error_case parts = parts[1:] obj = module for part in parts: - parent, obj = obj, getattr(obj, part) + try: + parent, obj = obj, getattr(obj, part) + except AttributeError as e: + # We can't traverse some part of the name. + if (getattr(obj, '__path__', None) is not None + and error_case is not None): + # This is a package (no __path__ per importlib docs), and we + # encountered an error importing something. We cannot tell + # the difference between package.WrongNameTestClass and + # package.wrong_module_name so we just report the + # ImportError - it is more informative. + self.errors.append(error_message) + return error_case + else: + # Otherwise, we signal that an AttributeError has occurred. + error_case, error_message = _make_failed_test( + part, e, self.suiteClass, + 'Failed to access attribute:\n%s' % ( + traceback.format_exc(),)) + self.errors.append(error_message) + return error_case if isinstance(obj, types.ModuleType): return self.loadTestsFromModule(obj) @@ -181,9 +243,13 @@ class TestLoader(object): If a test package name (directory with '__init__.py') matches the pattern then the package will be checked for a 'load_tests' function. If - this exists then it will be called with loader, tests, pattern. + this exists then it will be called with (loader, tests, pattern) unless + the package has already had load_tests called from the same discovery + invocation, in which case the package module object is not scanned for + tests - this ensures that when a package uses discover to further + discover child tests that infinite recursion does not happen. - If load_tests exists then discovery does *not* recurse into the package, + If load_tests exists then discovery does *not* recurse into the package, load_tests is responsible for loading all tests in the package. The pattern is deliberately not stored as a loader attribute so that @@ -288,6 +354,8 @@ class TestLoader(object): return os.path.dirname(full_path) def _get_name_from_path(self, path): + if path == self._top_level_dir: + return '.' path = _jython_aware_splitext(os.path.normpath(path)) _relpath = os.path.relpath(path, self._top_level_dir) @@ -307,63 +375,111 @@ class TestLoader(object): def _find_tests(self, start_dir, pattern, namespace=False): """Used by discovery. Yields test suites it loads.""" + # Handle the __init__ in this package + name = self._get_name_from_path(start_dir) + # name is '.' when start_dir == top_level_dir (and top_level_dir is by + # definition not a package). + if name != '.' and name not in self._loading_packages: + # name is in self._loading_packages while we have called into + # loadTestsFromModule with name. + tests, should_recurse = self._find_test_path( + start_dir, pattern, namespace) + if tests is not None: + yield tests + if not should_recurse: + # Either an error occured, or load_tests was used by the + # package. + return + # Handle the contents. paths = sorted(os.listdir(start_dir)) - for path in paths: full_path = os.path.join(start_dir, path) - if os.path.isfile(full_path): - if not VALID_MODULE_NAME.match(path): - # valid Python identifiers only - continue - if not self._match_path(path, full_path, pattern): - continue - # if the test file matches, load it + tests, should_recurse = self._find_test_path( + full_path, pattern, namespace) + if tests is not None: + yield tests + if should_recurse: + # we found a package that didn't use load_tests. name = self._get_name_from_path(full_path) + self._loading_packages.add(name) try: - module = self._get_module_from_name(name) - except case.SkipTest as e: - yield _make_skipped_test(name, e, self.suiteClass) - except: - yield _make_failed_import_test(name, self.suiteClass) - else: - mod_file = os.path.abspath(getattr(module, '__file__', full_path)) - realpath = _jython_aware_splitext(os.path.realpath(mod_file)) - fullpath_noext = _jython_aware_splitext(os.path.realpath(full_path)) - if realpath.lower() != fullpath_noext.lower(): - module_dir = os.path.dirname(realpath) - mod_name = _jython_aware_splitext(os.path.basename(full_path)) - expected_dir = os.path.dirname(full_path) - msg = ("%r module incorrectly imported from %r. Expected %r. " - "Is this module globally installed?") - raise ImportError(msg % (mod_name, module_dir, expected_dir)) - yield self.loadTestsFromModule(module) - elif os.path.isdir(full_path): - if (not namespace and - not os.path.isfile(os.path.join(full_path, '__init__.py'))): - continue - - load_tests = None - tests = None - if fnmatch(path, pattern): - # only check load_tests if the package directory itself matches the filter - name = self._get_name_from_path(full_path) - package = self._get_module_from_name(name) - load_tests = getattr(package, 'load_tests', None) - tests = self.loadTestsFromModule(package, use_load_tests=False) - - if load_tests is None: - if tests is not None: - # tests loaded from package file - yield tests - # recurse into the package - yield from self._find_tests(full_path, pattern, - namespace=namespace) - else: - try: - yield load_tests(self, tests, pattern) - except Exception as e: - yield _make_failed_load_tests(package.__name__, e, - self.suiteClass) + yield from self._find_tests(full_path, pattern, namespace) + finally: + self._loading_packages.discard(name) + + def _find_test_path(self, full_path, pattern, namespace=False): + """Used by discovery. + + Loads tests from a single file, or a directories' __init__.py when + passed the directory. + + Returns a tuple (None_or_tests_from_file, should_recurse). + """ + basename = os.path.basename(full_path) + if os.path.isfile(full_path): + if not VALID_MODULE_NAME.match(basename): + # valid Python identifiers only + return None, False + if not self._match_path(basename, full_path, pattern): + return None, False + # if the test file matches, load it + name = self._get_name_from_path(full_path) + try: + module = self._get_module_from_name(name) + except case.SkipTest as e: + return _make_skipped_test(name, e, self.suiteClass), False + except: + error_case, error_message = \ + _make_failed_import_test(name, self.suiteClass) + self.errors.append(error_message) + return error_case, False + else: + mod_file = os.path.abspath( + getattr(module, '__file__', full_path)) + realpath = _jython_aware_splitext( + os.path.realpath(mod_file)) + fullpath_noext = _jython_aware_splitext( + os.path.realpath(full_path)) + if realpath.lower() != fullpath_noext.lower(): + module_dir = os.path.dirname(realpath) + mod_name = _jython_aware_splitext( + os.path.basename(full_path)) + expected_dir = os.path.dirname(full_path) + msg = ("%r module incorrectly imported from %r. Expected " + "%r. Is this module globally installed?") + raise ImportError( + msg % (mod_name, module_dir, expected_dir)) + return self.loadTestsFromModule(module, pattern=pattern), False + elif os.path.isdir(full_path): + if (not namespace and + not os.path.isfile(os.path.join(full_path, '__init__.py'))): + return None, False + + load_tests = None + tests = None + name = self._get_name_from_path(full_path) + try: + package = self._get_module_from_name(name) + except case.SkipTest as e: + return _make_skipped_test(name, e, self.suiteClass), False + except: + error_case, error_message = \ + _make_failed_import_test(name, self.suiteClass) + self.errors.append(error_message) + return error_case, False + else: + load_tests = getattr(package, 'load_tests', None) + # Mark this package as being in load_tests (possibly ;)) + self._loading_packages.add(name) + try: + tests = self.loadTestsFromModule(package, pattern=pattern) + if load_tests is not None: + # loadTestsFromModule(package) has loaded tests for us. + return tests, False + return tests, True + finally: + self._loading_packages.discard(name) + defaultTestLoader = TestLoader() diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py index 180df86..b209a3a 100644 --- a/Lib/unittest/main.py +++ b/Lib/unittest/main.py @@ -58,7 +58,7 @@ class TestProgram(object): def __init__(self, module='__main__', defaultTest=None, argv=None, testRunner=None, testLoader=loader.defaultTestLoader, exit=True, verbosity=1, failfast=None, catchbreak=None, - buffer=None, warnings=None): + buffer=None, warnings=None, *, tb_locals=False): if isinstance(module, str): self.module = __import__(module) for part in module.split('.')[1:]: @@ -73,8 +73,9 @@ class TestProgram(object): self.catchbreak = catchbreak self.verbosity = verbosity self.buffer = buffer + self.tb_locals = tb_locals if warnings is None and not sys.warnoptions: - # even if DreprecationWarnings are ignored by default + # even if DeprecationWarnings are ignored by default # print them anyway unless other warnings settings are # specified by the warnings arg or the -W python flag self.warnings = 'default' @@ -159,7 +160,9 @@ class TestProgram(object): parser.add_argument('-q', '--quiet', dest='verbosity', action='store_const', const=0, help='Quiet output') - + parser.add_argument('--locals', dest='tb_locals', + action='store_true', + help='Show local variables in tracebacks') if self.failfast is None: parser.add_argument('-f', '--failfast', dest='failfast', action='store_true', @@ -231,10 +234,18 @@ class TestProgram(object): self.testRunner = runner.TextTestRunner if isinstance(self.testRunner, type): try: - testRunner = self.testRunner(verbosity=self.verbosity, - failfast=self.failfast, - buffer=self.buffer, - warnings=self.warnings) + try: + testRunner = self.testRunner(verbosity=self.verbosity, + failfast=self.failfast, + buffer=self.buffer, + warnings=self.warnings, + tb_locals=self.tb_locals) + except TypeError: + # didn't accept the tb_locals argument + testRunner = self.testRunner(verbosity=self.verbosity, + failfast=self.failfast, + buffer=self.buffer, + warnings=self.warnings) except TypeError: # didn't accept the verbosity, buffer or failfast arguments testRunner = self.testRunner() diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index 3fbe846..efe5763 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -27,9 +27,13 @@ __version__ = '1.0' import inspect import pprint import sys +import builtins +from types import ModuleType from functools import wraps, partial +_builtins = {name for name in dir(builtins) if not name.startswith('_')} + BaseExceptions = (BaseException,) if 'java' in sys.platform: # jython @@ -271,13 +275,11 @@ def _copy(value): return value -_allowed_names = set( - [ - 'return_value', '_mock_return_value', 'side_effect', - '_mock_side_effect', '_mock_parent', '_mock_new_parent', - '_mock_name', '_mock_new_name' - ] -) +_allowed_names = { + 'return_value', '_mock_return_value', 'side_effect', + '_mock_side_effect', '_mock_parent', '_mock_new_parent', + '_mock_name', '_mock_new_name' +} def _delegating_property(name): @@ -375,7 +377,7 @@ class NonCallableMock(Base): def __init__( self, spec=None, wraps=None, name=None, spec_set=None, parent=None, _spec_state=None, _new_name='', _new_parent=None, - _spec_as_instance=False, _eat_self=None, **kwargs + _spec_as_instance=False, _eat_self=None, unsafe=False, **kwargs ): if _new_parent is None: _new_parent = parent @@ -405,6 +407,7 @@ class NonCallableMock(Base): __dict__['_mock_mock_calls'] = _CallList() __dict__['method_calls'] = _CallList() + __dict__['_mock_unsafe'] = unsafe if kwargs: self.configure_mock(**kwargs) @@ -503,7 +506,8 @@ class NonCallableMock(Base): if delegated is None: return self._mock_side_effect sf = delegated.side_effect - if sf is not None and not callable(sf) and not isinstance(sf, _MockIter): + if (sf is not None and not callable(sf) + and not isinstance(sf, _MockIter) and not _is_exception(sf)): sf = _MockIter(sf) delegated.side_effect = sf return sf @@ -567,13 +571,16 @@ class NonCallableMock(Base): def __getattr__(self, name): - if name == '_mock_methods': + if name in {'_mock_methods', '_mock_unsafe'}: raise AttributeError(name) elif self._mock_methods is not None: if name not in self._mock_methods or name in _all_magics: raise AttributeError("Mock object has no attribute %r" % name) elif _is_magic(name): raise AttributeError(name) + if not self._mock_unsafe: + if name.startswith(('assert', 'assret')): + raise AttributeError(name) result = self._mock_children.get(name) if result is _deleted: @@ -756,6 +763,14 @@ class NonCallableMock(Base): else: return _call + def assert_not_called(_mock_self): + """assert that the mock was never called. + """ + self = _mock_self + if self.call_count != 0: + msg = ("Expected '%s' to not have been called. Called %s times." % + (self._mock_name or 'mock', self.call_count)) + raise AssertionError(msg) def assert_called_with(_mock_self, *args, **kwargs): """assert that the mock was called with the specified arguments. @@ -1172,6 +1187,9 @@ class _patch(object): else: local = True + if name in _builtins and isinstance(target, ModuleType): + self.create = True + if not self.create and original is DEFAULT: raise AttributeError( "%s does not have the attribute %r" % (target, name) @@ -1659,7 +1677,7 @@ magic_methods = ( ) numerics = ( - "add sub mul div floordiv mod lshift rshift and xor or pow truediv" + "add sub mul matmul div floordiv mod lshift rshift and xor or pow truediv" ) inplace = ' '.join('i%s' % n for n in numerics.split()) right = ' '.join('r%s' % n for n in numerics.split()) @@ -1668,11 +1686,12 @@ right = ' '.join('r%s' % n for n in numerics.split()) # (as they are metaclass methods) # __del__ is not supported at all as it causes problems if it exists -_non_defaults = set('__%s__' % method for method in [ - 'get', 'set', 'delete', 'reversed', 'missing', 'reduce', 'reduce_ex', - 'getinitargs', 'getnewargs', 'getstate', 'setstate', 'getformat', - 'setformat', 'repr', 'dir', 'subclasses', 'format', -]) +_non_defaults = { + '__get__', '__set__', '__delete__', '__reversed__', '__missing__', + '__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__', + '__getstate__', '__setstate__', '__getformat__', '__setformat__', + '__repr__', '__dir__', '__subclasses__', '__format__', +} def _get_method(name, func): @@ -1683,19 +1702,19 @@ def _get_method(name, func): return method -_magics = set( +_magics = { '__%s__' % method for method in ' '.join([magic_methods, numerics, inplace, right]).split() -) +} _all_magics = _magics | _non_defaults -_unsupported_magics = set([ +_unsupported_magics = { '__getattr__', '__setattr__', '__init__', '__new__', '__prepare__' '__instancecheck__', '__subclasscheck__', '__del__' -]) +} _calculate_return_value = { '__hash__': lambda self: object.__hash__(self), @@ -1884,7 +1903,7 @@ def _format_call_signature(name, args, kwargs): formatted_args = '' args_string = ', '.join([repr(arg) for arg in args]) kwargs_string = ', '.join([ - '%s=%r' % (key, value) for key, value in kwargs.items() + '%s=%r' % (key, value) for key, value in sorted(kwargs.items()) ]) if args_string: formatted_args = args_string @@ -2006,10 +2025,6 @@ class _Call(tuple): return (other_args, other_kwargs) == (self_args, self_kwargs) - def __ne__(self, other): - return not self.__eq__(other) - - def __call__(self, *args, **kwargs): if self.name is None: return _Call(('', args, kwargs), name='()') @@ -2025,6 +2040,12 @@ class _Call(tuple): return _Call(name=name, parent=self, from_kall=False) + def count(self, *args, **kwargs): + return self.__getattr__('count')(*args, **kwargs) + + def index(self, *args, **kwargs): + return self.__getattr__('index')(*args, **kwargs) + def __repr__(self): if not self.from_kall: name = self.name or 'call' diff --git a/Lib/unittest/result.py b/Lib/unittest/result.py index 8e0a643..a18f11b 100644 --- a/Lib/unittest/result.py +++ b/Lib/unittest/result.py @@ -45,6 +45,7 @@ class TestResult(object): self.unexpectedSuccesses = [] self.shouldStop = False self.buffer = False + self.tb_locals = False self._stdout_buffer = None self._stderr_buffer = None self._original_stdout = sys.stdout @@ -179,9 +180,11 @@ class TestResult(object): if exctype is test.failureException: # Skip assert*() traceback levels length = self._count_relevant_tb_levels(tb) - msgLines = traceback.format_exception(exctype, value, tb, length) else: - msgLines = traceback.format_exception(exctype, value, tb) + length = None + tb_e = traceback.TracebackException( + exctype, value, tb, limit=length, capture_locals=self.tb_locals) + msgLines = list(tb_e.format()) if self.buffer: output = sys.stdout.getvalue() diff --git a/Lib/unittest/runner.py b/Lib/unittest/runner.py index 28b8865..2112262 100644 --- a/Lib/unittest/runner.py +++ b/Lib/unittest/runner.py @@ -126,7 +126,13 @@ class TextTestRunner(object): resultclass = TextTestResult def __init__(self, stream=None, descriptions=True, verbosity=1, - failfast=False, buffer=False, resultclass=None, warnings=None): + failfast=False, buffer=False, resultclass=None, warnings=None, + *, tb_locals=False): + """Construct a TextTestRunner. + + Subclasses should accept **kwargs to ensure compatibility as the + interface changes. + """ if stream is None: stream = sys.stderr self.stream = _WritelnDecorator(stream) @@ -134,6 +140,7 @@ class TextTestRunner(object): self.verbosity = verbosity self.failfast = failfast self.buffer = buffer + self.tb_locals = tb_locals self.warnings = warnings if resultclass is not None: self.resultclass = resultclass @@ -147,6 +154,7 @@ class TextTestRunner(object): registerResult(result) result.failfast = self.failfast result.buffer = self.buffer + result.tb_locals = self.tb_locals with warnings.catch_warnings(): if self.warnings: # if self.warnings is set, use it to filter all the warnings diff --git a/Lib/unittest/test/support.py b/Lib/unittest/test/support.py index 02e8f3a..5292653 100644 --- a/Lib/unittest/test/support.py +++ b/Lib/unittest/test/support.py @@ -25,8 +25,6 @@ class TestHashing(object): try: if not hash(obj_1) == hash(obj_2): self.fail("%r and %r do not hash equal" % (obj_1, obj_2)) - except KeyboardInterrupt: - raise except Exception as e: self.fail("Problem hashing %r and %r: %s" % (obj_1, obj_2, e)) @@ -35,8 +33,6 @@ class TestHashing(object): if hash(obj_1) == hash(obj_2): self.fail("%s and %s hash equal, but shouldn't" % (obj_1, obj_2)) - except KeyboardInterrupt: - raise except Exception as e: self.fail("Problem hashing %s and %s: %s" % (obj_1, obj_2, e)) diff --git a/Lib/unittest/test/test_break.py b/Lib/unittest/test/test_break.py index 0bf1a22..2c75019 100644 --- a/Lib/unittest/test/test_break.py +++ b/Lib/unittest/test/test_break.py @@ -211,6 +211,7 @@ class TestBreak(unittest.TestCase): self.verbosity = verbosity self.failfast = failfast self.catchbreak = catchbreak + self.tb_locals = False self.testRunner = FakeRunner self.test = test self.result = None @@ -221,6 +222,7 @@ class TestBreak(unittest.TestCase): self.assertEqual(FakeRunner.initArgs, [((), {'buffer': None, 'verbosity': verbosity, 'failfast': failfast, + 'tb_locals': False, 'warnings': None})]) self.assertEqual(FakeRunner.runArgs, [test]) self.assertEqual(p.result, result) @@ -235,6 +237,7 @@ class TestBreak(unittest.TestCase): self.assertEqual(FakeRunner.initArgs, [((), {'buffer': None, 'verbosity': verbosity, 'failfast': failfast, + 'tb_locals': False, 'warnings': None})]) self.assertEqual(FakeRunner.runArgs, [test]) self.assertEqual(p.result, result) diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py index 321d67a..ada733b 100644 --- a/Lib/unittest/test/test_case.py +++ b/Lib/unittest/test/test_case.py @@ -1103,12 +1103,9 @@ test case except self.failureException as e: # need to remove the first line of the error message error = str(e).split('\n', 1)[1] + self.assertEqual(sample_text_error, error) - # no fair testing ourself with ourself, and assertEqual is used for strings - # so can't use assertEqual either. Just use assertTrue. - self.assertTrue(sample_text_error == error) - - def testAsertEqualSingleLine(self): + def testAssertEqualSingleLine(self): sample_text = "laden swallows fly slowly" revised_sample_text = "unladen swallows fly quickly" sample_text_error = """\ @@ -1120,8 +1117,9 @@ test case try: self.assertEqual(sample_text, revised_sample_text) except self.failureException as e: + # need to remove the first line of the error message error = str(e).split('\n', 1)[1] - self.assertTrue(sample_text_error == error) + self.assertEqual(sample_text_error, error) def testAssertIsNone(self): self.assertIsNone(None) @@ -1147,6 +1145,9 @@ test case # Failure when no exception is raised with self.assertRaises(self.failureException): self.assertRaises(ExceptionMock, lambda: 0) + # Failure when the function is None + with self.assertWarns(DeprecationWarning): + self.assertRaises(ExceptionMock, None) # Failure when another exception is raised with self.assertRaises(ExceptionMock): self.assertRaises(ValueError, Stub) @@ -1171,10 +1172,31 @@ test case with self.assertRaises(self.failureException): with self.assertRaises(ExceptionMock): pass + # Custom message + with self.assertRaisesRegex(self.failureException, 'foobar'): + with self.assertRaises(ExceptionMock, msg='foobar'): + pass + # Invalid keyword argument + with self.assertWarnsRegex(DeprecationWarning, 'foobar'), \ + self.assertRaises(AssertionError): + with self.assertRaises(ExceptionMock, foobar=42): + pass # Failure when another exception is raised with self.assertRaises(ExceptionMock): self.assertRaises(ValueError, Stub) + def testAssertRaisesNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertRaises() + with self.assertRaises(TypeError): + self.assertRaises(1) + with self.assertRaises(TypeError): + self.assertRaises(object) + with self.assertRaises(TypeError): + self.assertRaises((ValueError, 1)) + with self.assertRaises(TypeError): + self.assertRaises((ValueError, object)) + def testAssertRaisesRegex(self): class ExceptionMock(Exception): pass @@ -1184,6 +1206,8 @@ test case self.assertRaisesRegex(ExceptionMock, re.compile('expect$'), Stub) self.assertRaisesRegex(ExceptionMock, 'expect$', Stub) + with self.assertWarns(DeprecationWarning): + self.assertRaisesRegex(ExceptionMock, 'expect$', None) def testAssertNotRaisesRegex(self): self.assertRaisesRegex( @@ -1194,6 +1218,15 @@ test case self.failureException, '^Exception not raised by <lambda>$', self.assertRaisesRegex, Exception, 'x', lambda: None) + # Custom message + with self.assertRaisesRegex(self.failureException, 'foobar'): + with self.assertRaisesRegex(Exception, 'expect', msg='foobar'): + pass + # Invalid keyword argument + with self.assertWarnsRegex(DeprecationWarning, 'foobar'), \ + self.assertRaises(AssertionError): + with self.assertRaisesRegex(Exception, 'expect', foobar=42): + pass def testAssertRaisesRegexInvalidRegex(self): # Issue 20145. @@ -1237,6 +1270,20 @@ test case self.assertIsInstance(e, ExceptionMock) self.assertEqual(e.args[0], v) + def testAssertRaisesRegexNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertRaisesRegex() + with self.assertRaises(TypeError): + self.assertRaisesRegex(ValueError) + with self.assertRaises(TypeError): + self.assertRaisesRegex(1, 'expect') + with self.assertRaises(TypeError): + self.assertRaisesRegex(object, 'expect') + with self.assertRaises(TypeError): + self.assertRaisesRegex((ValueError, 1), 'expect') + with self.assertRaises(TypeError): + self.assertRaisesRegex((ValueError, object), 'expect') + def testAssertWarnsCallable(self): def _runtime_warn(): warnings.warn("foo", RuntimeWarning) @@ -1251,6 +1298,9 @@ test case # Failure when no warning is triggered with self.assertRaises(self.failureException): self.assertWarns(RuntimeWarning, lambda: 0) + # Failure when the function is None + with self.assertWarns(DeprecationWarning): + self.assertWarns(RuntimeWarning, None) # Failure when another warning is triggered with warnings.catch_warnings(): # Force default filter (in case tests are run with -We) @@ -1289,6 +1339,15 @@ test case with self.assertRaises(self.failureException): with self.assertWarns(RuntimeWarning): pass + # Custom message + with self.assertRaisesRegex(self.failureException, 'foobar'): + with self.assertWarns(RuntimeWarning, msg='foobar'): + pass + # Invalid keyword argument + with self.assertWarnsRegex(DeprecationWarning, 'foobar'), \ + self.assertRaises(AssertionError): + with self.assertWarns(RuntimeWarning, foobar=42): + pass # Failure when another warning is triggered with warnings.catch_warnings(): # Force default filter (in case tests are run with -We) @@ -1303,6 +1362,20 @@ test case with self.assertWarns(DeprecationWarning): _runtime_warn() + def testAssertWarnsNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertWarns() + with self.assertRaises(TypeError): + self.assertWarns(1) + with self.assertRaises(TypeError): + self.assertWarns(object) + with self.assertRaises(TypeError): + self.assertWarns((UserWarning, 1)) + with self.assertRaises(TypeError): + self.assertWarns((UserWarning, object)) + with self.assertRaises(TypeError): + self.assertWarns((UserWarning, Exception)) + def testAssertWarnsRegexCallable(self): def _runtime_warn(msg): warnings.warn(msg, RuntimeWarning) @@ -1312,6 +1385,9 @@ test case with self.assertRaises(self.failureException): self.assertWarnsRegex(RuntimeWarning, "o+", lambda: 0) + # Failure when the function is None + with self.assertWarns(DeprecationWarning): + self.assertWarnsRegex(RuntimeWarning, "o+", None) # Failure when another warning is triggered with warnings.catch_warnings(): # Force default filter (in case tests are run with -We) @@ -1348,6 +1424,15 @@ test case with self.assertRaises(self.failureException): with self.assertWarnsRegex(RuntimeWarning, "o+"): pass + # Custom message + with self.assertRaisesRegex(self.failureException, 'foobar'): + with self.assertWarnsRegex(RuntimeWarning, 'o+', msg='foobar'): + pass + # Invalid keyword argument + with self.assertWarnsRegex(DeprecationWarning, 'foobar'), \ + self.assertRaises(AssertionError): + with self.assertWarnsRegex(RuntimeWarning, 'o+', foobar=42): + pass # Failure when another warning is triggered with warnings.catch_warnings(): # Force default filter (in case tests are run with -We) @@ -1369,6 +1454,22 @@ test case with self.assertWarnsRegex(RuntimeWarning, "o+"): _runtime_warn("barz") + def testAssertWarnsRegexNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertWarnsRegex() + with self.assertRaises(TypeError): + self.assertWarnsRegex(UserWarning) + with self.assertRaises(TypeError): + self.assertWarnsRegex(1, 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex(object, 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex((UserWarning, 1), 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex((UserWarning, object), 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex((UserWarning, Exception), 'expect') + @contextlib.contextmanager def assertNoStderr(self): with captured_stderr() as buf: diff --git a/Lib/unittest/test/test_discovery.py b/Lib/unittest/test/test_discovery.py index f12e898..8991f38 100644 --- a/Lib/unittest/test/test_discovery.py +++ b/Lib/unittest/test/test_discovery.py @@ -1,4 +1,5 @@ -import os +import os.path +from os.path import abspath import re import sys import types @@ -69,7 +70,13 @@ class TestDiscovery(unittest.TestCase): self.addCleanup(restore_isfile) loader._get_module_from_name = lambda path: path + ' module' - loader.loadTestsFromModule = lambda module: module + ' tests' + orig_load_tests = loader.loadTestsFromModule + def loadTestsFromModule(module, pattern=None): + # This is where load_tests is called. + base = orig_load_tests(module, pattern=pattern) + return base + [module + ' tests'] + loader.loadTestsFromModule = loadTestsFromModule + loader.suiteClass = lambda thing: thing top_level = os.path.abspath('/foo') loader._top_level_dir = top_level @@ -77,9 +84,9 @@ class TestDiscovery(unittest.TestCase): # The test suites found should be sorted alphabetically for reliable # execution order. - expected = [name + ' module tests' for name in - ('test1', 'test2')] - expected.extend([('test_dir.%s' % name) + ' module tests' for name in + expected = [[name + ' module tests'] for name in + ('test1', 'test2', 'test_dir')] + expected.extend([[('test_dir.%s' % name) + ' module tests'] for name in ('test3', 'test4')]) self.assertEqual(suite, expected) @@ -117,34 +124,204 @@ class TestDiscovery(unittest.TestCase): if os.path.basename(path) == 'test_directory': def load_tests(loader, tests, pattern): self.load_tests_args.append((loader, tests, pattern)) - return 'load_tests' + return [self.path + ' load_tests'] self.load_tests = load_tests def __eq__(self, other): return self.path == other.path loader._get_module_from_name = lambda name: Module(name) - def loadTestsFromModule(module, use_load_tests): - if use_load_tests: - raise self.failureException('use_load_tests should be False for packages') - return module.path + ' module tests' + orig_load_tests = loader.loadTestsFromModule + def loadTestsFromModule(module, pattern=None): + # This is where load_tests is called. + base = orig_load_tests(module, pattern=pattern) + return base + [module.path + ' module tests'] loader.loadTestsFromModule = loadTestsFromModule + loader.suiteClass = lambda thing: thing loader._top_level_dir = '/foo' # this time no '.py' on the pattern so that it can match # a test package suite = list(loader._find_tests('/foo', 'test*')) - # We should have loaded tests from the test_directory package by calling load_tests - # and directly from the test_directory2 package + # We should have loaded tests from the a_directory and test_directory2 + # directly and via load_tests for the test_directory package, which + # still calls the baseline module loader. self.assertEqual(suite, - ['load_tests', 'test_directory2' + ' module tests']) + [['a_directory module tests'], + ['test_directory load_tests', + 'test_directory module tests'], + ['test_directory2 module tests']]) + + # The test module paths should be sorted for reliable execution order - self.assertEqual(Module.paths, ['test_directory', 'test_directory2']) + self.assertEqual(Module.paths, + ['a_directory', 'test_directory', 'test_directory2']) # load_tests should have been called once with loader, tests and pattern + # (but there are no tests in our stub module itself, so thats [] at the + # time of call. + self.assertEqual(Module.load_tests_args, + [(loader, [], 'test*')]) + + def test_find_tests_default_calls_package_load_tests(self): + loader = unittest.TestLoader() + + original_listdir = os.listdir + def restore_listdir(): + os.listdir = original_listdir + original_isfile = os.path.isfile + def restore_isfile(): + os.path.isfile = original_isfile + original_isdir = os.path.isdir + def restore_isdir(): + os.path.isdir = original_isdir + + directories = ['a_directory', 'test_directory', 'test_directory2'] + path_lists = [directories, [], [], []] + os.listdir = lambda path: path_lists.pop(0) + self.addCleanup(restore_listdir) + + os.path.isdir = lambda path: True + self.addCleanup(restore_isdir) + + os.path.isfile = lambda path: os.path.basename(path) not in directories + self.addCleanup(restore_isfile) + + class Module(object): + paths = [] + load_tests_args = [] + + def __init__(self, path): + self.path = path + self.paths.append(path) + if os.path.basename(path) == 'test_directory': + def load_tests(loader, tests, pattern): + self.load_tests_args.append((loader, tests, pattern)) + return [self.path + ' load_tests'] + self.load_tests = load_tests + + def __eq__(self, other): + return self.path == other.path + + loader._get_module_from_name = lambda name: Module(name) + orig_load_tests = loader.loadTestsFromModule + def loadTestsFromModule(module, pattern=None): + # This is where load_tests is called. + base = orig_load_tests(module, pattern=pattern) + return base + [module.path + ' module tests'] + loader.loadTestsFromModule = loadTestsFromModule + loader.suiteClass = lambda thing: thing + + loader._top_level_dir = '/foo' + # this time no '.py' on the pattern so that it can match + # a test package + suite = list(loader._find_tests('/foo', 'test*.py')) + + # We should have loaded tests from the a_directory and test_directory2 + # directly and via load_tests for the test_directory package, which + # still calls the baseline module loader. + self.assertEqual(suite, + [['a_directory module tests'], + ['test_directory load_tests', + 'test_directory module tests'], + ['test_directory2 module tests']]) + # The test module paths should be sorted for reliable execution order + self.assertEqual(Module.paths, + ['a_directory', 'test_directory', 'test_directory2']) + + + # load_tests should have been called once with loader, tests and pattern + self.assertEqual(Module.load_tests_args, + [(loader, [], 'test*.py')]) + + def test_find_tests_customise_via_package_pattern(self): + # This test uses the example 'do-nothing' load_tests from + # https://docs.python.org/3/library/unittest.html#load-tests-protocol + # to make sure that that actually works. + # Housekeeping + original_listdir = os.listdir + def restore_listdir(): + os.listdir = original_listdir + self.addCleanup(restore_listdir) + original_isfile = os.path.isfile + def restore_isfile(): + os.path.isfile = original_isfile + self.addCleanup(restore_isfile) + original_isdir = os.path.isdir + def restore_isdir(): + os.path.isdir = original_isdir + self.addCleanup(restore_isdir) + self.addCleanup(sys.path.remove, abspath('/foo')) + + # Test data: we expect the following: + # a listdir to find our package, and a isfile and isdir check on it. + # a module-from-name call to turn that into a module + # followed by load_tests. + # then our load_tests will call discover() which is messy + # but that finally chains into find_tests again for the child dir - + # which is why we don't have a infinite loop. + # We expect to see: + # the module load tests for both package and plain module called, + # and the plain module result nested by the package module load_tests + # indicating that it was processed and could have been mutated. + vfs = {abspath('/foo'): ['my_package'], + abspath('/foo/my_package'): ['__init__.py', 'test_module.py']} + def list_dir(path): + return list(vfs[path]) + os.listdir = list_dir + os.path.isdir = lambda path: not path.endswith('.py') + os.path.isfile = lambda path: path.endswith('.py') + + class Module(object): + paths = [] + load_tests_args = [] + + def __init__(self, path): + self.path = path + self.paths.append(path) + if path.endswith('test_module'): + def load_tests(loader, tests, pattern): + self.load_tests_args.append((loader, tests, pattern)) + return [self.path + ' load_tests'] + else: + def load_tests(loader, tests, pattern): + self.load_tests_args.append((loader, tests, pattern)) + # top level directory cached on loader instance + __file__ = '/foo/my_package/__init__.py' + this_dir = os.path.dirname(__file__) + pkg_tests = loader.discover( + start_dir=this_dir, pattern=pattern) + return [self.path + ' load_tests', tests + ] + pkg_tests + self.load_tests = load_tests + + def __eq__(self, other): + return self.path == other.path + + loader = unittest.TestLoader() + loader._get_module_from_name = lambda name: Module(name) + loader.suiteClass = lambda thing: thing + + loader._top_level_dir = abspath('/foo') + # this time no '.py' on the pattern so that it can match + # a test package + suite = list(loader._find_tests(abspath('/foo'), 'test*.py')) + + # We should have loaded tests from both my_package and + # my_pacakge.test_module, and also run the load_tests hook in both. + # (normally this would be nested TestSuites.) + self.assertEqual(suite, + [['my_package load_tests', [], + ['my_package.test_module load_tests']]]) + # Parents before children. + self.assertEqual(Module.paths, + ['my_package', 'my_package.test_module']) + + # load_tests should have been called twice with loader, tests and pattern self.assertEqual(Module.load_tests_args, - [(loader, 'test_directory' + ' module tests', 'test*')]) + [(loader, [], 'test*.py'), + (loader, [], 'test*.py')]) def test_discover(self): loader = unittest.TestLoader() @@ -192,6 +369,51 @@ class TestDiscovery(unittest.TestCase): self.assertEqual(_find_tests_args, [(start_dir, 'pattern')]) self.assertIn(top_level_dir, sys.path) + def test_discover_start_dir_is_package_calls_package_load_tests(self): + # This test verifies that the package load_tests in a package is indeed + # invoked when the start_dir is a package (and not the top level). + # http://bugs.python.org/issue22457 + + # Test data: we expect the following: + # an isfile to verify the package, then importing and scanning + # as per _find_tests' normal behaviour. + # We expect to see our load_tests hook called once. + vfs = {abspath('/toplevel'): ['startdir'], + abspath('/toplevel/startdir'): ['__init__.py']} + def list_dir(path): + return list(vfs[path]) + self.addCleanup(setattr, os, 'listdir', os.listdir) + os.listdir = list_dir + self.addCleanup(setattr, os.path, 'isfile', os.path.isfile) + os.path.isfile = lambda path: path.endswith('.py') + self.addCleanup(setattr, os.path, 'isdir', os.path.isdir) + os.path.isdir = lambda path: not path.endswith('.py') + self.addCleanup(sys.path.remove, abspath('/toplevel')) + + class Module(object): + paths = [] + load_tests_args = [] + + def __init__(self, path): + self.path = path + + def load_tests(self, loader, tests, pattern): + return ['load_tests called ' + self.path] + + def __eq__(self, other): + return self.path == other.path + + loader = unittest.TestLoader() + loader._get_module_from_name = lambda name: Module(name) + loader.suiteClass = lambda thing: thing + + suite = loader.discover('/toplevel/startdir', top_level_dir='/toplevel') + + # We should have loaded tests from the package __init__. + # (normally this would be nested TestSuites.) + self.assertEqual(suite, + [['load_tests called startdir']]) + def setup_import_issue_tests(self, fakefile): listdir = os.listdir os.listdir = lambda _: [fakefile] @@ -204,6 +426,17 @@ class TestDiscovery(unittest.TestCase): sys.path[:] = orig_sys_path self.addCleanup(restore) + def setup_import_issue_package_tests(self, vfs): + self.addCleanup(setattr, os, 'listdir', os.listdir) + self.addCleanup(setattr, os.path, 'isfile', os.path.isfile) + self.addCleanup(setattr, os.path, 'isdir', os.path.isdir) + self.addCleanup(sys.path.__setitem__, slice(None), list(sys.path)) + def list_dir(path): + return list(vfs[path]) + os.listdir = list_dir + os.path.isdir = lambda path: not path.endswith('.py') + os.path.isfile = lambda path: path.endswith('.py') + def test_discover_with_modules_that_fail_to_import(self): loader = unittest.TestLoader() @@ -212,11 +445,44 @@ class TestDiscovery(unittest.TestCase): suite = loader.discover('.') self.assertIn(os.getcwd(), sys.path) self.assertEqual(suite.countTestCases(), 1) + # Errors loading the suite are also captured for introspection. + self.assertNotEqual([], loader.errors) + self.assertEqual(1, len(loader.errors)) + error = loader.errors[0] + self.assertTrue( + 'Failed to import test module: test_this_does_not_exist' in error, + 'missing error string in %r' % error) test = list(list(suite)[0])[0] # extract test from suite with self.assertRaises(ImportError): test.test_this_does_not_exist() + def test_discover_with_init_modules_that_fail_to_import(self): + vfs = {abspath('/foo'): ['my_package'], + abspath('/foo/my_package'): ['__init__.py', 'test_module.py']} + self.setup_import_issue_package_tests(vfs) + import_calls = [] + def _get_module_from_name(name): + import_calls.append(name) + raise ImportError("Cannot import Name") + loader = unittest.TestLoader() + loader._get_module_from_name = _get_module_from_name + suite = loader.discover(abspath('/foo')) + + self.assertIn(abspath('/foo'), sys.path) + self.assertEqual(suite.countTestCases(), 1) + # Errors loading the suite are also captured for introspection. + self.assertNotEqual([], loader.errors) + self.assertEqual(1, len(loader.errors)) + error = loader.errors[0] + self.assertTrue( + 'Failed to import test module: my_package' in error, + 'missing error string in %r' % error) + test = list(list(suite)[0])[0] # extract test from suite + with self.assertRaises(ImportError): + test.my_package() + self.assertEqual(import_calls, ['my_package']) + # Check picklability for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickle.loads(pickle.dumps(test, proto)) @@ -241,6 +507,30 @@ class TestDiscovery(unittest.TestCase): for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickle.loads(pickle.dumps(suite, proto)) + def test_discover_with_init_module_that_raises_SkipTest_on_import(self): + vfs = {abspath('/foo'): ['my_package'], + abspath('/foo/my_package'): ['__init__.py', 'test_module.py']} + self.setup_import_issue_package_tests(vfs) + import_calls = [] + def _get_module_from_name(name): + import_calls.append(name) + raise unittest.SkipTest('skipperoo') + loader = unittest.TestLoader() + loader._get_module_from_name = _get_module_from_name + suite = loader.discover(abspath('/foo')) + + self.assertIn(abspath('/foo'), sys.path) + self.assertEqual(suite.countTestCases(), 1) + result = unittest.TestResult() + suite.run(result) + self.assertEqual(len(result.skipped), 1) + self.assertEqual(result.testsRun, 1) + self.assertEqual(import_calls, ['my_package']) + + # Check picklability + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickle.loads(pickle.dumps(suite, proto)) + def test_command_line_handling_parseArgs(self): program = TestableTestProgram() diff --git a/Lib/unittest/test/test_loader.py b/Lib/unittest/test/test_loader.py index b62a1b5..68f1036 100644 --- a/Lib/unittest/test/test_loader.py +++ b/Lib/unittest/test/test_loader.py @@ -1,12 +1,36 @@ import sys import types - +import warnings import unittest +# Decorator used in the deprecation tests to reset the warning registry for +# test isolation and reproducibility. +def warningregistry(func): + def wrapper(*args, **kws): + missing = object() + saved = getattr(warnings, '__warningregistry__', missing).copy() + try: + return func(*args, **kws) + finally: + if saved is missing: + try: + del warnings.__warningregistry__ + except AttributeError: + pass + else: + warnings.__warningregistry__ = saved + class Test_TestLoader(unittest.TestCase): + ### Basic object tests + ################################################################ + + def test___init__(self): + loader = unittest.TestLoader() + self.assertEqual([], loader.errors) + ### Tests for TestLoader.loadTestsFromTestCase ################################################################ @@ -150,6 +174,7 @@ class Test_TestLoader(unittest.TestCase): # Check that loadTestsFromModule honors (or not) a module # with a load_tests function. + @warningregistry def test_loadTestsFromModule__load_tests(self): m = types.ModuleType('m') class MyTestCase(unittest.TestCase): @@ -168,10 +193,144 @@ class Test_TestLoader(unittest.TestCase): suite = loader.loadTestsFromModule(m) self.assertIsInstance(suite, unittest.TestSuite) self.assertEqual(load_tests_args, [loader, suite, None]) + # With Python 3.5, the undocumented and unofficial use_load_tests is + # ignored (and deprecated). + load_tests_args = [] + with warnings.catch_warnings(record=False): + warnings.simplefilter('never') + suite = loader.loadTestsFromModule(m, use_load_tests=False) + self.assertEqual(load_tests_args, [loader, suite, None]) + + @warningregistry + def test_loadTestsFromModule__use_load_tests_deprecated_positional(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + load_tests_args = [] + def load_tests(loader, tests, pattern): + self.assertIsInstance(tests, unittest.TestSuite) + load_tests_args.extend((loader, tests, pattern)) + return tests + m.load_tests = load_tests + # The method still works. + loader = unittest.TestLoader() + # use_load_tests=True as a positional argument. + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + suite = loader.loadTestsFromModule(m, False) + self.assertIsInstance(suite, unittest.TestSuite) + # load_tests was still called because use_load_tests is deprecated + # and ignored. + self.assertEqual(load_tests_args, [loader, suite, None]) + # We got a warning. + self.assertIs(w[-1].category, DeprecationWarning) + self.assertEqual(str(w[-1].message), + 'use_load_tests is deprecated and ignored') + + @warningregistry + def test_loadTestsFromModule__use_load_tests_deprecated_keyword(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + load_tests_args = [] + def load_tests(loader, tests, pattern): + self.assertIsInstance(tests, unittest.TestSuite) + load_tests_args.extend((loader, tests, pattern)) + return tests + m.load_tests = load_tests + # The method still works. + loader = unittest.TestLoader() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + suite = loader.loadTestsFromModule(m, use_load_tests=False) + self.assertIsInstance(suite, unittest.TestSuite) + # load_tests was still called because use_load_tests is deprecated + # and ignored. + self.assertEqual(load_tests_args, [loader, suite, None]) + # We got a warning. + self.assertIs(w[-1].category, DeprecationWarning) + self.assertEqual(str(w[-1].message), + 'use_load_tests is deprecated and ignored') + + @warningregistry + def test_loadTestsFromModule__too_many_positional_args(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase load_tests_args = [] - suite = loader.loadTestsFromModule(m, use_load_tests=False) - self.assertEqual(load_tests_args, []) + def load_tests(loader, tests, pattern): + self.assertIsInstance(tests, unittest.TestSuite) + load_tests_args.extend((loader, tests, pattern)) + return tests + m.load_tests = load_tests + loader = unittest.TestLoader() + with self.assertRaises(TypeError) as cm, \ + warnings.catch_warning(record=True) as w: + loader.loadTestsFromModule(m, False, 'testme.*') + # We still got the deprecation warning. + self.assertIs(w[-1].category, DeprecationWarning) + self.assertEqual(str(w[-1].message), + 'use_load_tests is deprecated and ignored') + # We also got a TypeError for too many positional arguments. + self.assertEqual(type(cm.exception), TypeError) + self.assertEqual( + str(cm.exception), + 'loadTestsFromModule() takes 1 positional argument but 3 were given') + + @warningregistry + def test_loadTestsFromModule__use_load_tests_other_bad_keyword(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + load_tests_args = [] + def load_tests(loader, tests, pattern): + self.assertIsInstance(tests, unittest.TestSuite) + load_tests_args.extend((loader, tests, pattern)) + return tests + m.load_tests = load_tests + loader = unittest.TestLoader() + with warnings.catch_warnings(): + warnings.simplefilter('never') + with self.assertRaises(TypeError) as cm: + loader.loadTestsFromModule( + m, use_load_tests=False, very_bad=True, worse=False) + self.assertEqual(type(cm.exception), TypeError) + # The error message names the first bad argument alphabetically, + # however use_load_tests (which sorts first) is ignored. + self.assertEqual( + str(cm.exception), + "loadTestsFromModule() got an unexpected keyword argument 'very_bad'") + + def test_loadTestsFromModule__pattern(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + load_tests_args = [] + def load_tests(loader, tests, pattern): + self.assertIsInstance(tests, unittest.TestSuite) + load_tests_args.extend((loader, tests, pattern)) + return tests + m.load_tests = load_tests + + loader = unittest.TestLoader() + suite = loader.loadTestsFromModule(m, pattern='testme.*') + self.assertIsInstance(suite, unittest.TestSuite) + self.assertEqual(load_tests_args, [loader, suite, 'testme.*']) def test_loadTestsFromModule__faulty_load_tests(self): m = types.ModuleType('m') @@ -184,6 +343,13 @@ class Test_TestLoader(unittest.TestCase): suite = loader.loadTestsFromModule(m) self.assertIsInstance(suite, unittest.TestSuite) self.assertEqual(suite.countTestCases(), 1) + # Errors loading the suite are also captured for introspection. + self.assertNotEqual([], loader.errors) + self.assertEqual(1, len(loader.errors)) + error = loader.errors[0] + self.assertTrue( + 'Failed to call load_tests:' in error, + 'missing error string in %r' % error) test = list(suite)[0] self.assertRaisesRegex(TypeError, "some failure", test.m) @@ -219,15 +385,15 @@ class Test_TestLoader(unittest.TestCase): def test_loadTestsFromName__malformed_name(self): loader = unittest.TestLoader() - # XXX Should this raise ValueError or ImportError? - try: - loader.loadTestsFromName('abc () //') - except ValueError: - pass - except ImportError: - pass - else: - self.fail("TestLoader.loadTestsFromName failed to raise ValueError") + suite = loader.loadTestsFromName('abc () //') + error, test = self.check_deferred_error(loader, suite) + expected = "Failed to import test module: abc () //" + expected_regex = "Failed to import test module: abc \(\) //" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex( + ImportError, expected_regex, getattr(test, 'abc () //')) # "The specifier name is a ``dotted name'' that may resolve ... to a # module" @@ -236,28 +402,47 @@ class Test_TestLoader(unittest.TestCase): def test_loadTestsFromName__unknown_module_name(self): loader = unittest.TestLoader() - try: - loader.loadTestsFromName('sdasfasfasdf') - except ImportError as e: - self.assertEqual(str(e), "No module named 'sdasfasfasdf'") - else: - self.fail("TestLoader.loadTestsFromName failed to raise ImportError") + suite = loader.loadTestsFromName('sdasfasfasdf') + expected = "No module named 'sdasfasfasdf'" + error, test = self.check_deferred_error(loader, suite) + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(ImportError, expected, test.sdasfasfasdf) # "The specifier name is a ``dotted name'' that may resolve either to # a module, a test case class, a TestSuite instance, a test method # within a test case class, or a callable object which returns a # TestCase or TestSuite instance." # - # What happens when the module is found, but the attribute can't? - def test_loadTestsFromName__unknown_attr_name(self): + # What happens when the module is found, but the attribute isn't? + def test_loadTestsFromName__unknown_attr_name_on_module(self): loader = unittest.TestLoader() - try: - loader.loadTestsFromName('unittest.sdasfasfasdf') - except AttributeError as e: - self.assertEqual(str(e), "'module' object has no attribute 'sdasfasfasdf'") - else: - self.fail("TestLoader.loadTestsFromName failed to raise AttributeError") + suite = loader.loadTestsFromName('unittest.loader.sdasfasfasdf') + expected = "module 'unittest.loader' has no attribute 'sdasfasfasdf'" + error, test = self.check_deferred_error(loader, suite) + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.sdasfasfasdf) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # What happens when the module is found, but the attribute isn't? + def test_loadTestsFromName__unknown_attr_name_on_package(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromName('unittest.sdasfasfasdf') + expected = "No module named 'unittest.sdasfasfasdf'" + error, test = self.check_deferred_error(loader, suite) + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(ImportError, expected, test.sdasfasfasdf) # "The specifier name is a ``dotted name'' that may resolve either to # a module, a test case class, a TestSuite instance, a test method @@ -269,12 +454,13 @@ class Test_TestLoader(unittest.TestCase): def test_loadTestsFromName__relative_unknown_name(self): loader = unittest.TestLoader() - try: - loader.loadTestsFromName('sdasfasfasdf', unittest) - except AttributeError as e: - self.assertEqual(str(e), "'module' object has no attribute 'sdasfasfasdf'") - else: - self.fail("TestLoader.loadTestsFromName failed to raise AttributeError") + suite = loader.loadTestsFromName('sdasfasfasdf', unittest) + expected = "module 'unittest' has no attribute 'sdasfasfasdf'" + error, test = self.check_deferred_error(loader, suite) + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.sdasfasfasdf) # "The specifier name is a ``dotted name'' that may resolve either to # a module, a test case class, a TestSuite instance, a test method @@ -290,12 +476,13 @@ class Test_TestLoader(unittest.TestCase): def test_loadTestsFromName__relative_empty_name(self): loader = unittest.TestLoader() - try: - loader.loadTestsFromName('', unittest) - except AttributeError as e: - pass - else: - self.fail("Failed to raise AttributeError") + suite = loader.loadTestsFromName('', unittest) + error, test = self.check_deferred_error(loader, suite) + expected = "has no attribute ''" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, getattr(test, '')) # "The specifier name is a ``dotted name'' that may resolve either to # a module, a test case class, a TestSuite instance, a test method @@ -310,14 +497,15 @@ class Test_TestLoader(unittest.TestCase): loader = unittest.TestLoader() # XXX Should this raise AttributeError or ValueError? - try: - loader.loadTestsFromName('abc () //', unittest) - except ValueError: - pass - except AttributeError: - pass - else: - self.fail("TestLoader.loadTestsFromName failed to raise ValueError") + suite = loader.loadTestsFromName('abc () //', unittest) + error, test = self.check_deferred_error(loader, suite) + expected = "module 'unittest' has no attribute 'abc () //'" + expected_regex = "module 'unittest' has no attribute 'abc \(\) //'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex( + AttributeError, expected_regex, getattr(test, 'abc () //')) # "The method optionally resolves name relative to the given module" # @@ -423,12 +611,13 @@ class Test_TestLoader(unittest.TestCase): m.testcase_1 = MyTestCase loader = unittest.TestLoader() - try: - loader.loadTestsFromName('testcase_1.testfoo', m) - except AttributeError as e: - self.assertEqual(str(e), "type object 'MyTestCase' has no attribute 'testfoo'") - else: - self.fail("Failed to raise AttributeError") + suite = loader.loadTestsFromName('testcase_1.testfoo', m) + expected = "type object 'MyTestCase' has no attribute 'testfoo'" + error, test = self.check_deferred_error(loader, suite) + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.testfoo) # "The specifier name is a ``dotted name'' that may resolve ... to # ... a callable object which returns a ... TestSuite instance" @@ -546,6 +735,23 @@ class Test_TestLoader(unittest.TestCase): ### Tests for TestLoader.loadTestsFromNames() ################################################################ + def check_deferred_error(self, loader, suite): + """Helper function for checking that errors in loading are reported. + + :param loader: A loader with some errors. + :param suite: A suite that should have a late bound error. + :return: The first error message from the loader and the test object + from the suite. + """ + self.assertIsInstance(suite, unittest.TestSuite) + self.assertEqual(suite.countTestCases(), 1) + # Errors loading the suite are also captured for introspection. + self.assertNotEqual([], loader.errors) + self.assertEqual(1, len(loader.errors)) + error = loader.errors[0] + test = list(suite)[0] + return error, test + # "Similar to loadTestsFromName(), but takes a sequence of names rather # than a single name." # @@ -598,14 +804,15 @@ class Test_TestLoader(unittest.TestCase): loader = unittest.TestLoader() # XXX Should this raise ValueError or ImportError? - try: - loader.loadTestsFromNames(['abc () //']) - except ValueError: - pass - except ImportError: - pass - else: - self.fail("TestLoader.loadTestsFromNames failed to raise ValueError") + suite = loader.loadTestsFromNames(['abc () //']) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "Failed to import test module: abc () //" + expected_regex = "Failed to import test module: abc \(\) //" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex( + ImportError, expected_regex, getattr(test, 'abc () //')) # "The specifier name is a ``dotted name'' that may resolve either to # a module, a test case class, a TestSuite instance, a test method @@ -616,12 +823,13 @@ class Test_TestLoader(unittest.TestCase): def test_loadTestsFromNames__unknown_module_name(self): loader = unittest.TestLoader() - try: - loader.loadTestsFromNames(['sdasfasfasdf']) - except ImportError as e: - self.assertEqual(str(e), "No module named 'sdasfasfasdf'") - else: - self.fail("TestLoader.loadTestsFromNames failed to raise ImportError") + suite = loader.loadTestsFromNames(['sdasfasfasdf']) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "Failed to import test module: sdasfasfasdf" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(ImportError, expected, test.sdasfasfasdf) # "The specifier name is a ``dotted name'' that may resolve either to # a module, a test case class, a TestSuite instance, a test method @@ -632,12 +840,14 @@ class Test_TestLoader(unittest.TestCase): def test_loadTestsFromNames__unknown_attr_name(self): loader = unittest.TestLoader() - try: - loader.loadTestsFromNames(['unittest.sdasfasfasdf', 'unittest']) - except AttributeError as e: - self.assertEqual(str(e), "'module' object has no attribute 'sdasfasfasdf'") - else: - self.fail("TestLoader.loadTestsFromNames failed to raise AttributeError") + suite = loader.loadTestsFromNames( + ['unittest.loader.sdasfasfasdf', 'unittest.test.dummy']) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "module 'unittest.loader' has no attribute 'sdasfasfasdf'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.sdasfasfasdf) # "The specifier name is a ``dotted name'' that may resolve either to # a module, a test case class, a TestSuite instance, a test method @@ -651,12 +861,13 @@ class Test_TestLoader(unittest.TestCase): def test_loadTestsFromNames__unknown_name_relative_1(self): loader = unittest.TestLoader() - try: - loader.loadTestsFromNames(['sdasfasfasdf'], unittest) - except AttributeError as e: - self.assertEqual(str(e), "'module' object has no attribute 'sdasfasfasdf'") - else: - self.fail("TestLoader.loadTestsFromName failed to raise AttributeError") + suite = loader.loadTestsFromNames(['sdasfasfasdf'], unittest) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "module 'unittest' has no attribute 'sdasfasfasdf'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.sdasfasfasdf) # "The specifier name is a ``dotted name'' that may resolve either to # a module, a test case class, a TestSuite instance, a test method @@ -670,12 +881,13 @@ class Test_TestLoader(unittest.TestCase): def test_loadTestsFromNames__unknown_name_relative_2(self): loader = unittest.TestLoader() - try: - loader.loadTestsFromNames(['TestCase', 'sdasfasfasdf'], unittest) - except AttributeError as e: - self.assertEqual(str(e), "'module' object has no attribute 'sdasfasfasdf'") - else: - self.fail("TestLoader.loadTestsFromName failed to raise AttributeError") + suite = loader.loadTestsFromNames(['TestCase', 'sdasfasfasdf'], unittest) + error, test = self.check_deferred_error(loader, list(suite)[1]) + expected = "module 'unittest' has no attribute 'sdasfasfasdf'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.sdasfasfasdf) # "The specifier name is a ``dotted name'' that may resolve either to # a module, a test case class, a TestSuite instance, a test method @@ -691,12 +903,13 @@ class Test_TestLoader(unittest.TestCase): def test_loadTestsFromNames__relative_empty_name(self): loader = unittest.TestLoader() - try: - loader.loadTestsFromNames([''], unittest) - except AttributeError: - pass - else: - self.fail("Failed to raise ValueError") + suite = loader.loadTestsFromNames([''], unittest) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "has no attribute ''" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, getattr(test, '')) # "The specifier name is a ``dotted name'' that may resolve either to # a module, a test case class, a TestSuite instance, a test method @@ -710,14 +923,15 @@ class Test_TestLoader(unittest.TestCase): loader = unittest.TestLoader() # XXX Should this raise AttributeError or ValueError? - try: - loader.loadTestsFromNames(['abc () //'], unittest) - except AttributeError: - pass - except ValueError: - pass - else: - self.fail("TestLoader.loadTestsFromNames failed to raise ValueError") + suite = loader.loadTestsFromNames(['abc () //'], unittest) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "module 'unittest' has no attribute 'abc () //'" + expected_regex = "module 'unittest' has no attribute 'abc \(\) //'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex( + AttributeError, expected_regex, getattr(test, 'abc () //')) # "The method optionally resolves name relative to the given module" # @@ -835,12 +1049,13 @@ class Test_TestLoader(unittest.TestCase): m.testcase_1 = MyTestCase loader = unittest.TestLoader() - try: - loader.loadTestsFromNames(['testcase_1.testfoo'], m) - except AttributeError as e: - self.assertEqual(str(e), "type object 'MyTestCase' has no attribute 'testfoo'") - else: - self.fail("Failed to raise AttributeError") + suite = loader.loadTestsFromNames(['testcase_1.testfoo'], m) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "type object 'MyTestCase' has no attribute 'testfoo'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.testfoo) # "The specifier name is a ``dotted name'' that may resolve ... to # ... a callable object which returns a ... TestSuite instance" diff --git a/Lib/unittest/test/test_program.py b/Lib/unittest/test/test_program.py index 725d67f..1cfc179 100644 --- a/Lib/unittest/test/test_program.py +++ b/Lib/unittest/test/test_program.py @@ -134,6 +134,7 @@ class InitialisableProgram(unittest.TestProgram): result = None verbosity = 1 defaultTest = None + tb_locals = False testRunner = None testLoader = unittest.defaultTestLoader module = '__main__' @@ -147,18 +148,19 @@ RESULT = object() class FakeRunner(object): initArgs = None test = None - raiseError = False + raiseError = 0 def __init__(self, **kwargs): FakeRunner.initArgs = kwargs if FakeRunner.raiseError: - FakeRunner.raiseError = False + FakeRunner.raiseError -= 1 raise TypeError def run(self, test): FakeRunner.test = test return RESULT + class TestCommandLineArgs(unittest.TestCase): def setUp(self): @@ -166,7 +168,7 @@ class TestCommandLineArgs(unittest.TestCase): self.program.createTests = lambda: None FakeRunner.initArgs = None FakeRunner.test = None - FakeRunner.raiseError = False + FakeRunner.raiseError = 0 def testVerbosity(self): program = self.program @@ -256,6 +258,7 @@ class TestCommandLineArgs(unittest.TestCase): self.assertEqual(FakeRunner.initArgs, {'verbosity': 'verbosity', 'failfast': 'failfast', 'buffer': 'buffer', + 'tb_locals': False, 'warnings': 'warnings'}) self.assertEqual(FakeRunner.test, 'test') self.assertIs(program.result, RESULT) @@ -274,10 +277,25 @@ class TestCommandLineArgs(unittest.TestCase): self.assertEqual(FakeRunner.test, 'test') self.assertIs(program.result, RESULT) + def test_locals(self): + program = self.program + + program.testRunner = FakeRunner + program.parseArgs([None, '--locals']) + self.assertEqual(True, program.tb_locals) + program.runTests() + self.assertEqual(FakeRunner.initArgs, {'buffer': False, + 'failfast': False, + 'tb_locals': True, + 'verbosity': 1, + 'warnings': None}) + def testRunTestsOldRunnerClass(self): program = self.program - FakeRunner.raiseError = True + # Two TypeErrors are needed to fall all the way back to old-style + # runners - one to fail tb_locals, one to fail buffer etc. + FakeRunner.raiseError = 2 program.testRunner = FakeRunner program.verbosity = 'verbosity' program.failfast = 'failfast' diff --git a/Lib/unittest/test/test_result.py b/Lib/unittest/test/test_result.py index 489fe17..e39e2ea 100644 --- a/Lib/unittest/test/test_result.py +++ b/Lib/unittest/test/test_result.py @@ -8,6 +8,20 @@ import traceback import unittest +class MockTraceback(object): + class TracebackException: + def __init__(self, *args, **kwargs): + self.capture_locals = kwargs.get('capture_locals', False) + def format(self): + result = ['A traceback'] + if self.capture_locals: + result.append('locals') + return result + +def restore_traceback(): + unittest.result.traceback = traceback + + class Test_TestResult(unittest.TestCase): # Note: there are not separate tests for TestResult.wasSuccessful(), # TestResult.errors, TestResult.failures, TestResult.testsRun or @@ -227,6 +241,25 @@ class Test_TestResult(unittest.TestCase): self.assertIs(test_case, test) self.assertIsInstance(formatted_exc, str) + def test_addError_locals(self): + class Foo(unittest.TestCase): + def test_1(self): + 1/0 + + test = Foo('test_1') + result = unittest.TestResult() + result.tb_locals = True + + unittest.result.traceback = MockTraceback + self.addCleanup(restore_traceback) + result.startTestRun() + test.run(result) + result.stopTestRun() + + self.assertEqual(len(result.errors), 1) + test_case, formatted_exc = result.errors[0] + self.assertEqual('A tracebacklocals', formatted_exc) + def test_addSubTest(self): class Foo(unittest.TestCase): def test_1(self): @@ -398,6 +431,7 @@ def __init__(self, stream=None, descriptions=None, verbosity=None): self.testsRun = 0 self.shouldStop = False self.buffer = False + self.tb_locals = False classDict['__init__'] = __init__ OldResult = type('OldResult', (object,), classDict) @@ -454,15 +488,6 @@ class Test_OldTestResult(unittest.TestCase): runner.run(Test('testFoo')) -class MockTraceback(object): - @staticmethod - def format_exception(*_): - return ['A traceback'] - -def restore_traceback(): - unittest.result.traceback = traceback - - class TestOutputBuffering(unittest.TestCase): def setUp(self): diff --git a/Lib/unittest/test/test_runner.py b/Lib/unittest/test/test_runner.py index 7c0bd51..9cbc260 100644 --- a/Lib/unittest/test/test_runner.py +++ b/Lib/unittest/test/test_runner.py @@ -158,7 +158,7 @@ class Test_TextTestRunner(unittest.TestCase): self.assertEqual(runner.warnings, None) self.assertTrue(runner.descriptions) self.assertEqual(runner.resultclass, unittest.TextTestResult) - + self.assertFalse(runner.tb_locals) def test_multiple_inheritance(self): class AResult(unittest.TestResult): @@ -172,14 +172,13 @@ class Test_TextTestRunner(unittest.TestCase): # on arguments in its __init__ super call ATextResult(None, None, 1) - def testBufferAndFailfast(self): class Test(unittest.TestCase): def testFoo(self): pass result = unittest.TestResult() runner = unittest.TextTestRunner(stream=io.StringIO(), failfast=True, - buffer=True) + buffer=True) # Use our result object runner._makeResult = lambda: result runner.run(Test('testFoo')) @@ -187,6 +186,11 @@ class Test_TextTestRunner(unittest.TestCase): self.assertTrue(result.failfast) self.assertTrue(result.buffer) + def test_locals(self): + runner = unittest.TextTestRunner(stream=io.StringIO(), tb_locals=True) + result = runner.run(unittest.TestSuite()) + self.assertEqual(True, result.tb_locals) + def testRunnerRegistersResult(self): class Test(unittest.TestCase): def testFoo(self): diff --git a/Lib/unittest/test/test_setups.py b/Lib/unittest/test/test_setups.py index 392f95e..2df703e 100644 --- a/Lib/unittest/test/test_setups.py +++ b/Lib/unittest/test/test_setups.py @@ -111,7 +111,7 @@ class TestSetups(unittest.TestCase): self.assertEqual(len(result.errors), 1) error, _ = result.errors[0] self.assertEqual(str(error), - 'setUpClass (%s.BrokenTest)' % __name__) + 'setUpClass (%s.%s)' % (__name__, BrokenTest.__qualname__)) def test_error_in_teardown_class(self): class Test(unittest.TestCase): @@ -144,7 +144,7 @@ class TestSetups(unittest.TestCase): error, _ = result.errors[0] self.assertEqual(str(error), - 'tearDownClass (%s.Test)' % __name__) + 'tearDownClass (%s.%s)' % (__name__, Test.__qualname__)) def test_class_not_torndown_when_setup_fails(self): class Test(unittest.TestCase): @@ -414,7 +414,8 @@ class TestSetups(unittest.TestCase): self.assertEqual(len(result.errors), 0) self.assertEqual(len(result.skipped), 1) skipped = result.skipped[0][0] - self.assertEqual(str(skipped), 'setUpClass (%s.Test)' % __name__) + self.assertEqual(str(skipped), + 'setUpClass (%s.%s)' % (__name__, Test.__qualname__)) def test_skiptest_in_setupmodule(self): class Test(unittest.TestCase): diff --git a/Lib/unittest/test/testmock/testmagicmethods.py b/Lib/unittest/test/testmock/testmagicmethods.py index e05c6e0..bb9b956 100644 --- a/Lib/unittest/test/testmock/testmagicmethods.py +++ b/Lib/unittest/test/testmock/testmagicmethods.py @@ -424,6 +424,17 @@ class TestMockingMagicMethods(unittest.TestCase): self.assertEqual(list(m), []) + def test_matmul(self): + m = MagicMock() + self.assertIsInstance(m @ 1, MagicMock) + m.__matmul__.return_value = 42 + m.__rmatmul__.return_value = 666 + m.__imatmul__.return_value = 24 + self.assertEqual(m @ 1, 42) + self.assertEqual(1 @ m, 666) + m @= 24 + self.assertEqual(m, 24) + def test_divmod_and_rdivmod(self): m = MagicMock() self.assertIsInstance(divmod(5, m), MagicMock) diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py index 976c40f..97d844f 100644 --- a/Lib/unittest/test/testmock/testmock.py +++ b/Lib/unittest/test/testmock/testmock.py @@ -174,6 +174,15 @@ class MockTest(unittest.TestCase): self.assertEqual([mock(), mock(), mock()], [3, 2, 1], "callable side effect not used correctly") + def test_autospec_side_effect_exception(self): + # Test for issue 23661 + def f(): + pass + + mock = create_autospec(f) + mock.side_effect = ValueError('Bazinga!') + self.assertRaisesRegex(ValueError, 'Bazinga!', mock) + @unittest.skipUnless('java' in sys.platform, 'This test only applies to Jython') def test_java_exception_side_effect(self): @@ -1191,6 +1200,42 @@ class MockTest(unittest.TestCase): m = mock.create_autospec(object(), name='sweet_func') self.assertIn('sweet_func', repr(m)) + #Issue21238 + def test_mock_unsafe(self): + m = Mock() + with self.assertRaises(AttributeError): + m.assert_foo_call() + with self.assertRaises(AttributeError): + m.assret_foo_call() + m = Mock(unsafe=True) + m.assert_foo_call() + m.assret_foo_call() + + #Issue21262 + def test_assert_not_called(self): + m = Mock() + m.hello.assert_not_called() + m.hello() + with self.assertRaises(AssertionError): + m.hello.assert_not_called() + + #Issue21256 printout of keyword args should be in deterministic order + def test_sorted_call_signature(self): + m = Mock() + m.hello(name='hello', daddy='hero') + text = "call(daddy='hero', name='hello')" + self.assertEqual(repr(m.hello.call_args), text) + + #Issue21270 overrides tuple methods for mock.call objects + def test_override_tuple_methods(self): + c = call.count() + i = call.index(132,'hello') + m = Mock() + m.count() + m.index(132,"hello") + self.assertEqual(m.method_calls[0], c) + self.assertEqual(m.method_calls[1], i) + def test_mock_add_spec(self): class _One(object): one = 1 diff --git a/Lib/unittest/test/testmock/testpatch.py b/Lib/unittest/test/testmock/testpatch.py index b516f42..28fe86b 100644 --- a/Lib/unittest/test/testmock/testpatch.py +++ b/Lib/unittest/test/testmock/testpatch.py @@ -377,7 +377,7 @@ class PatchTest(unittest.TestCase): def test_patchobject_wont_create_by_default(self): try: - @patch.object(SomeClass, 'frooble', sentinel.Frooble) + @patch.object(SomeClass, 'ord', sentinel.Frooble) def test(): self.fail('Patching non existent attributes should fail') @@ -386,7 +386,27 @@ class PatchTest(unittest.TestCase): pass else: self.fail('Patching non existent attributes should fail') - self.assertFalse(hasattr(SomeClass, 'frooble')) + self.assertFalse(hasattr(SomeClass, 'ord')) + + + def test_patch_builtins_without_create(self): + @patch(__name__+'.ord') + def test_ord(mock_ord): + mock_ord.return_value = 101 + return ord('c') + + @patch(__name__+'.open') + def test_open(mock_open): + m = mock_open.return_value + m.read.return_value = 'abcd' + + fobj = open('doesnotexists.txt') + data = fobj.read() + fobj.close() + return data + + self.assertEqual(test_ord(), 101) + self.assertEqual(test_open(), 'abcd') def test_patch_with_static_methods(self): diff --git a/Lib/unittest/util.py b/Lib/unittest/util.py index aee498f..45485dc 100644 --- a/Lib/unittest/util.py +++ b/Lib/unittest/util.py @@ -52,7 +52,7 @@ def safe_repr(obj, short=False): return result[:_MAX_LENGTH] + ' [truncated]...' def strclass(cls): - return "%s.%s" % (cls.__module__, cls.__name__) + return "%s.%s" % (cls.__module__, cls.__qualname__) def sorted_list_difference(expected, actual): """Finds elements in only one or the other of two, sorted input lists. |