diff options
Diffstat (limited to 'Lib/unittest/case.py')
-rw-r--r-- | Lib/unittest/case.py | 133 |
1 files changed, 52 insertions, 81 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index bd47493..2e5cb04 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -9,8 +9,7 @@ import warnings import collections from . import result -from .util import (strclass, safe_repr, sorted_list_difference, - unorderable_list_difference, _count_diff_all_purpose, +from .util import (strclass, safe_repr, _count_diff_all_purpose, _count_diff_hashable) __unittest = True @@ -104,9 +103,9 @@ def expectedFailure(func): class _AssertRaisesBaseContext(object): def __init__(self, expected, test_case, callable_obj=None, - expected_regex=None): + expected_regex=None): self.expected = expected - self.failureException = test_case.failureException + self.test_case = test_case if callable_obj is not None: try: self.obj_name = callable_obj.__name__ @@ -117,6 +116,24 @@ class _AssertRaisesBaseContext(object): if isinstance(expected_regex, (bytes, str)): expected_regex = re.compile(expected_regex) self.expected_regex = expected_regex + self.msg = None + + def _raiseFailure(self, standardMsg): + msg = self.test_case._formatMessage(self.msg, standardMsg) + raise self.test_case.failureException(msg) + + def handle(self, name, callable_obj, args, kwargs): + """ + If callable_obj is None, assertRaises/Warns is being used as a + context manager, so check for a 'msg' kwarg and return self. + If callable_obj is not None, call it passing args and kwargs. + """ + if callable_obj is None: + self.msg = kwargs.pop('msg', None) + return self + with self: + callable_obj(*args, **kwargs) + class _AssertRaisesContext(_AssertRaisesBaseContext): @@ -132,11 +149,10 @@ class _AssertRaisesContext(_AssertRaisesBaseContext): except AttributeError: exc_name = str(self.expected) if self.obj_name: - raise self.failureException("{0} not raised by {1}" - .format(exc_name, self.obj_name)) + self._raiseFailure("{} not raised by {}".format(exc_name, + self.obj_name)) else: - raise self.failureException("{0} not raised" - .format(exc_name)) + self._raiseFailure("{} not raised".format(exc_name)) if not issubclass(exc_type, self.expected): # let unexpected exceptions pass through return False @@ -147,8 +163,8 @@ class _AssertRaisesContext(_AssertRaisesBaseContext): expected_regex = self.expected_regex if not expected_regex.search(str(exc_value)): - raise self.failureException('"%s" does not match "%s"' % - (expected_regex.pattern, str(exc_value))) + self._raiseFailure('"{}" does not match "{}"'.format( + expected_regex.pattern, str(exc_value))) return True @@ -192,14 +208,13 @@ class _AssertWarnsContext(_AssertRaisesBaseContext): return # Now we simply try to choose a helpful failure message if first_matching is not None: - raise self.failureException('"%s" does not match "%s"' % - (self.expected_regex.pattern, str(first_matching))) + self._raiseFailure('"{}" does not match "{}"'.format( + self.expected_regex.pattern, str(first_matching))) if self.obj_name: - raise self.failureException("{0} not triggered by {1}" - .format(exc_name, self.obj_name)) + self._raiseFailure("{} not triggered by {}".format(exc_name, + self.obj_name)) else: - raise self.failureException("{0} not triggered" - .format(exc_name)) + self._raiseFailure("{} not triggered".format(exc_name)) class TestCase(object): @@ -452,7 +467,7 @@ class TestCase(object): warnings.warn("TestResult has no addExpectedFailure method, reporting as passes", RuntimeWarning) result.addSuccess(self) - + return result finally: result.stopTest(self) if orig_result is None: @@ -526,7 +541,6 @@ class TestCase(object): except UnicodeDecodeError: return '%s : %s' % (safe_repr(standardMsg), safe_repr(msg)) - def assertRaises(self, excClass, callableObj=None, *args, **kwargs): """Fail unless an exception of class excClass is thrown by callableObj when invoked with arguments args and keyword @@ -541,6 +555,9 @@ class TestCase(object): with self.assertRaises(SomeException): do_something() + An optional keyword argument 'msg' can be provided when assertRaises + is used as a context object. + The context manager keeps a reference to the exception as the 'exception' attribute. This allows you to inspect the exception after the assertion:: @@ -551,25 +568,25 @@ class TestCase(object): self.assertEqual(the_exception.error_code, 3) """ context = _AssertRaisesContext(excClass, self, callableObj) - if callableObj is None: - return context - with context: - callableObj(*args, **kwargs) + return context.handle('assertRaises', callableObj, args, kwargs) def assertWarns(self, expected_warning, callable_obj=None, *args, **kwargs): """Fail unless a warning of class warnClass is triggered - by callableObj when invoked with arguments args and keyword + by callable_obj when invoked with arguments args and keyword arguments kwargs. If a different type of warning is triggered, it will not be handled: depending on the other warning filtering rules in effect, it might be silenced, printed out, or raised as an exception. - If called with callableObj omitted or None, will return a + If called with callable_obj omitted or None, will return a context object used like this:: with self.assertWarns(SomeWarning): do_something() + An optional keyword argument 'msg' can be provided when assertWarns + is used as a context object. + The context manager keeps a reference to the first matching warning as the 'warning' attribute; similarly, the 'filename' and 'lineno' attributes give you information about the line @@ -582,10 +599,7 @@ class TestCase(object): self.assertEqual(the_warning.some_attribute, 147) """ context = _AssertWarnsContext(expected_warning, self, callable_obj) - if callable_obj is None: - return context - with context: - callable_obj(*args, **kwargs) + return context.handle('assertWarns', callable_obj, args, kwargs) def _getAssertEqualityFunc(self, first, second): """Get a detailed comparison function for the types of the two args. @@ -722,7 +736,7 @@ class TestCase(object): msg: Optional message to use on failure instead of a list of differences. """ - if seq_type != None: + if seq_type is not None: seq_type_name = seq_type.__name__ if not isinstance(seq1, seq_type): raise self.failureException('First sequence is not a %s: %s' @@ -951,48 +965,6 @@ class TestCase(object): self.fail(self._formatMessage(msg, standardMsg)) - def assertSameElements(self, expected_seq, actual_seq, msg=None): - """An unordered sequence specific comparison. - - Raises with an error message listing which elements of expected_seq - are missing from actual_seq and vice versa if any. - - Duplicate elements are ignored when comparing *expected_seq* and - *actual_seq*. It is the equivalent of ``assertEqual(set(expected), - set(actual))`` but it works with sequences of unhashable objects as - well. - """ - warnings.warn('assertSameElements is deprecated', - DeprecationWarning) - try: - expected = set(expected_seq) - actual = set(actual_seq) - missing = sorted(expected.difference(actual)) - unexpected = sorted(actual.difference(expected)) - except TypeError: - # Fall back to slower list-compare if any of the objects are - # not hashable. - expected = list(expected_seq) - actual = list(actual_seq) - try: - expected.sort() - actual.sort() - except TypeError: - missing, unexpected = unorderable_list_difference(expected, - actual) - else: - missing, unexpected = sorted_list_difference(expected, actual) - errors = [] - if missing: - errors.append('Expected, but missing:\n %s' % - safe_repr(missing)) - if unexpected: - errors.append('Unexpected, but present:\n %s' % - safe_repr(unexpected)) - if errors: - standardMsg = '\n'.join(errors) - self.fail(self._formatMessage(msg, standardMsg)) - def assertCountEqual(self, first, second, msg=None): """An unordered sequence comparison asserting that the same elements, @@ -1037,8 +1009,8 @@ class TestCase(object): if (len(first) > self._diffThreshold or len(second) > self._diffThreshold): self._baseAssertEqual(first, second, msg) - firstlines = first.splitlines(True) - secondlines = second.splitlines(True) + firstlines = first.splitlines(keepends=True) + secondlines = second.splitlines(keepends=True) if len(firstlines) == 1 and first.strip('\r\n') == first: firstlines = [first + '\n'] secondlines = [second + '\n'] @@ -1106,15 +1078,15 @@ class TestCase(object): expected_regex: Regex (re pattern object or string) expected to be found in error message. callable_obj: Function to be called. + msg: Optional message used in case of failure. Can only be used + when assertRaisesRegex is used as a context manager. args: Extra args. kwargs: Extra kwargs. """ context = _AssertRaisesContext(expected_exception, self, callable_obj, expected_regex) - if callable_obj is None: - return context - with context: - callable_obj(*args, **kwargs) + + return context.handle('assertRaisesRegex', callable_obj, args, kwargs) def assertWarnsRegex(self, expected_warning, expected_regex, callable_obj=None, *args, **kwargs): @@ -1128,15 +1100,14 @@ class TestCase(object): expected_regex: Regex (re pattern object or string) expected to be found in error message. callable_obj: Function to be called. + msg: Optional message used in case of failure. Can only be used + when assertWarnsRegex is used as a context manager. args: Extra args. kwargs: Extra kwargs. """ context = _AssertWarnsContext(expected_warning, self, callable_obj, expected_regex) - if callable_obj is None: - return context - with context: - callable_obj(*args, **kwargs) + return context.handle('assertWarnsRegex', callable_obj, args, kwargs) def assertRegex(self, text, expected_regex, msg=None): """Fail the test unless the text matches the regular expression.""" |