summaryrefslogtreecommitdiffstats
path: root/Lib/unittest
diff options
context:
space:
mode:
authorAntoine Pitrou <solipsis@pitrou.net>2010-09-06 19:25:46 (GMT)
committerAntoine Pitrou <solipsis@pitrou.net>2010-09-06 19:25:46 (GMT)
commit4bc12ef47dd57abda134fc0e90f946d862d8989e (patch)
tree2737280117621973c50edcf4483dcaf173990f10 /Lib/unittest
parent972ee13e037432497fa003d4a786b2342a38db94 (diff)
downloadcpython-4bc12ef47dd57abda134fc0e90f946d862d8989e.zip
cpython-4bc12ef47dd57abda134fc0e90f946d862d8989e.tar.gz
cpython-4bc12ef47dd57abda134fc0e90f946d862d8989e.tar.bz2
Issue #9754: Similarly to assertRaises and assertRaisesRegexp, unittest
test cases now also have assertWarns and assertWarnsRegexp methods to check that a given warning type was triggered by the code under test.
Diffstat (limited to 'Lib/unittest')
-rw-r--r--Lib/unittest/case.py114
-rw-r--r--Lib/unittest/test/test_case.py134
2 files changed, 244 insertions, 4 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index 03346a1..761ac46 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -90,8 +90,7 @@ def expectedFailure(func):
return wrapper
-class _AssertRaisesContext(object):
- """A context manager used to implement TestCase.assertRaises* methods."""
+class _AssertRaisesBaseContext(object):
def __init__(self, expected, test_case, callable_obj=None,
expected_regexp=None):
@@ -104,8 +103,14 @@ class _AssertRaisesContext(object):
self.obj_name = str(callable_obj)
else:
self.obj_name = None
+ if isinstance(expected_regexp, (bytes, str)):
+ expected_regexp = re.compile(expected_regexp)
self.expected_regexp = expected_regexp
+
+class _AssertRaisesContext(_AssertRaisesBaseContext):
+ """A context manager used to implement TestCase.assertRaises* methods."""
+
def __enter__(self):
return self
@@ -130,14 +135,62 @@ class _AssertRaisesContext(object):
return True
expected_regexp = self.expected_regexp
- if isinstance(expected_regexp, (bytes, str)):
- expected_regexp = re.compile(expected_regexp)
if not expected_regexp.search(str(exc_value)):
raise self.failureException('"%s" does not match "%s"' %
(expected_regexp.pattern, str(exc_value)))
return True
+class _AssertWarnsContext(_AssertRaisesBaseContext):
+ """A context manager used to implement TestCase.assertWarns* methods."""
+
+ def __enter__(self):
+ # The __warningregistry__'s need to be in a pristine state for tests
+ # to work properly.
+ for v in sys.modules.values():
+ if getattr(v, '__warningregistry__', None):
+ v.__warningregistry__ = {}
+ self.warnings_manager = warnings.catch_warnings(record=True)
+ self.warnings = self.warnings_manager.__enter__()
+ warnings.simplefilter("always", self.expected)
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ self.warnings_manager.__exit__(exc_type, exc_value, tb)
+ if exc_type is not None:
+ # let unexpected exceptions pass through
+ return
+ try:
+ exc_name = self.expected.__name__
+ except AttributeError:
+ exc_name = str(self.expected)
+ first_matching = None
+ for m in self.warnings:
+ w = m.message
+ if not isinstance(w, self.expected):
+ continue
+ if first_matching is None:
+ first_matching = w
+ if (self.expected_regexp is not None and
+ not self.expected_regexp.search(str(w))):
+ continue
+ # store warning for later retrieval
+ self.warning = w
+ self.filename = m.filename
+ self.lineno = m.lineno
+ return
+ # Now we simply try to choose a helpful failure message
+ if first_matching is not None:
+ raise self.failureException('"%s" does not match "%s"' %
+ (self.expected_regexp.pattern, str(first_matching)))
+ if self.obj_name:
+ raise self.failureException("{0} not triggered by {1}"
+ .format(exc_name, self.obj_name))
+ else:
+ raise self.failureException("{0} not triggered"
+ .format(exc_name))
+
+
class TestCase(object):
"""A class whose instances are single test cases.
@@ -464,6 +517,37 @@ class TestCase(object):
with context:
callableObj(*args, **kwargs)
+ def assertWarns(self, expected_warning, callable_obj=None, *args, **kwargs):
+ """Fail unless a warning of class warnClass is triggered
+ by callableObj when invoked with arguments args and keyword
+ arguments kwargs. If a different type of warning is
+ triggered, it will not be handled: depending on the other
+ warning filtering rules in effect, it might be silenced, printed
+ out, or raised as an exception.
+
+ If called with callableObj omitted or None, will return a
+ context object used like this::
+
+ with self.assertWarns(SomeWarning):
+ do_something()
+
+ The context manager keeps a reference to the first matching
+ warning as the 'warning' attribute; similarly, the 'filename'
+ and 'lineno' attributes give you information about the line
+ of Python code from which the warning was triggered.
+ This allows you to inspect the warning after the assertion::
+
+ with self.assertWarns(SomeWarning) as cm:
+ do_something()
+ the_warning = cm.warning
+ self.assertEqual(the_warning.some_attribute, 147)
+ """
+ context = _AssertWarnsContext(expected_warning, self, callable_obj)
+ if callable_obj is None:
+ return context
+ with context:
+ callable_obj(*args, **kwargs)
+
def _getAssertEqualityFunc(self, first, second):
"""Get a detailed comparison function for the types of the two args.
@@ -1019,6 +1103,28 @@ class TestCase(object):
with context:
callable_obj(*args, **kwargs)
+ def assertWarnsRegexp(self, expected_warning, expected_regexp,
+ callable_obj=None, *args, **kwargs):
+ """Asserts that the message in a triggered warning matches a regexp.
+ Basic functioning is similar to assertWarns() with the addition
+ that only warnings whose messages also match the regular expression
+ are considered successful matches.
+
+ Args:
+ expected_warning: Warning class expected to be triggered.
+ expected_regexp: Regexp (re pattern object or string) expected
+ to be found in error message.
+ callable_obj: Function to be called.
+ args: Extra args.
+ kwargs: Extra kwargs.
+ """
+ context = _AssertWarnsContext(expected_warning, self, callable_obj,
+ expected_regexp)
+ if callable_obj is None:
+ return context
+ with context:
+ callable_obj(*args, **kwargs)
+
def assertRegexpMatches(self, text, expected_regexp, msg=None):
"""Fail the test unless the text matches the regular expression."""
if isinstance(expected_regexp, (str, bytes)):
diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py
index 1800f2e..1bd839f 100644
--- a/Lib/unittest/test/test_case.py
+++ b/Lib/unittest/test/test_case.py
@@ -2,6 +2,8 @@ import difflib
import pprint
import re
import sys
+import warnings
+import inspect
from copy import deepcopy
from test import support
@@ -917,6 +919,138 @@ test case
self.assertIsInstance(e, ExceptionMock)
self.assertEqual(e.args[0], v)
+ def testAssertWarnsCallable(self):
+ def _runtime_warn():
+ warnings.warn("foo", RuntimeWarning)
+ # Success when the right warning is triggered, even several times
+ self.assertWarns(RuntimeWarning, _runtime_warn)
+ self.assertWarns(RuntimeWarning, _runtime_warn)
+ # A tuple of warning classes is accepted
+ self.assertWarns((DeprecationWarning, RuntimeWarning), _runtime_warn)
+ # *args and **kwargs also work
+ self.assertWarns(RuntimeWarning,
+ warnings.warn, "foo", category=RuntimeWarning)
+ # Failure when no warning is triggered
+ with self.assertRaises(self.failureException):
+ self.assertWarns(RuntimeWarning, lambda: 0)
+ # Failure when another warning is triggered
+ with warnings.catch_warnings():
+ # Force default filter (in case tests are run with -We)
+ warnings.simplefilter("default", RuntimeWarning)
+ with self.assertRaises(self.failureException):
+ self.assertWarns(DeprecationWarning, _runtime_warn)
+ # Filters for other warnings are not modified
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", RuntimeWarning)
+ with self.assertRaises(RuntimeWarning):
+ self.assertWarns(DeprecationWarning, _runtime_warn)
+
+ def testAssertWarnsContext(self):
+ # Believe it or not, it is preferrable to duplicate all tests above,
+ # to make sure the __warningregistry__ $@ is circumvented correctly.
+ def _runtime_warn():
+ warnings.warn("foo", RuntimeWarning)
+ _runtime_warn_lineno = inspect.getsourcelines(_runtime_warn)[1]
+ with self.assertWarns(RuntimeWarning) as cm:
+ _runtime_warn()
+ # A tuple of warning classes is accepted
+ with self.assertWarns((DeprecationWarning, RuntimeWarning)) as cm:
+ _runtime_warn()
+ # The context manager exposes various useful attributes
+ self.assertIsInstance(cm.warning, RuntimeWarning)
+ self.assertEqual(cm.warning.args[0], "foo")
+ self.assertIn("test_case.py", cm.filename)
+ self.assertEqual(cm.lineno, _runtime_warn_lineno + 1)
+ # Same with several warnings
+ with self.assertWarns(RuntimeWarning):
+ _runtime_warn()
+ _runtime_warn()
+ with self.assertWarns(RuntimeWarning):
+ warnings.warn("foo", category=RuntimeWarning)
+ # Failure when no warning is triggered
+ with self.assertRaises(self.failureException):
+ with self.assertWarns(RuntimeWarning):
+ pass
+ # Failure when another warning is triggered
+ with warnings.catch_warnings():
+ # Force default filter (in case tests are run with -We)
+ warnings.simplefilter("default", RuntimeWarning)
+ with self.assertRaises(self.failureException):
+ with self.assertWarns(DeprecationWarning):
+ _runtime_warn()
+ # Filters for other warnings are not modified
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", RuntimeWarning)
+ with self.assertRaises(RuntimeWarning):
+ with self.assertWarns(DeprecationWarning):
+ _runtime_warn()
+
+ def testAssertWarnsRegexpCallable(self):
+ def _runtime_warn(msg):
+ warnings.warn(msg, RuntimeWarning)
+ self.assertWarnsRegexp(RuntimeWarning, "o+",
+ _runtime_warn, "foox")
+ # Failure when no warning is triggered
+ with self.assertRaises(self.failureException):
+ self.assertWarnsRegexp(RuntimeWarning, "o+",
+ lambda: 0)
+ # Failure when another warning is triggered
+ with warnings.catch_warnings():
+ # Force default filter (in case tests are run with -We)
+ warnings.simplefilter("default", RuntimeWarning)
+ with self.assertRaises(self.failureException):
+ self.assertWarnsRegexp(DeprecationWarning, "o+",
+ _runtime_warn, "foox")
+ # Failure when message doesn't match
+ with self.assertRaises(self.failureException):
+ self.assertWarnsRegexp(RuntimeWarning, "o+",
+ _runtime_warn, "barz")
+ # A little trickier: we ask RuntimeWarnings to be raised, and then
+ # check for some of them. It is implementation-defined whether
+ # non-matching RuntimeWarnings are simply re-raised, or produce a
+ # failureException.
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", RuntimeWarning)
+ with self.assertRaises((RuntimeWarning, self.failureException)):
+ self.assertWarnsRegexp(RuntimeWarning, "o+",
+ _runtime_warn, "barz")
+
+ def testAssertWarnsRegexpContext(self):
+ # Same as above, but with assertWarnsRegexp as a context manager
+ def _runtime_warn(msg):
+ warnings.warn(msg, RuntimeWarning)
+ _runtime_warn_lineno = inspect.getsourcelines(_runtime_warn)[1]
+ with self.assertWarnsRegexp(RuntimeWarning, "o+") as cm:
+ _runtime_warn("foox")
+ self.assertIsInstance(cm.warning, RuntimeWarning)
+ self.assertEqual(cm.warning.args[0], "foox")
+ self.assertIn("test_case.py", cm.filename)
+ self.assertEqual(cm.lineno, _runtime_warn_lineno + 1)
+ # Failure when no warning is triggered
+ with self.assertRaises(self.failureException):
+ with self.assertWarnsRegexp(RuntimeWarning, "o+"):
+ pass
+ # Failure when another warning is triggered
+ with warnings.catch_warnings():
+ # Force default filter (in case tests are run with -We)
+ warnings.simplefilter("default", RuntimeWarning)
+ with self.assertRaises(self.failureException):
+ with self.assertWarnsRegexp(DeprecationWarning, "o+"):
+ _runtime_warn("foox")
+ # Failure when message doesn't match
+ with self.assertRaises(self.failureException):
+ with self.assertWarnsRegexp(RuntimeWarning, "o+"):
+ _runtime_warn("barz")
+ # A little trickier: we ask RuntimeWarnings to be raised, and then
+ # check for some of them. It is implementation-defined whether
+ # non-matching RuntimeWarnings are simply re-raised, or produce a
+ # failureException.
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", RuntimeWarning)
+ with self.assertRaises((RuntimeWarning, self.failureException)):
+ with self.assertWarnsRegexp(RuntimeWarning, "o+"):
+ _runtime_warn("barz")
+
def testSynonymAssertMethodNames(self):
"""Test undocumented method name synonyms.