summaryrefslogtreecommitdiffstats
path: root/Lib/unittest/case.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/unittest/case.py')
-rw-r--r--Lib/unittest/case.py114
1 files changed, 110 insertions, 4 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index 03346a1..761ac46 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -90,8 +90,7 @@ def expectedFailure(func):
return wrapper
-class _AssertRaisesContext(object):
- """A context manager used to implement TestCase.assertRaises* methods."""
+class _AssertRaisesBaseContext(object):
def __init__(self, expected, test_case, callable_obj=None,
expected_regexp=None):
@@ -104,8 +103,14 @@ class _AssertRaisesContext(object):
self.obj_name = str(callable_obj)
else:
self.obj_name = None
+ if isinstance(expected_regexp, (bytes, str)):
+ expected_regexp = re.compile(expected_regexp)
self.expected_regexp = expected_regexp
+
+class _AssertRaisesContext(_AssertRaisesBaseContext):
+ """A context manager used to implement TestCase.assertRaises* methods."""
+
def __enter__(self):
return self
@@ -130,14 +135,62 @@ class _AssertRaisesContext(object):
return True
expected_regexp = self.expected_regexp
- if isinstance(expected_regexp, (bytes, str)):
- expected_regexp = re.compile(expected_regexp)
if not expected_regexp.search(str(exc_value)):
raise self.failureException('"%s" does not match "%s"' %
(expected_regexp.pattern, str(exc_value)))
return True
+class _AssertWarnsContext(_AssertRaisesBaseContext):
+ """A context manager used to implement TestCase.assertWarns* methods."""
+
+ def __enter__(self):
+ # The __warningregistry__'s need to be in a pristine state for tests
+ # to work properly.
+ for v in sys.modules.values():
+ if getattr(v, '__warningregistry__', None):
+ v.__warningregistry__ = {}
+ self.warnings_manager = warnings.catch_warnings(record=True)
+ self.warnings = self.warnings_manager.__enter__()
+ warnings.simplefilter("always", self.expected)
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ self.warnings_manager.__exit__(exc_type, exc_value, tb)
+ if exc_type is not None:
+ # let unexpected exceptions pass through
+ return
+ try:
+ exc_name = self.expected.__name__
+ except AttributeError:
+ exc_name = str(self.expected)
+ first_matching = None
+ for m in self.warnings:
+ w = m.message
+ if not isinstance(w, self.expected):
+ continue
+ if first_matching is None:
+ first_matching = w
+ if (self.expected_regexp is not None and
+ not self.expected_regexp.search(str(w))):
+ continue
+ # store warning for later retrieval
+ self.warning = w
+ self.filename = m.filename
+ self.lineno = m.lineno
+ return
+ # Now we simply try to choose a helpful failure message
+ if first_matching is not None:
+ raise self.failureException('"%s" does not match "%s"' %
+ (self.expected_regexp.pattern, str(first_matching)))
+ if self.obj_name:
+ raise self.failureException("{0} not triggered by {1}"
+ .format(exc_name, self.obj_name))
+ else:
+ raise self.failureException("{0} not triggered"
+ .format(exc_name))
+
+
class TestCase(object):
"""A class whose instances are single test cases.
@@ -464,6 +517,37 @@ class TestCase(object):
with context:
callableObj(*args, **kwargs)
+ def assertWarns(self, expected_warning, callable_obj=None, *args, **kwargs):
+ """Fail unless a warning of class warnClass is triggered
+ by callableObj when invoked with arguments args and keyword
+ arguments kwargs. If a different type of warning is
+ triggered, it will not be handled: depending on the other
+ warning filtering rules in effect, it might be silenced, printed
+ out, or raised as an exception.
+
+ If called with callableObj omitted or None, will return a
+ context object used like this::
+
+ with self.assertWarns(SomeWarning):
+ do_something()
+
+ The context manager keeps a reference to the first matching
+ warning as the 'warning' attribute; similarly, the 'filename'
+ and 'lineno' attributes give you information about the line
+ of Python code from which the warning was triggered.
+ This allows you to inspect the warning after the assertion::
+
+ with self.assertWarns(SomeWarning) as cm:
+ do_something()
+ the_warning = cm.warning
+ self.assertEqual(the_warning.some_attribute, 147)
+ """
+ context = _AssertWarnsContext(expected_warning, self, callable_obj)
+ if callable_obj is None:
+ return context
+ with context:
+ callable_obj(*args, **kwargs)
+
def _getAssertEqualityFunc(self, first, second):
"""Get a detailed comparison function for the types of the two args.
@@ -1019,6 +1103,28 @@ class TestCase(object):
with context:
callable_obj(*args, **kwargs)
+ def assertWarnsRegexp(self, expected_warning, expected_regexp,
+ callable_obj=None, *args, **kwargs):
+ """Asserts that the message in a triggered warning matches a regexp.
+ Basic functioning is similar to assertWarns() with the addition
+ that only warnings whose messages also match the regular expression
+ are considered successful matches.
+
+ Args:
+ expected_warning: Warning class expected to be triggered.
+ expected_regexp: Regexp (re pattern object or string) expected
+ to be found in error message.
+ callable_obj: Function to be called.
+ args: Extra args.
+ kwargs: Extra kwargs.
+ """
+ context = _AssertWarnsContext(expected_warning, self, callable_obj,
+ expected_regexp)
+ if callable_obj is None:
+ return context
+ with context:
+ callable_obj(*args, **kwargs)
+
def assertRegexpMatches(self, text, expected_regexp, msg=None):
"""Fail the test unless the text matches the regular expression."""
if isinstance(expected_regexp, (str, bytes)):