diff options
author | Antoine Pitrou <solipsis@pitrou.net> | 2010-09-06 19:25:46 (GMT) |
---|---|---|
committer | Antoine Pitrou <solipsis@pitrou.net> | 2010-09-06 19:25:46 (GMT) |
commit | 4bc12ef47dd57abda134fc0e90f946d862d8989e (patch) | |
tree | 2737280117621973c50edcf4483dcaf173990f10 /Lib/unittest | |
parent | 972ee13e037432497fa003d4a786b2342a38db94 (diff) | |
download | cpython-4bc12ef47dd57abda134fc0e90f946d862d8989e.zip cpython-4bc12ef47dd57abda134fc0e90f946d862d8989e.tar.gz cpython-4bc12ef47dd57abda134fc0e90f946d862d8989e.tar.bz2 |
Issue #9754: Similarly to assertRaises and assertRaisesRegexp, unittest
test cases now also have assertWarns and assertWarnsRegexp methods to
check that a given warning type was triggered by the code under test.
Diffstat (limited to 'Lib/unittest')
-rw-r--r-- | Lib/unittest/case.py | 114 | ||||
-rw-r--r-- | Lib/unittest/test/test_case.py | 134 |
2 files changed, 244 insertions, 4 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 03346a1..761ac46 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -90,8 +90,7 @@ def expectedFailure(func): return wrapper -class _AssertRaisesContext(object): - """A context manager used to implement TestCase.assertRaises* methods.""" +class _AssertRaisesBaseContext(object): def __init__(self, expected, test_case, callable_obj=None, expected_regexp=None): @@ -104,8 +103,14 @@ class _AssertRaisesContext(object): self.obj_name = str(callable_obj) else: self.obj_name = None + if isinstance(expected_regexp, (bytes, str)): + expected_regexp = re.compile(expected_regexp) self.expected_regexp = expected_regexp + +class _AssertRaisesContext(_AssertRaisesBaseContext): + """A context manager used to implement TestCase.assertRaises* methods.""" + def __enter__(self): return self @@ -130,14 +135,62 @@ class _AssertRaisesContext(object): return True expected_regexp = self.expected_regexp - if isinstance(expected_regexp, (bytes, str)): - expected_regexp = re.compile(expected_regexp) if not expected_regexp.search(str(exc_value)): raise self.failureException('"%s" does not match "%s"' % (expected_regexp.pattern, str(exc_value))) return True +class _AssertWarnsContext(_AssertRaisesBaseContext): + """A context manager used to implement TestCase.assertWarns* methods.""" + + def __enter__(self): + # The __warningregistry__'s need to be in a pristine state for tests + # to work properly. + for v in sys.modules.values(): + if getattr(v, '__warningregistry__', None): + v.__warningregistry__ = {} + self.warnings_manager = warnings.catch_warnings(record=True) + self.warnings = self.warnings_manager.__enter__() + warnings.simplefilter("always", self.expected) + return self + + def __exit__(self, exc_type, exc_value, tb): + self.warnings_manager.__exit__(exc_type, exc_value, tb) + if exc_type is not None: + # let unexpected exceptions pass through + return + try: + exc_name = self.expected.__name__ + except AttributeError: + exc_name = str(self.expected) + first_matching = None + for m in self.warnings: + w = m.message + if not isinstance(w, self.expected): + continue + if first_matching is None: + first_matching = w + if (self.expected_regexp is not None and + not self.expected_regexp.search(str(w))): + continue + # store warning for later retrieval + self.warning = w + self.filename = m.filename + self.lineno = m.lineno + return + # Now we simply try to choose a helpful failure message + if first_matching is not None: + raise self.failureException('"%s" does not match "%s"' % + (self.expected_regexp.pattern, str(first_matching))) + if self.obj_name: + raise self.failureException("{0} not triggered by {1}" + .format(exc_name, self.obj_name)) + else: + raise self.failureException("{0} not triggered" + .format(exc_name)) + + class TestCase(object): """A class whose instances are single test cases. @@ -464,6 +517,37 @@ class TestCase(object): with context: callableObj(*args, **kwargs) + def assertWarns(self, expected_warning, callable_obj=None, *args, **kwargs): + """Fail unless a warning of class warnClass is triggered + by callableObj when invoked with arguments args and keyword + arguments kwargs. If a different type of warning is + triggered, it will not be handled: depending on the other + warning filtering rules in effect, it might be silenced, printed + out, or raised as an exception. + + If called with callableObj omitted or None, will return a + context object used like this:: + + with self.assertWarns(SomeWarning): + do_something() + + The context manager keeps a reference to the first matching + warning as the 'warning' attribute; similarly, the 'filename' + and 'lineno' attributes give you information about the line + of Python code from which the warning was triggered. + This allows you to inspect the warning after the assertion:: + + with self.assertWarns(SomeWarning) as cm: + do_something() + the_warning = cm.warning + self.assertEqual(the_warning.some_attribute, 147) + """ + context = _AssertWarnsContext(expected_warning, self, callable_obj) + if callable_obj is None: + return context + with context: + callable_obj(*args, **kwargs) + def _getAssertEqualityFunc(self, first, second): """Get a detailed comparison function for the types of the two args. @@ -1019,6 +1103,28 @@ class TestCase(object): with context: callable_obj(*args, **kwargs) + def assertWarnsRegexp(self, expected_warning, expected_regexp, + callable_obj=None, *args, **kwargs): + """Asserts that the message in a triggered warning matches a regexp. + Basic functioning is similar to assertWarns() with the addition + that only warnings whose messages also match the regular expression + are considered successful matches. + + Args: + expected_warning: Warning class expected to be triggered. + expected_regexp: Regexp (re pattern object or string) expected + to be found in error message. + callable_obj: Function to be called. + args: Extra args. + kwargs: Extra kwargs. + """ + context = _AssertWarnsContext(expected_warning, self, callable_obj, + expected_regexp) + if callable_obj is None: + return context + with context: + callable_obj(*args, **kwargs) + def assertRegexpMatches(self, text, expected_regexp, msg=None): """Fail the test unless the text matches the regular expression.""" if isinstance(expected_regexp, (str, bytes)): diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py index 1800f2e..1bd839f 100644 --- a/Lib/unittest/test/test_case.py +++ b/Lib/unittest/test/test_case.py @@ -2,6 +2,8 @@ import difflib import pprint import re import sys +import warnings +import inspect from copy import deepcopy from test import support @@ -917,6 +919,138 @@ test case self.assertIsInstance(e, ExceptionMock) self.assertEqual(e.args[0], v) + def testAssertWarnsCallable(self): + def _runtime_warn(): + warnings.warn("foo", RuntimeWarning) + # Success when the right warning is triggered, even several times + self.assertWarns(RuntimeWarning, _runtime_warn) + self.assertWarns(RuntimeWarning, _runtime_warn) + # A tuple of warning classes is accepted + self.assertWarns((DeprecationWarning, RuntimeWarning), _runtime_warn) + # *args and **kwargs also work + self.assertWarns(RuntimeWarning, + warnings.warn, "foo", category=RuntimeWarning) + # Failure when no warning is triggered + with self.assertRaises(self.failureException): + self.assertWarns(RuntimeWarning, lambda: 0) + # Failure when another warning is triggered + with warnings.catch_warnings(): + # Force default filter (in case tests are run with -We) + warnings.simplefilter("default", RuntimeWarning) + with self.assertRaises(self.failureException): + self.assertWarns(DeprecationWarning, _runtime_warn) + # Filters for other warnings are not modified + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + with self.assertRaises(RuntimeWarning): + self.assertWarns(DeprecationWarning, _runtime_warn) + + def testAssertWarnsContext(self): + # Believe it or not, it is preferrable to duplicate all tests above, + # to make sure the __warningregistry__ $@ is circumvented correctly. + def _runtime_warn(): + warnings.warn("foo", RuntimeWarning) + _runtime_warn_lineno = inspect.getsourcelines(_runtime_warn)[1] + with self.assertWarns(RuntimeWarning) as cm: + _runtime_warn() + # A tuple of warning classes is accepted + with self.assertWarns((DeprecationWarning, RuntimeWarning)) as cm: + _runtime_warn() + # The context manager exposes various useful attributes + self.assertIsInstance(cm.warning, RuntimeWarning) + self.assertEqual(cm.warning.args[0], "foo") + self.assertIn("test_case.py", cm.filename) + self.assertEqual(cm.lineno, _runtime_warn_lineno + 1) + # Same with several warnings + with self.assertWarns(RuntimeWarning): + _runtime_warn() + _runtime_warn() + with self.assertWarns(RuntimeWarning): + warnings.warn("foo", category=RuntimeWarning) + # Failure when no warning is triggered + with self.assertRaises(self.failureException): + with self.assertWarns(RuntimeWarning): + pass + # Failure when another warning is triggered + with warnings.catch_warnings(): + # Force default filter (in case tests are run with -We) + warnings.simplefilter("default", RuntimeWarning) + with self.assertRaises(self.failureException): + with self.assertWarns(DeprecationWarning): + _runtime_warn() + # Filters for other warnings are not modified + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + with self.assertRaises(RuntimeWarning): + with self.assertWarns(DeprecationWarning): + _runtime_warn() + + def testAssertWarnsRegexpCallable(self): + def _runtime_warn(msg): + warnings.warn(msg, RuntimeWarning) + self.assertWarnsRegexp(RuntimeWarning, "o+", + _runtime_warn, "foox") + # Failure when no warning is triggered + with self.assertRaises(self.failureException): + self.assertWarnsRegexp(RuntimeWarning, "o+", + lambda: 0) + # Failure when another warning is triggered + with warnings.catch_warnings(): + # Force default filter (in case tests are run with -We) + warnings.simplefilter("default", RuntimeWarning) + with self.assertRaises(self.failureException): + self.assertWarnsRegexp(DeprecationWarning, "o+", + _runtime_warn, "foox") + # Failure when message doesn't match + with self.assertRaises(self.failureException): + self.assertWarnsRegexp(RuntimeWarning, "o+", + _runtime_warn, "barz") + # A little trickier: we ask RuntimeWarnings to be raised, and then + # check for some of them. It is implementation-defined whether + # non-matching RuntimeWarnings are simply re-raised, or produce a + # failureException. + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + with self.assertRaises((RuntimeWarning, self.failureException)): + self.assertWarnsRegexp(RuntimeWarning, "o+", + _runtime_warn, "barz") + + def testAssertWarnsRegexpContext(self): + # Same as above, but with assertWarnsRegexp as a context manager + def _runtime_warn(msg): + warnings.warn(msg, RuntimeWarning) + _runtime_warn_lineno = inspect.getsourcelines(_runtime_warn)[1] + with self.assertWarnsRegexp(RuntimeWarning, "o+") as cm: + _runtime_warn("foox") + self.assertIsInstance(cm.warning, RuntimeWarning) + self.assertEqual(cm.warning.args[0], "foox") + self.assertIn("test_case.py", cm.filename) + self.assertEqual(cm.lineno, _runtime_warn_lineno + 1) + # Failure when no warning is triggered + with self.assertRaises(self.failureException): + with self.assertWarnsRegexp(RuntimeWarning, "o+"): + pass + # Failure when another warning is triggered + with warnings.catch_warnings(): + # Force default filter (in case tests are run with -We) + warnings.simplefilter("default", RuntimeWarning) + with self.assertRaises(self.failureException): + with self.assertWarnsRegexp(DeprecationWarning, "o+"): + _runtime_warn("foox") + # Failure when message doesn't match + with self.assertRaises(self.failureException): + with self.assertWarnsRegexp(RuntimeWarning, "o+"): + _runtime_warn("barz") + # A little trickier: we ask RuntimeWarnings to be raised, and then + # check for some of them. It is implementation-defined whether + # non-matching RuntimeWarnings are simply re-raised, or produce a + # failureException. + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + with self.assertRaises((RuntimeWarning, self.failureException)): + with self.assertWarnsRegexp(RuntimeWarning, "o+"): + _runtime_warn("barz") + def testSynonymAssertMethodNames(self): """Test undocumented method name synonyms. |