diff options
Diffstat (limited to 'Lib/unittest.py')
| -rw-r--r-- | Lib/unittest.py | 119 |
1 files changed, 89 insertions, 30 deletions
diff --git a/Lib/unittest.py b/Lib/unittest.py index cd91c2c..74f15d6 100644 --- a/Lib/unittest.py +++ b/Lib/unittest.py @@ -25,7 +25,7 @@ Simple usage: Further information is available in the bundled documentation, and from - http://pyunit.sourceforge.net/ + http://docs.python.org/lib/module-unittest.html Copyright (c) 1999-2003 Steve Purcell This module is free software, and you may redistribute it and/or modify @@ -68,7 +68,6 @@ __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases']) # Backward compatibility ############################################################################## if sys.version_info[:2] < (2, 2): - False, True = 0, 1 def isinstance(obj, clsinfo): import __builtin__ if type(clsinfo) in (tuple, list): @@ -79,6 +78,14 @@ if sys.version_info[:2] < (2, 2): return 0 else: return __builtin__.isinstance(obj, clsinfo) +def _CmpToKey(mycmp): + 'Convert a cmp= function into a key= function' + class K(object): + def __init__(self, obj): + self.obj = obj + def __lt__(self, other): + return mycmp(self.obj, other.obj) == -1 + return K ############################################################################## # Test framework core @@ -107,7 +114,7 @@ class TestResult: self.failures = [] self.errors = [] self.testsRun = 0 - self.shouldStop = 0 + self.shouldStop = False def startTest(self, test): "Called when the given test is about to be run" @@ -153,7 +160,7 @@ class TestResult: return ''.join(traceback.format_exception(exctype, value, tb)) def _is_relevant_tb_level(self, tb): - return tb.tb_frame.f_globals.has_key('__unittest') + return '__unittest' in tb.tb_frame.f_globals def _count_relevant_tb_levels(self, tb): length = 0 @@ -235,6 +242,18 @@ class TestCase: def id(self): return "%s.%s" % (_strclass(self.__class__), self._testMethodName) + def __eq__(self, other): + if type(self) is not type(other): + return False + + return self._testMethodName == other._testMethodName + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((type(self), self._testMethodName)) + def __str__(self): return "%s (%s)" % (self._testMethodName, _strclass(self.__class__)) @@ -291,10 +310,7 @@ class TestCase: minimised; usually the top level of the traceback frame is not needed. """ - exctype, excvalue, tb = sys.exc_info() - if sys.platform[:4] == 'java': ## tracebacks look different in Jython - return (exctype, excvalue, tb) - return (exctype, excvalue, tb) + return sys.exc_info() def fail(self, msg=None): """Fail immediately, with the given message.""" @@ -349,7 +365,7 @@ class TestCase: Note that decimal places (from zero) are usually not the same as significant digits (measured from the most signficant digit). """ - if round(second-first, places) != 0: + if round(abs(second-first), places) != 0: raise self.failureException, \ (msg or '%r != %r within %r places' % (first, second, places)) @@ -361,7 +377,7 @@ class TestCase: Note that decimal places (from zero) are usually not the same as significant digits (measured from the most signficant digit). """ - if round(second-first, places) == 0: + if round(abs(second-first), places) == 0: raise self.failureException, \ (msg or '%r == %r within %r places' % (first, second, places)) @@ -401,6 +417,17 @@ class TestSuite: __str__ = __repr__ + def __eq__(self, other): + if type(self) is not type(other): + return False + return self._tests == other._tests + + def __ne__(self, other): + return not self == other + + # Can't guarantee hash invariant, so flag as unhashable + __hash__ = None + def __iter__(self): return iter(self._tests) @@ -412,7 +439,7 @@ class TestSuite: def addTest(self, test): # sanity checks - if not callable(test): + if not hasattr(test, '__call__'): raise TypeError("the test to add must be callable") if (isinstance(test, (type, types.ClassType)) and issubclass(test, (TestCase, TestSuite))): @@ -445,7 +472,7 @@ class FunctionTestCase(TestCase): """A test case that wraps a test function. This is useful for slipping pre-existing test functions into the - PyUnit framework. Optionally, set-up and tidy-up functions can be + unittest framework. Optionally, set-up and tidy-up functions can be supplied. As with TestCase, the tidy-up ('tearDown') function will always be called if the set-up ('setUp') function ran successfully. """ @@ -472,6 +499,22 @@ class FunctionTestCase(TestCase): def id(self): return self.__testFunc.__name__ + def __eq__(self, other): + if type(self) is not type(other): + return False + + return self.__setUpFunc == other.__setUpFunc and \ + self.__tearDownFunc == other.__tearDownFunc and \ + self.__testFunc == other.__testFunc and \ + self.__description == other.__description + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((type(self), self.__setUpFunc, self.__tearDownFunc, + self.__testFunc, self.__description)) + def __str__(self): return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__) @@ -491,7 +534,7 @@ class FunctionTestCase(TestCase): class TestLoader: """This class is responsible for loading tests according to various - criteria and returning them wrapped in a Test + criteria and returning them wrapped in a TestSuite """ testMethodPrefix = 'test' sortTestMethodsUsing = cmp @@ -545,18 +588,23 @@ class TestLoader: elif (isinstance(obj, (type, types.ClassType)) and issubclass(obj, TestCase)): return self.loadTestsFromTestCase(obj) - elif type(obj) == types.UnboundMethodType: - return parent(obj.__name__) + elif (type(obj) == types.UnboundMethodType and + isinstance(parent, (type, types.ClassType)) and + issubclass(parent, TestCase)): + return TestSuite([parent(obj.__name__)]) elif isinstance(obj, TestSuite): return obj - elif callable(obj): + elif hasattr(obj, '__call__'): test = obj() - if not isinstance(test, (TestCase, TestSuite)): - raise ValueError, \ - "calling %s returned %s, not a test" % (obj,test) - return test + if isinstance(test, TestSuite): + return test + elif isinstance(test, TestCase): + return TestSuite([test]) + else: + raise TypeError("calling %s returned %s, not a test" % + (obj, test)) else: - raise ValueError, "don't know how to make test from: %s" % obj + raise TypeError("don't know how to make test from: %s" % obj) def loadTestsFromNames(self, names, module=None): """Return a suite of all tests cases found using the given sequence @@ -569,14 +617,10 @@ class TestLoader: """Return a sorted sequence of method names found within testCaseClass """ def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix): - return attrname.startswith(prefix) and callable(getattr(testCaseClass, attrname)) + return attrname.startswith(prefix) and hasattr(getattr(testCaseClass, attrname), '__call__') testFnNames = filter(isTestMethod, dir(testCaseClass)) - for baseclass in testCaseClass.__bases__: - for testFnName in self.getTestCaseNames(baseclass): - if testFnName not in testFnNames: # handle overridden methods - testFnNames.append(testFnName) if self.sortTestMethodsUsing: - testFnNames.sort(self.sortTestMethodsUsing) + testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing)) return testFnNames @@ -648,6 +692,7 @@ class _TextTestResult(TestResult): if self.showAll: self.stream.write(self.getDescription(test)) self.stream.write(" ... ") + self.stream.flush() def addSuccess(self, test): TestResult.addSuccess(self, test) @@ -655,6 +700,7 @@ class _TextTestResult(TestResult): self.stream.writeln("ok") elif self.dots: self.stream.write('.') + self.stream.flush() def addError(self, test, err): TestResult.addError(self, test, err) @@ -662,6 +708,7 @@ class _TextTestResult(TestResult): self.stream.writeln("ERROR") elif self.dots: self.stream.write('E') + self.stream.flush() def addFailure(self, test, err): TestResult.addFailure(self, test, err) @@ -669,6 +716,7 @@ class _TextTestResult(TestResult): self.stream.writeln("FAIL") elif self.dots: self.stream.write('F') + self.stream.flush() def printErrors(self): if self.dots or self.showAll: @@ -750,7 +798,8 @@ Examples: in MyTestCase """ def __init__(self, module='__main__', defaultTest=None, - argv=None, testRunner=None, testLoader=defaultTestLoader): + argv=None, testRunner=None, + testLoader=defaultTestLoader): if type(module) == type(''): self.module = __import__(module) for part in module.split('.')[1:]: @@ -801,8 +850,18 @@ Examples: def runTests(self): if self.testRunner is None: - self.testRunner = TextTestRunner(verbosity=self.verbosity) - result = self.testRunner.run(self.test) + self.testRunner = TextTestRunner + + if isinstance(self.testRunner, (type, types.ClassType)): + try: + testRunner = self.testRunner(verbosity=self.verbosity) + except TypeError: + # didn't accept the verbosity argument + testRunner = self.testRunner() + else: + # it is assumed to be a TestRunner instance + testRunner = self.testRunner + result = testRunner.run(self.test) sys.exit(not result.wasSuccessful()) main = TestProgram |
