diff options
Diffstat (limited to 'Lib/unittest/case.py')
-rw-r--r-- | Lib/unittest/case.py | 114 |
1 files changed, 110 insertions, 4 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 03346a1..761ac46 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -90,8 +90,7 @@ def expectedFailure(func): return wrapper -class _AssertRaisesContext(object): - """A context manager used to implement TestCase.assertRaises* methods.""" +class _AssertRaisesBaseContext(object): def __init__(self, expected, test_case, callable_obj=None, expected_regexp=None): @@ -104,8 +103,14 @@ class _AssertRaisesContext(object): self.obj_name = str(callable_obj) else: self.obj_name = None + if isinstance(expected_regexp, (bytes, str)): + expected_regexp = re.compile(expected_regexp) self.expected_regexp = expected_regexp + +class _AssertRaisesContext(_AssertRaisesBaseContext): + """A context manager used to implement TestCase.assertRaises* methods.""" + def __enter__(self): return self @@ -130,14 +135,62 @@ class _AssertRaisesContext(object): return True expected_regexp = self.expected_regexp - if isinstance(expected_regexp, (bytes, str)): - expected_regexp = re.compile(expected_regexp) if not expected_regexp.search(str(exc_value)): raise self.failureException('"%s" does not match "%s"' % (expected_regexp.pattern, str(exc_value))) return True +class _AssertWarnsContext(_AssertRaisesBaseContext): + """A context manager used to implement TestCase.assertWarns* methods.""" + + def __enter__(self): + # The __warningregistry__'s need to be in a pristine state for tests + # to work properly. + for v in sys.modules.values(): + if getattr(v, '__warningregistry__', None): + v.__warningregistry__ = {} + self.warnings_manager = warnings.catch_warnings(record=True) + self.warnings = self.warnings_manager.__enter__() + warnings.simplefilter("always", self.expected) + return self + + def __exit__(self, exc_type, exc_value, tb): + self.warnings_manager.__exit__(exc_type, exc_value, tb) + if exc_type is not None: + # let unexpected exceptions pass through + return + try: + exc_name = self.expected.__name__ + except AttributeError: + exc_name = str(self.expected) + first_matching = None + for m in self.warnings: + w = m.message + if not isinstance(w, self.expected): + continue + if first_matching is None: + first_matching = w + if (self.expected_regexp is not None and + not self.expected_regexp.search(str(w))): + continue + # store warning for later retrieval + self.warning = w + self.filename = m.filename + self.lineno = m.lineno + return + # Now we simply try to choose a helpful failure message + if first_matching is not None: + raise self.failureException('"%s" does not match "%s"' % + (self.expected_regexp.pattern, str(first_matching))) + if self.obj_name: + raise self.failureException("{0} not triggered by {1}" + .format(exc_name, self.obj_name)) + else: + raise self.failureException("{0} not triggered" + .format(exc_name)) + + class TestCase(object): """A class whose instances are single test cases. @@ -464,6 +517,37 @@ class TestCase(object): with context: callableObj(*args, **kwargs) + def assertWarns(self, expected_warning, callable_obj=None, *args, **kwargs): + """Fail unless a warning of class warnClass is triggered + by callableObj when invoked with arguments args and keyword + arguments kwargs. If a different type of warning is + triggered, it will not be handled: depending on the other + warning filtering rules in effect, it might be silenced, printed + out, or raised as an exception. + + If called with callableObj omitted or None, will return a + context object used like this:: + + with self.assertWarns(SomeWarning): + do_something() + + The context manager keeps a reference to the first matching + warning as the 'warning' attribute; similarly, the 'filename' + and 'lineno' attributes give you information about the line + of Python code from which the warning was triggered. + This allows you to inspect the warning after the assertion:: + + with self.assertWarns(SomeWarning) as cm: + do_something() + the_warning = cm.warning + self.assertEqual(the_warning.some_attribute, 147) + """ + context = _AssertWarnsContext(expected_warning, self, callable_obj) + if callable_obj is None: + return context + with context: + callable_obj(*args, **kwargs) + def _getAssertEqualityFunc(self, first, second): """Get a detailed comparison function for the types of the two args. @@ -1019,6 +1103,28 @@ class TestCase(object): with context: callable_obj(*args, **kwargs) + def assertWarnsRegexp(self, expected_warning, expected_regexp, + callable_obj=None, *args, **kwargs): + """Asserts that the message in a triggered warning matches a regexp. + Basic functioning is similar to assertWarns() with the addition + that only warnings whose messages also match the regular expression + are considered successful matches. + + Args: + expected_warning: Warning class expected to be triggered. + expected_regexp: Regexp (re pattern object or string) expected + to be found in error message. + callable_obj: Function to be called. + args: Extra args. + kwargs: Extra kwargs. + """ + context = _AssertWarnsContext(expected_warning, self, callable_obj, + expected_regexp) + if callable_obj is None: + return context + with context: + callable_obj(*args, **kwargs) + def assertRegexpMatches(self, text, expected_regexp, msg=None): """Fail the test unless the text matches the regular expression.""" if isinstance(expected_regexp, (str, bytes)): |