summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/unittest/result.py1
-rw-r--r--Lib/unittest/suite.py36
-rw-r--r--Lib/unittest/test/test_suite.py14
3 files changed, 32 insertions, 19 deletions
diff --git a/Lib/unittest/result.py b/Lib/unittest/result.py
index 1dbd04c..3dc7154 100644
--- a/Lib/unittest/result.py
+++ b/Lib/unittest/result.py
@@ -34,6 +34,7 @@ class TestResult(object):
formatted traceback of the error that occurred.
"""
_previousTestClass = None
+ _testRunEntered = False
_moduleSetUpFailed = False
def __init__(self, stream=None, descriptions=None, verbosity=None):
self.failfast = False
diff --git a/Lib/unittest/suite.py b/Lib/unittest/suite.py
index b6ae68c..73f0e17 100644
--- a/Lib/unittest/suite.py
+++ b/Lib/unittest/suite.py
@@ -77,23 +77,11 @@ class TestSuite(BaseTestSuite):
subclassing, do not forget to call the base class constructor.
"""
+ def run(self, result, debug=False):
+ topLevel = False
+ if getattr(result, '_testRunEntered', False) is False:
+ result._testRunEntered = topLevel = True
- def run(self, result):
- self._wrapped_run(result)
- self._tearDownPreviousClass(None, result)
- self._handleModuleTearDown(result)
- return result
-
- def debug(self):
- """Run the tests without collecting errors in a TestResult"""
- debug = _DebugResult()
- self._wrapped_run(debug, True)
- self._tearDownPreviousClass(None, debug)
- self._handleModuleTearDown(debug)
-
- ################################
- # private methods
- def _wrapped_run(self, result, debug=False):
for test in self:
if result.shouldStop:
break
@@ -108,13 +96,23 @@ class TestSuite(BaseTestSuite):
getattr(result, '_moduleSetUpFailed', False)):
continue
- if hasattr(test, '_wrapped_run'):
- test._wrapped_run(result, debug)
- elif not debug:
+ if not debug:
test(result)
else:
test.debug()
+ if topLevel:
+ self._tearDownPreviousClass(None, result)
+ self._handleModuleTearDown(result)
+ return result
+
+ def debug(self):
+ """Run the tests without collecting errors in a TestResult"""
+ debug = _DebugResult()
+ self.run(debug, True)
+
+ ################################
+
def _handleClassSetUp(self, test, result):
previousClass = getattr(result, '_previousTestClass', None)
currentClass = test.__class__
diff --git a/Lib/unittest/test/test_suite.py b/Lib/unittest/test/test_suite.py
index 47b57de..fa32247 100644
--- a/Lib/unittest/test/test_suite.py
+++ b/Lib/unittest/test/test_suite.py
@@ -345,5 +345,19 @@ class Test_TestSuite(unittest.TestCase, TestEquality):
self.assertEqual(result.testsRun, 2)
+ def test_overriding_call(self):
+ class MySuite(unittest.TestSuite):
+ called = False
+ def __call__(self, *args, **kw):
+ self.called = True
+ unittest.TestSuite.__call__(self, *args, **kw)
+
+ suite = MySuite()
+ wrapper = unittest.TestSuite()
+ wrapper.addTest(suite)
+ wrapper(unittest.TestResult())
+ self.assertTrue(suite.called)
+
+
if __name__ == '__main__':
unittest.main()