diff options
Diffstat (limited to 'Lib/unittest.py')
-rw-r--r-- | Lib/unittest.py | 117 |
1 files changed, 87 insertions, 30 deletions
diff --git a/Lib/unittest.py b/Lib/unittest.py index 0b7cea4..c6d893e 100644 --- a/Lib/unittest.py +++ b/Lib/unittest.py @@ -173,10 +173,22 @@ class TestResult(object): "Called when the given test is about to be run" self.testsRun = self.testsRun + 1 + def startTestRun(self): + """Called once before any tests are executed. + + See startTest for a method called before each test. + """ + def stopTest(self, test): "Called when the given test has been run" pass + def stopTestRun(self): + """Called once after all tests are executed. + + See stopTest for a method called after each test. + """ + def addError(self, test, err): """Called when an error has occurred. 'err' is a tuple of values as returned by sys.exc_info(). @@ -262,7 +274,7 @@ class _AssertRaisesContext(object): def __enter__(self): pass - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, tb): if exc_type is None: try: exc_name = self.expected.__name__ @@ -341,12 +353,14 @@ class TestCase(object): not have a method with the specified name. """ self._testMethodName = methodName + self._result = None try: testMethod = getattr(self, methodName) except AttributeError: raise ValueError("no such test method in %s: %s" % \ (self.__class__, methodName)) self._testMethodDoc = testMethod.__doc__ + self._cleanups = [] # Map types to custom assertEqual functions that will compare # instances of said type in more detail to generate a more useful @@ -373,6 +387,14 @@ class TestCase(object): """ self._type_equality_funcs[typeobj] = _AssertWrapper(function) + def addCleanup(self, function, *args, **kwargs): + """Add a function, with arguments, to be called when the test is + completed. Functions added are called on a LIFO basis and are + called after tearDown on test failure or success. + + Cleanup items are called even if setUp fails (unlike tearDown).""" + self._cleanups.append((function, args, kwargs)) + def setUp(self): "Hook method for setting up the test fixture before exercising it." pass @@ -428,45 +450,70 @@ class TestCase(object): (_strclass(self.__class__), self._testMethodName) def run(self, result=None): + orig_result = result if result is None: result = self.defaultTestResult() + startTestRun = getattr(result, 'startTestRun', None) + if startTestRun is not None: + startTestRun() + + self._result = result result.startTest(self) testMethod = getattr(self, self._testMethodName) try: - try: - self.setUp() - except SkipTest as e: - result.addSkip(self, str(e)) - return - except Exception: - result.addError(self, sys.exc_info()) - return - success = False try: - testMethod() - except self.failureException: - result.addFailure(self, sys.exc_info()) - except _ExpectedFailure as e: - result.addExpectedFailure(self, e.exc_info) - except _UnexpectedSuccess: - result.addUnexpectedSuccess(self) + self.setUp() except SkipTest as e: result.addSkip(self, str(e)) except Exception: result.addError(self, sys.exc_info()) else: - success = True + try: + testMethod() + except self.failureException: + result.addFailure(self, sys.exc_info()) + except _ExpectedFailure as e: + result.addExpectedFailure(self, e.exc_info) + except _UnexpectedSuccess: + result.addUnexpectedSuccess(self) + except SkipTest as e: + result.addSkip(self, str(e)) + except Exception: + result.addError(self, sys.exc_info()) + else: + success = True - try: - self.tearDown() - except Exception: - result.addError(self, sys.exc_info()) - success = False + try: + self.tearDown() + except Exception: + result.addError(self, sys.exc_info()) + success = False + + cleanUpSuccess = self.doCleanups() + success = success and cleanUpSuccess if success: result.addSuccess(self) finally: result.stopTest(self) + if orig_result is None: + stopTestRun = getattr(result, 'stopTestRun', None) + if stopTestRun is not None: + stopTestRun() + + def doCleanups(self): + """Execute all cleanup functions. Normally called for you after + tearDown.""" + result = self._result + ok = True + while self._cleanups: + function, args, kwargs = self._cleanups.pop(-1) + try: + function(*args, **kwargs) + except Exception: + ok = False + result.addError(self, sys.exc_info()) + return ok def __call__(self, *args, **kwds): return self.run(*args, **kwds) @@ -1037,7 +1084,7 @@ class TestSuite(object): def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented - return self._tests == other._tests + return list(self) == list(other) def __ne__(self, other): return not self == other @@ -1160,8 +1207,7 @@ class FunctionTestCase(TestCase): self._testFunc, self._description)) def __str__(self): - return "%s (%s)" % (_strclass(self.__class__), - self.__testFunc.__name__) + return "%s (%s)" % (_strclass(self.__class__), self._testFunc.__name__) def __repr__(self): return "<%s testFunc=%s>" % (_strclass(self.__class__), self._testFunc) @@ -1449,7 +1495,15 @@ class TextTestRunner(object): "Run the given test case or test suite." result = self._makeResult() startTime = time.time() - test(result) + startTestRun = getattr(result, 'startTestRun', None) + if startTestRun is not None: + startTestRun() + try: + test(result) + finally: + stopTestRun = getattr(result, 'stopTestRun', None) + if stopTestRun is not None: + stopTestRun() stopTime = time.time() timeTaken = stopTime - startTime result.printErrors() @@ -1511,7 +1565,7 @@ Examples: """ def __init__(self, module='__main__', defaultTest=None, argv=None, testRunner=TextTestRunner, - testLoader=defaultTestLoader): + testLoader=defaultTestLoader, exit=True): if isinstance(module, str): self.module = __import__(module) for part in module.split('.')[1:]: @@ -1520,6 +1574,8 @@ Examples: self.module = module if argv is None: argv = sys.argv + + self.exit = exit self.verbosity = 1 self.defaultTest = defaultTest self.testRunner = testRunner @@ -1571,8 +1627,9 @@ Examples: else: # it is assumed to be a TestRunner instance testRunner = self.testRunner - result = testRunner.run(self.test) - sys.exit(not result.wasSuccessful()) + self.result = testRunner.run(self.test) + if self.exit: + sys.exit(not self.result.wasSuccessful()) main = TestProgram |