diff options
author | Benjamin Peterson <benjamin@python.org> | 2009-03-24 00:35:20 (GMT) |
---|---|---|
committer | Benjamin Peterson <benjamin@python.org> | 2009-03-24 00:35:20 (GMT) |
commit | a7d441de68233e2f035930d57f5373844c29fb89 (patch) | |
tree | 453565ae04541f9cebc1b7934155d14e91a9e82c /Lib | |
parent | 21b617bd98260a69c71f77586c8e6f2eb52e0ebf (diff) | |
download | cpython-a7d441de68233e2f035930d57f5373844c29fb89.zip cpython-a7d441de68233e2f035930d57f5373844c29fb89.tar.gz cpython-a7d441de68233e2f035930d57f5373844c29fb89.tar.bz2 |
some cleanup and modernization
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/unittest.py | 91 |
1 files changed, 50 insertions, 41 deletions
diff --git a/Lib/unittest.py b/Lib/unittest.py index 000c201..b9ef3d7 100644 --- a/Lib/unittest.py +++ b/Lib/unittest.py @@ -251,12 +251,16 @@ class TestResult(object): (_strclass(self.__class__), self.testsRun, len(self.errors), len(self.failures)) + class AssertRaisesContext(object): + def __init__(self, expected, test_case): self.expected = expected self.failureException = test_case.failureException + def __enter__(self): pass + def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: try: @@ -270,6 +274,7 @@ class AssertRaisesContext(object): # Let unexpected exceptions skip through return False + class TestCase(object): """A class whose instances are single test cases. @@ -303,13 +308,13 @@ class TestCase(object): method when executed. Raises a ValueError if the instance does not have a method with the specified name. """ + self._testMethodName = methodName try: - self._testMethodName = methodName testMethod = getattr(self, methodName) - self._testMethodDoc = testMethod.__doc__ except AttributeError: raise ValueError("no such test method in %s: %s" % \ (self.__class__, methodName)) + self._testMethodDoc = testMethod.__doc__ def setUp(self): "Hook method for setting up the test fixture before exercising it." @@ -340,7 +345,7 @@ class TestCase(object): def __eq__(self, other): if type(self) is not type(other): - return False + return NotImplemented return self._testMethodName == other._testMethodName @@ -358,7 +363,8 @@ class TestCase(object): (_strclass(self.__class__), self._testMethodName) def run(self, result=None): - if result is None: result = self.defaultTestResult() + if result is None: + result = self.defaultTestResult() result.startTest(self) testMethod = getattr(self, self._testMethodName) try: @@ -423,11 +429,13 @@ class TestCase(object): def failIf(self, expr, msg=None): "Fail the test if the expression is true." - if expr: raise self.failureException(msg) + if expr: + raise self.failureException(msg) def failUnless(self, expr, msg=None): """Fail the test unless the expression is true.""" - if not expr: raise self.failureException(msg) + if not expr: + raise self.failureException(msg) def failUnlessRaises(self, excClass, callableObj=None, *args, **kwargs): """Fail unless an exception of class excClass is thrown @@ -521,8 +529,6 @@ class TestSuite(object): def __repr__(self): return "<%s tests=%s>" % (_strclass(self.__class__), self._tests) - __str__ = __repr__ - def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented @@ -547,8 +553,7 @@ class TestSuite(object): # sanity checks if not hasattr(test, '__call__'): raise TypeError("the test to add must be callable") - if (isinstance(test, (type, types.ClassType)) and - issubclass(test, (TestCase, TestSuite))): + if isinstance(test, type) and issubclass(test, (TestCase, TestSuite)): raise TypeError("TestCases and TestSuites must be instantiated " "before passing them to addTest()") self._tests.append(test) @@ -571,7 +576,8 @@ class TestSuite(object): def debug(self): """Run the tests without collecting errors in a TestResult""" - for test in self._tests: test.debug() + for test in self._tests: + test.debug() class ClassTestSuite(TestSuite): @@ -614,9 +620,8 @@ class FunctionTestCase(TestCase): always be called if the set-up ('setUp') function ran successfully. """ - def __init__(self, testFunc, setUp=None, tearDown=None, - description=None): - TestCase.__init__(self) + def __init__(self, testFunc, setUp=None, tearDown=None, description=None): + super(FunctionTestCase, self).__init__() self.__setUpFunc = setUp self.__tearDownFunc = tearDown self.__testFunc = testFunc @@ -637,8 +642,8 @@ class FunctionTestCase(TestCase): return self.__testFunc.__name__ def __eq__(self, other): - if type(self) is not type(other): - return False + if not isinstance(other, self.__class__): + return NotImplemented return self.__setUpFunc == other.__setUpFunc and \ self.__tearDownFunc == other.__tearDownFunc and \ @@ -670,8 +675,9 @@ class FunctionTestCase(TestCase): ############################################################################## class TestLoader(object): - """This class is responsible for loading tests according to various - criteria and returning them wrapped in a TestSuite + """ + This class is responsible for loading tests according to various criteria + and returning them wrapped in a TestSuite """ testMethodPrefix = 'test' sortTestMethodsUsing = cmp @@ -681,7 +687,8 @@ class TestLoader(object): def loadTestsFromTestCase(self, testCaseClass): """Return a suite of all tests cases contained in testCaseClass""" if issubclass(testCaseClass, TestSuite): - raise TypeError("Test cases should not be derived from TestSuite. Maybe you meant to derive from TestCase?") + raise TypeError("Test cases should not be derived from TestSuite." \ + " Maybe you meant to derive from TestCase?") testCaseNames = self.getTestCaseNames(testCaseClass) if not testCaseNames and hasattr(testCaseClass, 'runTest'): testCaseNames = ['runTest'] @@ -694,8 +701,7 @@ class TestLoader(object): tests = [] for name in dir(module): obj = getattr(module, name) - if (isinstance(obj, (type, types.ClassType)) and - issubclass(obj, TestCase)): + if isinstance(obj, type) and issubclass(obj, TestCase): tests.append(self.loadTestsFromTestCase(obj)) return self.suiteClass(tests) @@ -717,7 +723,8 @@ class TestLoader(object): break except ImportError: del parts_copy[-1] - if not parts_copy: raise + if not parts_copy: + raise parts = parts[1:] obj = module for part in parts: @@ -725,11 +732,10 @@ class TestLoader(object): if isinstance(obj, types.ModuleType): return self.loadTestsFromModule(obj) - elif (isinstance(obj, (type, types.ClassType)) and - issubclass(obj, TestCase)): + elif isinstance(obj, type) and issubclass(obj, TestCase): return self.loadTestsFromTestCase(obj) elif (isinstance(obj, types.UnboundMethodType) and - isinstance(parent, (type, types.ClassType)) and + isinstance(parent, type) and issubclass(parent, TestCase)): return TestSuite([parent(obj.__name__)]) elif isinstance(obj, TestSuite): @@ -756,8 +762,10 @@ class TestLoader(object): def getTestCaseNames(self, testCaseClass): """Return a sorted sequence of method names found within testCaseClass """ - def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix): - return attrname.startswith(prefix) and hasattr(getattr(testCaseClass, attrname), '__call__') + def isTestMethod(attrname, testCaseClass=testCaseClass, + prefix=self.testMethodPrefix): + return attrname.startswith(prefix) and \ + hasattr(getattr(testCaseClass, attrname), '__call__') testFnNames = filter(isTestMethod, dir(testCaseClass)) if self.sortTestMethodsUsing: testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing)) @@ -815,7 +823,7 @@ class _TextTestResult(TestResult): separator2 = '-' * 70 def __init__(self, stream, descriptions, verbosity): - TestResult.__init__(self) + super(_TextTestResult, self).__init__() self.stream = stream self.showAll = verbosity > 1 self.dots = verbosity == 1 @@ -828,14 +836,14 @@ class _TextTestResult(TestResult): return str(test) def startTest(self, test): - TestResult.startTest(self, test) + super(_TextTestResult, self).startTest(test) if self.showAll: self.stream.write(self.getDescription(test)) self.stream.write(" ... ") self.stream.flush() def addSuccess(self, test): - TestResult.addSuccess(self, test) + super(_TextTestResult, self).addSuccess(test) if self.showAll: self.stream.writeln("ok") elif self.dots: @@ -843,7 +851,7 @@ class _TextTestResult(TestResult): self.stream.flush() def addError(self, test, err): - TestResult.addError(self, test, err) + super(_TextTestResult, self).addError(test, err) if self.showAll: self.stream.writeln("ERROR") elif self.dots: @@ -851,7 +859,7 @@ class _TextTestResult(TestResult): self.stream.flush() def addFailure(self, test, err): - TestResult.addFailure(self, test, err) + super(_TextTestResult, self).addFailure(test, err) if self.showAll: self.stream.writeln("FAIL") elif self.dots: @@ -859,7 +867,7 @@ class _TextTestResult(TestResult): self.stream.flush() def addSkip(self, test, reason): - TestResult.addSkip(self, test, reason) + super(_TextTestResult, self).addSkip(test, reason) if self.showAll: self.stream.writeln("skipped {0!r}".format(reason)) elif self.dots: @@ -867,7 +875,7 @@ class _TextTestResult(TestResult): self.stream.flush() def addExpectedFailure(self, test, err): - TestResult.addExpectedFailure(self, test, err) + super(_TextTestResult, self).addExpectedFailure(test, err) if self.showAll: self.stream.writeln("expected failure") elif self.dots: @@ -875,7 +883,7 @@ class _TextTestResult(TestResult): self.stream.flush() def addUnexpectedSuccess(self, test): - TestResult.addUnexpectedSuccess(self, test) + super(_TextTestResult, self).addUnexpectedSuccess(test) if self.showAll: self.stream.writeln("unexpected success") elif self.dots: @@ -936,13 +944,13 @@ class TextTestRunner(object): if errored: infos.append("errors=%d" % errored) else: - self.stream.write("OK") + self.stream.writeln("OK") if skipped: infos.append("skipped=%d" % skipped) - if expected_fails: - infos.append("expected failures=%d" % expected_fails) - if unexpected_successes: - infos.append("unexpected successes=%d" % unexpected_successes) + if expectedFails: + infos.append("expected failures=%d" % expectedFails) + if unexpectedSuccesses: + infos.append("unexpected successes=%d" % unexpectedSuccesses) if infos: self.stream.writeln(" (%s)" % (", ".join(infos),)) return result @@ -992,7 +1000,8 @@ Examples: self.runTests() def usageExit(self, msg=None): - if msg: print msg + if msg: + print msg print self.USAGE % self.__dict__ sys.exit(2) |