diff options
author | Antoine Pitrou <solipsis@pitrou.net> | 2013-09-14 17:45:47 (GMT) |
---|---|---|
committer | Antoine Pitrou <solipsis@pitrou.net> | 2013-09-14 17:45:47 (GMT) |
commit | 0715b9fad375751d93de90f23bab20faba1b1b62 (patch) | |
tree | ca77149d5d6549d441ecdcabd3b69c52ced7b4b1 /Lib/unittest | |
parent | 692ee9eaf0ae910a0a024f0122c4e51df6c04711 (diff) | |
download | cpython-0715b9fad375751d93de90f23bab20faba1b1b62.zip cpython-0715b9fad375751d93de90f23bab20faba1b1b62.tar.gz cpython-0715b9fad375751d93de90f23bab20faba1b1b62.tar.bz2 |
Issue #18937: Add an assertLogs() context manager to unittest.TestCase to ensure that a block of code emits a message using the logging module.
Diffstat (limited to 'Lib/unittest')
-rw-r--r-- | Lib/unittest/case.py | 109 | ||||
-rw-r--r-- | Lib/unittest/test/test_case.py | 96 |
2 files changed, 199 insertions, 6 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index d04ea61..2909610 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -3,6 +3,7 @@ import sys import functools import difflib +import logging import pprint import re import warnings @@ -115,10 +116,21 @@ def expectedFailure(test_item): return test_item -class _AssertRaisesBaseContext(object): +class _BaseTestCaseContext: + + def __init__(self, test_case): + self.test_case = test_case + + def _raiseFailure(self, standardMsg): + msg = self.test_case._formatMessage(self.msg, standardMsg) + raise self.test_case.failureException(msg) + + +class _AssertRaisesBaseContext(_BaseTestCaseContext): def __init__(self, expected, test_case, callable_obj=None, expected_regex=None): + _BaseTestCaseContext.__init__(self, test_case) self.expected = expected self.test_case = test_case if callable_obj is not None: @@ -133,10 +145,6 @@ class _AssertRaisesBaseContext(object): self.expected_regex = expected_regex self.msg = None - def _raiseFailure(self, standardMsg): - msg = self.test_case._formatMessage(self.msg, standardMsg) - raise self.test_case.failureException(msg) - def handle(self, name, callable_obj, args, kwargs): """ If callable_obj is None, assertRaises/Warns is being used as a @@ -150,7 +158,6 @@ class _AssertRaisesBaseContext(object): callable_obj(*args, **kwargs) - class _AssertRaisesContext(_AssertRaisesBaseContext): """A context manager used to implement TestCase.assertRaises* methods.""" @@ -232,6 +239,74 @@ class _AssertWarnsContext(_AssertRaisesBaseContext): self._raiseFailure("{} not triggered".format(exc_name)) + +_LoggingWatcher = collections.namedtuple("_LoggingWatcher", + ["records", "output"]) + + +class _CapturingHandler(logging.Handler): + """ + A logging handler capturing all (raw and formatted) logging output. + """ + + def __init__(self): + logging.Handler.__init__(self) + self.watcher = _LoggingWatcher([], []) + + def flush(self): + pass + + def emit(self, record): + self.watcher.records.append(record) + msg = self.format(record) + self.watcher.output.append(msg) + + + +class _AssertLogsContext(_BaseTestCaseContext): + """A context manager used to implement TestCase.assertLogs().""" + + LOGGING_FORMAT = "%(levelname)s:%(name)s:%(message)s" + + def __init__(self, test_case, logger_name, level): + _BaseTestCaseContext.__init__(self, test_case) + self.logger_name = logger_name + if level: + self.level = logging._nameToLevel.get(level, level) + else: + self.level = logging.INFO + self.msg = None + + def __enter__(self): + if isinstance(self.logger_name, logging.Logger): + logger = self.logger = self.logger_name + else: + logger = self.logger = logging.getLogger(self.logger_name) + formatter = logging.Formatter(self.LOGGING_FORMAT) + handler = _CapturingHandler() + handler.setFormatter(formatter) + self.watcher = handler.watcher + self.old_handlers = logger.handlers[:] + self.old_level = logger.level + self.old_propagate = logger.propagate + logger.handlers = [handler] + logger.setLevel(self.level) + logger.propagate = False + return handler.watcher + + def __exit__(self, exc_type, exc_value, tb): + self.logger.handlers = self.old_handlers + self.logger.propagate = self.old_propagate + self.logger.setLevel(self.old_level) + if exc_type is not None: + # let unexpected exceptions pass through + return False + if len(self.watcher.records) == 0: + self._raiseFailure( + "no logs of level {} or higher triggered on {}" + .format(logging.getLevelName(self.level), self.logger.name)) + + class TestCase(object): """A class whose instances are single test cases. @@ -644,6 +719,28 @@ class TestCase(object): context = _AssertWarnsContext(expected_warning, self, callable_obj) return context.handle('assertWarns', callable_obj, args, kwargs) + def assertLogs(self, logger=None, level=None): + """Fail unless a log message of level *level* or higher is emitted + on *logger_name* or its children. If omitted, *level* defaults to + INFO and *logger* defaults to the root logger. + + This method must be used as a context manager, and will yield + a recording object with two attributes: `output` and `records`. + At the end of the context manager, the `output` attribute will + be a list of the matching formatted log messages and the + `records` attribute will be a list of the corresponding LogRecord + objects. + + Example:: + + with self.assertLogs('foo', level='INFO') as cm: + logging.getLogger('foo').info('first message') + logging.getLogger('foo.bar').error('second message') + self.assertEqual(cm.output, ['INFO:foo:first message', + 'ERROR:foo.bar:second message']) + """ + return _AssertLogsContext(self, logger, level) + def _getAssertEqualityFunc(self, first, second): """Get a detailed comparison function for the types of the two args. diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py index 51b06bc..f08668f 100644 --- a/Lib/unittest/test/test_case.py +++ b/Lib/unittest/test/test_case.py @@ -1,8 +1,10 @@ +import contextlib import difflib import pprint import pickle import re import sys +import logging import warnings import weakref import inspect @@ -16,6 +18,12 @@ from .support import ( TestEquality, TestHashing, LoggingResult, LegacyLoggingResult, ResultWithNoStartTestRunStopTestRun ) +from test.support import captured_stderr + + +log_foo = logging.getLogger('foo') +log_foobar = logging.getLogger('foo.bar') +log_quux = logging.getLogger('quux') class Test(object): @@ -1251,6 +1259,94 @@ test case with self.assertWarnsRegex(RuntimeWarning, "o+"): _runtime_warn("barz") + @contextlib.contextmanager + def assertNoStderr(self): + with captured_stderr() as buf: + yield + self.assertEqual(buf.getvalue(), "") + + def assertLogRecords(self, records, matches): + self.assertEqual(len(records), len(matches)) + for rec, match in zip(records, matches): + self.assertIsInstance(rec, logging.LogRecord) + for k, v in match.items(): + self.assertEqual(getattr(rec, k), v) + + def testAssertLogsDefaults(self): + # defaults: root logger, level INFO + with self.assertNoStderr(): + with self.assertLogs() as cm: + log_foo.info("1") + log_foobar.debug("2") + self.assertEqual(cm.output, ["INFO:foo:1"]) + self.assertLogRecords(cm.records, [{'name': 'foo'}]) + + def testAssertLogsTwoMatchingMessages(self): + # Same, but with two matching log messages + with self.assertNoStderr(): + with self.assertLogs() as cm: + log_foo.info("1") + log_foobar.debug("2") + log_quux.warning("3") + self.assertEqual(cm.output, ["INFO:foo:1", "WARNING:quux:3"]) + self.assertLogRecords(cm.records, + [{'name': 'foo'}, {'name': 'quux'}]) + + def checkAssertLogsPerLevel(self, level): + # Check level filtering + with self.assertNoStderr(): + with self.assertLogs(level=level) as cm: + log_foo.warning("1") + log_foobar.error("2") + log_quux.critical("3") + self.assertEqual(cm.output, ["ERROR:foo.bar:2", "CRITICAL:quux:3"]) + self.assertLogRecords(cm.records, + [{'name': 'foo.bar'}, {'name': 'quux'}]) + + def testAssertLogsPerLevel(self): + self.checkAssertLogsPerLevel(logging.ERROR) + self.checkAssertLogsPerLevel('ERROR') + + def checkAssertLogsPerLogger(self, logger): + # Check per-logger fitering + with self.assertNoStderr(): + with self.assertLogs(level='DEBUG') as outer_cm: + with self.assertLogs(logger, level='DEBUG') as cm: + log_foo.info("1") + log_foobar.debug("2") + log_quux.warning("3") + self.assertEqual(cm.output, ["INFO:foo:1", "DEBUG:foo.bar:2"]) + self.assertLogRecords(cm.records, + [{'name': 'foo'}, {'name': 'foo.bar'}]) + # The outer catchall caught the quux log + self.assertEqual(outer_cm.output, ["WARNING:quux:3"]) + + def testAssertLogsPerLogger(self): + self.checkAssertLogsPerLogger(logging.getLogger('foo')) + self.checkAssertLogsPerLogger('foo') + + def testAssertLogsFailureNoLogs(self): + # Failure due to no logs + with self.assertNoStderr(): + with self.assertRaises(self.failureException): + with self.assertLogs(): + pass + + def testAssertLogsFailureLevelTooHigh(self): + # Failure due to level too high + with self.assertNoStderr(): + with self.assertRaises(self.failureException): + with self.assertLogs(level='WARNING'): + log_foo.info("1") + + def testAssertLogsFailureMismatchingLogger(self): + # Failure due to mismatching logger (and the logged message is + # passed through) + with self.assertLogs('quux', level='ERROR'): + with self.assertRaises(self.failureException): + with self.assertLogs('foo'): + log_quux.error("1") + def testDeprecatedMethodNames(self): """ Test that the deprecated methods raise a DeprecationWarning. See #9424. |