diff options
Diffstat (limited to 'Lib/unittest/case.py')
-rw-r--r-- | Lib/unittest/case.py | 399 |
1 files changed, 291 insertions, 108 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index f56af55..69888a5 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -3,14 +3,17 @@ import sys import functools import difflib +import logging import pprint import re import warnings import collections +import contextlib +import traceback from . import result from .util import (strclass, safe_repr, _count_diff_all_purpose, - _count_diff_hashable) + _count_diff_hashable, _common_shorten_repr) __unittest = True @@ -26,17 +29,11 @@ class SkipTest(Exception): instead of raising this directly. """ -class _ExpectedFailure(Exception): +class _ShouldStop(Exception): """ - Raise this when a test is expected to fail. - - This is an implementation detail. + The test should stop. """ - def __init__(self, exc_info): - super(_ExpectedFailure, self).__init__() - self.exc_info = exc_info - class _UnexpectedSuccess(Exception): """ The test was supposed to fail, but it didn't! @@ -44,13 +41,43 @@ class _UnexpectedSuccess(Exception): class _Outcome(object): - def __init__(self): + def __init__(self, result=None): + self.expecting_failure = False + self.result = result + self.result_supports_subtests = hasattr(result, "addSubTest") self.success = True - self.skipped = None - self.unexpectedSuccess = None + self.skipped = [] self.expectedFailure = None self.errors = [] - self.failures = [] + + @contextlib.contextmanager + def testPartExecutor(self, test_case, isTest=False): + old_success = self.success + self.success = True + try: + yield + except KeyboardInterrupt: + raise + except SkipTest as e: + self.success = False + self.skipped.append((test_case, str(e))) + except _ShouldStop: + pass + except: + exc_info = sys.exc_info() + if self.expecting_failure: + self.expectedFailure = exc_info + else: + self.success = False + self.errors.append((test_case, exc_info)) + # explicitly break a reference cycle: + # exc_info -> frame -> exc_info + exc_info = None + else: + if self.result_supports_subtests and self.success: + self.errors.append((test_case, None)) + finally: + self.success = self.success and old_success def _id(obj): @@ -88,22 +115,26 @@ def skipUnless(condition, reason): return skip(reason) return _id +def expectedFailure(test_item): + test_item.__unittest_expecting_failure__ = True + return test_item -def expectedFailure(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - try: - func(*args, **kwargs) - except Exception: - raise _ExpectedFailure(sys.exc_info()) - raise _UnexpectedSuccess - return wrapper +class _BaseTestCaseContext: + + def __init__(self, test_case): + self.test_case = test_case + + def _raiseFailure(self, standardMsg): + msg = self.test_case._formatMessage(self.msg, standardMsg) + raise self.test_case.failureException(msg) -class _AssertRaisesBaseContext(object): + +class _AssertRaisesBaseContext(_BaseTestCaseContext): def __init__(self, expected, test_case, callable_obj=None, expected_regex=None): + _BaseTestCaseContext.__init__(self, test_case) self.expected = expected self.test_case = test_case if callable_obj is not None: @@ -113,15 +144,11 @@ class _AssertRaisesBaseContext(object): self.obj_name = str(callable_obj) else: self.obj_name = None - if isinstance(expected_regex, (bytes, str)): + if expected_regex is not None: expected_regex = re.compile(expected_regex) self.expected_regex = expected_regex self.msg = None - def _raiseFailure(self, standardMsg): - msg = self.test_case._formatMessage(self.msg, standardMsg) - raise self.test_case.failureException(msg) - def handle(self, name, callable_obj, args, kwargs): """ If callable_obj is None, assertRaises/Warns is being used as a @@ -135,7 +162,6 @@ class _AssertRaisesBaseContext(object): callable_obj(*args, **kwargs) - class _AssertRaisesContext(_AssertRaisesBaseContext): """A context manager used to implement TestCase.assertRaises* methods.""" @@ -153,6 +179,8 @@ class _AssertRaisesContext(_AssertRaisesBaseContext): self.obj_name)) else: self._raiseFailure("{} not raised".format(exc_name)) + else: + traceback.clear_frames(tb) if not issubclass(exc_type, self.expected): # let unexpected exceptions pass through return False @@ -217,6 +245,74 @@ class _AssertWarnsContext(_AssertRaisesBaseContext): self._raiseFailure("{} not triggered".format(exc_name)) + +_LoggingWatcher = collections.namedtuple("_LoggingWatcher", + ["records", "output"]) + + +class _CapturingHandler(logging.Handler): + """ + A logging handler capturing all (raw and formatted) logging output. + """ + + def __init__(self): + logging.Handler.__init__(self) + self.watcher = _LoggingWatcher([], []) + + def flush(self): + pass + + def emit(self, record): + self.watcher.records.append(record) + msg = self.format(record) + self.watcher.output.append(msg) + + + +class _AssertLogsContext(_BaseTestCaseContext): + """A context manager used to implement TestCase.assertLogs().""" + + LOGGING_FORMAT = "%(levelname)s:%(name)s:%(message)s" + + def __init__(self, test_case, logger_name, level): + _BaseTestCaseContext.__init__(self, test_case) + self.logger_name = logger_name + if level: + self.level = logging._nameToLevel.get(level, level) + else: + self.level = logging.INFO + self.msg = None + + def __enter__(self): + if isinstance(self.logger_name, logging.Logger): + logger = self.logger = self.logger_name + else: + logger = self.logger = logging.getLogger(self.logger_name) + formatter = logging.Formatter(self.LOGGING_FORMAT) + handler = _CapturingHandler() + handler.setFormatter(formatter) + self.watcher = handler.watcher + self.old_handlers = logger.handlers[:] + self.old_level = logger.level + self.old_propagate = logger.propagate + logger.handlers = [handler] + logger.setLevel(self.level) + logger.propagate = False + return handler.watcher + + def __exit__(self, exc_type, exc_value, tb): + self.logger.handlers = self.old_handlers + self.logger.propagate = self.old_propagate + self.logger.setLevel(self.old_level) + if exc_type is not None: + # let unexpected exceptions pass through + return False + if len(self.watcher.records) == 0: + self._raiseFailure( + "no logs of level {} or higher triggered on {}" + .format(logging.getLevelName(self.level), self.logger.name)) + + class TestCase(object): """A class whose instances are single test cases. @@ -270,7 +366,7 @@ class TestCase(object): not have a method with the specified name. """ self._testMethodName = methodName - self._outcomeForDoCleanups = None + self._outcome = None self._testMethodDoc = 'No test' try: testMethod = getattr(self, methodName) @@ -283,6 +379,7 @@ class TestCase(object): else: self._testMethodDoc = testMethod.__doc__ self._cleanups = [] + self._subtest = None # Map types to custom assertEqual functions that will compare # instances of said type in more detail to generate a more useful @@ -370,44 +467,80 @@ class TestCase(object): return "<%s testMethod=%s>" % \ (strclass(self.__class__), self._testMethodName) - def _addSkip(self, result, reason): + def _addSkip(self, result, test_case, reason): addSkip = getattr(result, 'addSkip', None) if addSkip is not None: - addSkip(self, reason) + addSkip(test_case, reason) else: warnings.warn("TestResult has no addSkip method, skips not reported", RuntimeWarning, 2) + result.addSuccess(test_case) + + @contextlib.contextmanager + def subTest(self, msg=None, **params): + """Return a context manager that will return the enclosed block + of code in a subtest identified by the optional message and + keyword parameters. A failure in the subtest marks the test + case as failed but resumes execution at the end of the enclosed + block, allowing further test code to be executed. + """ + if not self._outcome.result_supports_subtests: + yield + return + parent = self._subtest + if parent is None: + params_map = collections.ChainMap(params) + else: + params_map = parent.params.new_child(params) + self._subtest = _SubTest(self, msg, params_map) + try: + with self._outcome.testPartExecutor(self._subtest, isTest=True): + yield + if not self._outcome.success: + result = self._outcome.result + if result is not None and result.failfast: + raise _ShouldStop + elif self._outcome.expectedFailure: + # If the test is expecting a failure, we really want to + # stop now and register the expected failure. + raise _ShouldStop + finally: + self._subtest = parent + + def _feedErrorsToResult(self, result, errors): + for test, exc_info in errors: + if isinstance(test, _SubTest): + result.addSubTest(test.test_case, test, exc_info) + elif exc_info is not None: + if issubclass(exc_info[0], self.failureException): + result.addFailure(test, exc_info) + else: + result.addError(test, exc_info) + + def _addExpectedFailure(self, result, exc_info): + try: + addExpectedFailure = result.addExpectedFailure + except AttributeError: + warnings.warn("TestResult has no addExpectedFailure method, reporting as passes", + RuntimeWarning) result.addSuccess(self) + else: + addExpectedFailure(self, exc_info) - def _executeTestPart(self, function, outcome, isTest=False): + def _addUnexpectedSuccess(self, result): try: - function() - except KeyboardInterrupt: - raise - except SkipTest as e: - outcome.success = False - outcome.skipped = str(e) - except _UnexpectedSuccess: - exc_info = sys.exc_info() - outcome.success = False - if isTest: - outcome.unexpectedSuccess = exc_info - else: - outcome.errors.append(exc_info) - except _ExpectedFailure: - outcome.success = False - exc_info = sys.exc_info() - if isTest: - outcome.expectedFailure = exc_info - else: - outcome.errors.append(exc_info) - except self.failureException: - outcome.success = False - outcome.failures.append(sys.exc_info()) - exc_info = sys.exc_info() - except: - outcome.success = False - outcome.errors.append(sys.exc_info()) + addUnexpectedSuccess = result.addUnexpectedSuccess + except AttributeError: + warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failure", + RuntimeWarning) + # We need to pass an actual exception and traceback to addFailure, + # otherwise the legacy result can choke. + try: + raise _UnexpectedSuccess from None + except _UnexpectedSuccess: + result.addFailure(self, sys.exc_info()) + else: + addUnexpectedSuccess(self) def run(self, result=None): orig_result = result @@ -426,46 +559,38 @@ class TestCase(object): try: skip_why = (getattr(self.__class__, '__unittest_skip_why__', '') or getattr(testMethod, '__unittest_skip_why__', '')) - self._addSkip(result, skip_why) + self._addSkip(result, self, skip_why) finally: result.stopTest(self) return + expecting_failure = getattr(testMethod, + "__unittest_expecting_failure__", False) + outcome = _Outcome(result) try: - outcome = _Outcome() - self._outcomeForDoCleanups = outcome + self._outcome = outcome - self._executeTestPart(self.setUp, outcome) + with outcome.testPartExecutor(self): + self.setUp() if outcome.success: - self._executeTestPart(testMethod, outcome, isTest=True) - self._executeTestPart(self.tearDown, outcome) + outcome.expecting_failure = expecting_failure + with outcome.testPartExecutor(self, isTest=True): + testMethod() + outcome.expecting_failure = False + with outcome.testPartExecutor(self): + self.tearDown() self.doCleanups() + for test, reason in outcome.skipped: + self._addSkip(result, test, reason) + self._feedErrorsToResult(result, outcome.errors) if outcome.success: - result.addSuccess(self) - else: - if outcome.skipped is not None: - self._addSkip(result, outcome.skipped) - for exc_info in outcome.errors: - result.addError(self, exc_info) - for exc_info in outcome.failures: - result.addFailure(self, exc_info) - if outcome.unexpectedSuccess is not None: - addUnexpectedSuccess = getattr(result, 'addUnexpectedSuccess', None) - if addUnexpectedSuccess is not None: - addUnexpectedSuccess(self) - else: - warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failures", - RuntimeWarning) - result.addFailure(self, outcome.unexpectedSuccess) - - if outcome.expectedFailure is not None: - addExpectedFailure = getattr(result, 'addExpectedFailure', None) - if addExpectedFailure is not None: - addExpectedFailure(self, outcome.expectedFailure) + if expecting_failure: + if outcome.expectedFailure: + self._addExpectedFailure(result, outcome.expectedFailure) else: - warnings.warn("TestResult has no addExpectedFailure method, reporting as passes", - RuntimeWarning) - result.addSuccess(self) + self._addUnexpectedSuccess(result) + else: + result.addSuccess(self) return result finally: result.stopTest(self) @@ -474,14 +599,23 @@ class TestCase(object): if stopTestRun is not None: stopTestRun() + # explicitly break reference cycles: + # outcome.errors -> frame -> outcome -> outcome.errors + # outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure + outcome.errors.clear() + outcome.expectedFailure = None + + # clear the outcome, no more needed + self._outcome = None + def doCleanups(self): """Execute all cleanup functions. Normally called for you after tearDown.""" - outcome = self._outcomeForDoCleanups or _Outcome() + outcome = self._outcome or _Outcome() while self._cleanups: function, args, kwargs = self._cleanups.pop() - part = lambda: function(*args, **kwargs) - self._executeTestPart(part, outcome) + with outcome.testPartExecutor(self): + function(*args, **kwargs) # return this for backwards compatibility # even though we no longer us it internally @@ -600,6 +734,28 @@ class TestCase(object): context = _AssertWarnsContext(expected_warning, self, callable_obj) return context.handle('assertWarns', callable_obj, args, kwargs) + def assertLogs(self, logger=None, level=None): + """Fail unless a log message of level *level* or higher is emitted + on *logger_name* or its children. If omitted, *level* defaults to + INFO and *logger* defaults to the root logger. + + This method must be used as a context manager, and will yield + a recording object with two attributes: `output` and `records`. + At the end of the context manager, the `output` attribute will + be a list of the matching formatted log messages and the + `records` attribute will be a list of the corresponding LogRecord + objects. + + Example:: + + with self.assertLogs('foo', level='INFO') as cm: + logging.getLogger('foo').info('first message') + logging.getLogger('foo.bar').error('second message') + self.assertEqual(cm.output, ['INFO:foo:first message', + 'ERROR:foo.bar:second message']) + """ + return _AssertLogsContext(self, logger, level) + def _getAssertEqualityFunc(self, first, second): """Get a detailed comparison function for the types of the two args. @@ -629,7 +785,7 @@ class TestCase(object): def _baseAssertEqual(self, first, second, msg=None): """The default assertEqual implementation, not type specific.""" if not first == second: - standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second)) + standardMsg = '%s != %s' % _common_shorten_repr(first, second) msg = self._formatMessage(msg, standardMsg) raise self.failureException(msg) @@ -764,14 +920,9 @@ class TestCase(object): if seq1 == seq2: return - seq1_repr = safe_repr(seq1) - seq2_repr = safe_repr(seq2) - if len(seq1_repr) > 30: - seq1_repr = seq1_repr[:30] + '...' - if len(seq2_repr) > 30: - seq2_repr = seq2_repr[:30] + '...' - elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr) - differing = '%ss differ: %s != %s\n' % elements + differing = '%ss differ: %s != %s\n' % ( + (seq_type_name.capitalize(),) + + _common_shorten_repr(seq1, seq2)) for i in range(min(len1, len2)): try: @@ -929,7 +1080,7 @@ class TestCase(object): self.assertIsInstance(d2, dict, 'Second argument is not a dictionary') if d1 != d2: - standardMsg = '%s != %s' % (safe_repr(d1, True), safe_repr(d2, True)) + standardMsg = '%s != %s' % _common_shorten_repr(d1, d2) diff = ('\n' + '\n'.join(difflib.ndiff( pprint.pformat(d1).splitlines(), pprint.pformat(d2).splitlines()))) @@ -1013,8 +1164,7 @@ class TestCase(object): if len(firstlines) == 1 and first.strip('\r\n') == first: firstlines = [first + '\n'] secondlines = [second + '\n'] - standardMsg = '%s != %s' % (safe_repr(first, True), - safe_repr(second, True)) + standardMsg = '%s != %s' % _common_shorten_repr(first, second) diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines)) standardMsg = self._truncateMessage(standardMsg, diff) self.fail(self._formatMessage(msg, standardMsg)) @@ -1192,9 +1342,6 @@ class FunctionTestCase(TestCase): self._testFunc == other._testFunc and \ self._description == other._description - def __ne__(self, other): - return not self == other - def __hash__(self): return hash((type(self), self._setUpFunc, self._tearDownFunc, self._testFunc, self._description)) @@ -1212,3 +1359,39 @@ class FunctionTestCase(TestCase): return self._description doc = self._testFunc.__doc__ return doc and doc.split("\n")[0].strip() or None + + +class _SubTest(TestCase): + + def __init__(self, test_case, message, params): + super().__init__() + self._message = message + self.test_case = test_case + self.params = params + self.failureException = test_case.failureException + + def runTest(self): + raise NotImplementedError("subtests cannot be run directly") + + def _subDescription(self): + parts = [] + if self._message: + parts.append("[{}]".format(self._message)) + if self.params: + params_desc = ', '.join( + "{}={!r}".format(k, v) + for (k, v) in sorted(self.params.items())) + parts.append("({})".format(params_desc)) + return " ".join(parts) or '(<subtest>)' + + def id(self): + return "{} {}".format(self.test_case.id(), self._subDescription()) + + def shortDescription(self): + """Returns a one-line description of the subtest, or None if no + description has been provided. + """ + return self.test_case.shortDescription() + + def __str__(self): + return "{} {}".format(self.test_case, self._subDescription()) |