diff options
Diffstat (limited to 'Lib/unittest/test/test_result.py')
-rw-r--r-- | Lib/unittest/test/test_result.py | 132 |
1 files changed, 131 insertions, 1 deletions
diff --git a/Lib/unittest/test/test_result.py b/Lib/unittest/test/test_result.py index 967b662..c703032 100644 --- a/Lib/unittest/test/test_result.py +++ b/Lib/unittest/test/test_result.py @@ -1,6 +1,6 @@ import io import sys -import warnings +import textwrap from test import support @@ -25,6 +25,8 @@ class Test_TestResult(unittest.TestCase): self.assertEqual(len(result.failures), 0) self.assertEqual(result.testsRun, 0) self.assertEqual(result.shouldStop, False) + self.assertIsNone(result._stdout_buffer) + self.assertIsNone(result._stderr_buffer) # "This method can be called to signal that the set of tests being # run should be aborted by setting the TestResult's shouldStop @@ -302,6 +304,8 @@ def __init__(self, stream=None, descriptions=None, verbosity=None): self.errors = [] self.testsRun = 0 self.shouldStop = False + self.buffer = False + classDict['__init__'] = __init__ OldResult = type('OldResult', (object,), classDict) @@ -355,3 +359,129 @@ class Test_OldTestResult(unittest.TestCase): # This will raise an exception if TextTestRunner can't handle old # test result objects runner.run(Test('testFoo')) + + +class TestOutputBuffering(unittest.TestCase): + + def setUp(self): + self._real_out = sys.stdout + self._real_err = sys.stderr + + def tearDown(self): + sys.stdout = self._real_out + sys.stderr = self._real_err + + def testBufferOutputOff(self): + real_out = self._real_out + real_err = self._real_err + + result = unittest.TestResult() + self.assertFalse(result.buffer) + + self.assertIs(real_out, sys.stdout) + self.assertIs(real_err, sys.stderr) + + result.startTest(self) + + self.assertIs(real_out, sys.stdout) + self.assertIs(real_err, sys.stderr) + + def testBufferOutputStartTestAddSuccess(self): + real_out = self._real_out + real_err = self._real_err + + result = unittest.TestResult() + self.assertFalse(result.buffer) + + result.buffer = True + + self.assertIs(real_out, sys.stdout) + self.assertIs(real_err, sys.stderr) + + result.startTest(self) + + self.assertIsNot(real_out, sys.stdout) + self.assertIsNot(real_err, sys.stderr) + self.assertIsInstance(sys.stdout, io.StringIO) + self.assertIsInstance(sys.stderr, io.StringIO) + self.assertIsNot(sys.stdout, sys.stderr) + + out_stream = sys.stdout + err_stream = sys.stderr + + result._original_stdout = io.StringIO() + result._original_stderr = io.StringIO() + + print('foo') + print('bar', file=sys.stderr) + + self.assertEqual(out_stream.getvalue(), 'foo\n') + self.assertEqual(err_stream.getvalue(), 'bar\n') + + self.assertEqual(result._original_stdout.getvalue(), '') + self.assertEqual(result._original_stderr.getvalue(), '') + + result.addSuccess(self) + result.stopTest(self) + + self.assertIs(sys.stdout, result._original_stdout) + self.assertIs(sys.stderr, result._original_stderr) + + self.assertEqual(result._original_stdout.getvalue(), '') + self.assertEqual(result._original_stderr.getvalue(), '') + + self.assertEqual(out_stream.getvalue(), '') + self.assertEqual(err_stream.getvalue(), '') + + + def getStartedResult(self): + result = unittest.TestResult() + result.buffer = True + result.startTest(self) + return result + + def testBufferOutputAddErrorOrFailure(self): + for message_attr, add_attr, include_error in [ + ('errors', 'addError', True), + ('failures', 'addFailure', False), + ('errors', 'addError', True), + ('failures', 'addFailure', False) + ]: + result = self.getStartedResult() + buffered_out = sys.stdout + buffered_err = sys.stderr + result._original_stdout = io.StringIO() + result._original_stderr = io.StringIO() + + print('foo', file=sys.stdout) + if include_error: + print('bar', file=sys.stderr) + + + addFunction = getattr(result, add_attr) + addFunction(self, (None, None, None)) + result.stopTest(self) + + result_list = getattr(result, message_attr) + self.assertEqual(len(result_list), 1) + + test, message = result_list[0] + expectedOutMessage = textwrap.dedent(""" + Stdout: + foo + """) + expectedErrMessage = '' + if include_error: + expectedErrMessage = textwrap.dedent(""" + Stderr: + bar + """) + expectedFullMessage = 'NoneType\n%s%s' % (expectedOutMessage, expectedErrMessage) + + self.assertIs(test, self) + self.assertEqual(result._original_stdout.getvalue(), expectedOutMessage) + self.assertEqual(result._original_stderr.getvalue(), expectedErrMessage) + self.assertMultiLineEqual(message, expectedFullMessage) + +if __name__ == '__main__': + unittest.main() |