summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorMichael Foord <michael@python.org>2011-03-17 17:58:22 (GMT)
committerMichael Foord <michael@python.org>2011-03-17 17:58:22 (GMT)
commite9ff2ef20488eb3d1e8bba04516939585f35a148 (patch)
treee92d21272c822a9b4beb170aa6a4a71a48eb1fa6 /Lib
parentf694a409aadf04e67d4e72a6f8e45dcbe34ab6e5 (diff)
parentf40834f39b7bf1e667fbe040fe869232d2488f60 (diff)
downloadcpython-e9ff2ef20488eb3d1e8bba04516939585f35a148.zip
cpython-e9ff2ef20488eb3d1e8bba04516939585f35a148.tar.gz
cpython-e9ff2ef20488eb3d1e8bba04516939585f35a148.tar.bz2
Closes issue 10979. unittest buffering now works with class and module setup and teardown
Diffstat (limited to 'Lib')
-rw-r--r--Lib/unittest/result.py8
-rw-r--r--Lib/unittest/suite.py18
-rw-r--r--Lib/unittest/test/test_result.py67
3 files changed, 91 insertions, 2 deletions
diff --git a/Lib/unittest/result.py b/Lib/unittest/result.py
index 3dc7154..44bf186 100644
--- a/Lib/unittest/result.py
+++ b/Lib/unittest/result.py
@@ -59,6 +59,9 @@ class TestResult(object):
"Called when the given test is about to be run"
self.testsRun += 1
self._mirrorOutput = False
+ self._setupStdout()
+
+ def _setupStdout(self):
if self.buffer:
if self._stderr_buffer is None:
self._stderr_buffer = io.StringIO()
@@ -74,6 +77,10 @@ class TestResult(object):
def stopTest(self, test):
"""Called when the given test has been run"""
+ self._restoreStdout()
+ self._mirrorOutput = False
+
+ def _restoreStdout(self):
if self.buffer:
if self._mirrorOutput:
output = sys.stdout.getvalue()
@@ -93,7 +100,6 @@ class TestResult(object):
self._stdout_buffer.truncate()
self._stderr_buffer.seek(0)
self._stderr_buffer.truncate()
- self._mirrorOutput = False
def stopTestRun(self):
"""Called once after all tests are executed.
diff --git a/Lib/unittest/suite.py b/Lib/unittest/suite.py
index 77ce089..38bd6b8 100644
--- a/Lib/unittest/suite.py
+++ b/Lib/unittest/suite.py
@@ -8,6 +8,11 @@ from . import util
__unittest = True
+def _call_if_exists(parent, attr):
+ func = getattr(parent, attr, lambda: None)
+ func()
+
+
class BaseTestSuite(object):
"""A simple test suite that doesn't provide class or module shared fixtures.
"""
@@ -133,6 +138,7 @@ class TestSuite(BaseTestSuite):
setUpClass = getattr(currentClass, 'setUpClass', None)
if setUpClass is not None:
+ _call_if_exists(result, '_setupStdout')
try:
setUpClass()
except Exception as e:
@@ -142,6 +148,8 @@ class TestSuite(BaseTestSuite):
className = util.strclass(currentClass)
errorName = 'setUpClass (%s)' % className
self._addClassOrModuleLevelException(result, e, errorName)
+ finally:
+ _call_if_exists(result, '_restoreStdout')
def _get_previous_module(self, result):
previousModule = None
@@ -167,6 +175,7 @@ class TestSuite(BaseTestSuite):
return
setUpModule = getattr(module, 'setUpModule', None)
if setUpModule is not None:
+ _call_if_exists(result, '_setupStdout')
try:
setUpModule()
except Exception as e:
@@ -175,6 +184,8 @@ class TestSuite(BaseTestSuite):
result._moduleSetUpFailed = True
errorName = 'setUpModule (%s)' % currentModule
self._addClassOrModuleLevelException(result, e, errorName)
+ finally:
+ _call_if_exists(result, '_restoreStdout')
def _addClassOrModuleLevelException(self, result, exception, errorName):
error = _ErrorHolder(errorName)
@@ -198,6 +209,7 @@ class TestSuite(BaseTestSuite):
tearDownModule = getattr(module, 'tearDownModule', None)
if tearDownModule is not None:
+ _call_if_exists(result, '_setupStdout')
try:
tearDownModule()
except Exception as e:
@@ -205,6 +217,8 @@ class TestSuite(BaseTestSuite):
raise
errorName = 'tearDownModule (%s)' % previousModule
self._addClassOrModuleLevelException(result, e, errorName)
+ finally:
+ _call_if_exists(result, '_restoreStdout')
def _tearDownPreviousClass(self, test, result):
previousClass = getattr(result, '_previousTestClass', None)
@@ -220,6 +234,7 @@ class TestSuite(BaseTestSuite):
tearDownClass = getattr(previousClass, 'tearDownClass', None)
if tearDownClass is not None:
+ _call_if_exists(result, '_setupStdout')
try:
tearDownClass()
except Exception as e:
@@ -228,7 +243,8 @@ class TestSuite(BaseTestSuite):
className = util.strclass(previousClass)
errorName = 'tearDownClass (%s)' % className
self._addClassOrModuleLevelException(result, e, errorName)
-
+ finally:
+ _call_if_exists(result, '_restoreStdout')
class _ErrorHolder(object):
diff --git a/Lib/unittest/test/test_result.py b/Lib/unittest/test/test_result.py
index 64798a1..1c58e61 100644
--- a/Lib/unittest/test/test_result.py
+++ b/Lib/unittest/test/test_result.py
@@ -497,5 +497,72 @@ class TestOutputBuffering(unittest.TestCase):
self.assertEqual(result._original_stderr.getvalue(), expectedErrMessage)
self.assertMultiLineEqual(message, expectedFullMessage)
+ def testBufferSetupClass(self):
+ result = unittest.TestResult()
+ result.buffer = True
+
+ class Foo(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ 1/0
+ def test_foo(self):
+ pass
+ suite = unittest.TestSuite([Foo('test_foo')])
+ suite(result)
+ self.assertEqual(len(result.errors), 1)
+
+ def testBufferTearDownClass(self):
+ result = unittest.TestResult()
+ result.buffer = True
+
+ class Foo(unittest.TestCase):
+ @classmethod
+ def tearDownClass(cls):
+ 1/0
+ def test_foo(self):
+ pass
+ suite = unittest.TestSuite([Foo('test_foo')])
+ suite(result)
+ self.assertEqual(len(result.errors), 1)
+
+ def testBufferSetUpModule(self):
+ result = unittest.TestResult()
+ result.buffer = True
+
+ class Foo(unittest.TestCase):
+ def test_foo(self):
+ pass
+ class Module(object):
+ @staticmethod
+ def setUpModule():
+ 1/0
+
+ Foo.__module__ = 'Module'
+ sys.modules['Module'] = Module
+ self.addCleanup(sys.modules.pop, 'Module')
+ suite = unittest.TestSuite([Foo('test_foo')])
+ suite(result)
+ self.assertEqual(len(result.errors), 1)
+
+ def testBufferTearDownModule(self):
+ result = unittest.TestResult()
+ result.buffer = True
+
+ class Foo(unittest.TestCase):
+ def test_foo(self):
+ pass
+ class Module(object):
+ @staticmethod
+ def tearDownModule():
+ 1/0
+
+ Foo.__module__ = 'Module'
+ sys.modules['Module'] = Module
+ self.addCleanup(sys.modules.pop, 'Module')
+ suite = unittest.TestSuite([Foo('test_foo')])
+ suite(result)
+ self.assertEqual(len(result.errors), 1)
+
+
if __name__ == '__main__':
unittest.main()