summaryrefslogtreecommitdiffstats
path: root/Lib/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/unittest.py')
-rw-r--r--Lib/unittest.py119
1 files changed, 89 insertions, 30 deletions
diff --git a/Lib/unittest.py b/Lib/unittest.py
index cd91c2c..74f15d6 100644
--- a/Lib/unittest.py
+++ b/Lib/unittest.py
@@ -25,7 +25,7 @@ Simple usage:
Further information is available in the bundled documentation, and from
- http://pyunit.sourceforge.net/
+ http://docs.python.org/lib/module-unittest.html
Copyright (c) 1999-2003 Steve Purcell
This module is free software, and you may redistribute it and/or modify
@@ -68,7 +68,6 @@ __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])
# Backward compatibility
##############################################################################
if sys.version_info[:2] < (2, 2):
- False, True = 0, 1
def isinstance(obj, clsinfo):
import __builtin__
if type(clsinfo) in (tuple, list):
@@ -79,6 +78,14 @@ if sys.version_info[:2] < (2, 2):
return 0
else: return __builtin__.isinstance(obj, clsinfo)
+def _CmpToKey(mycmp):
+ 'Convert a cmp= function into a key= function'
+ class K(object):
+ def __init__(self, obj):
+ self.obj = obj
+ def __lt__(self, other):
+ return mycmp(self.obj, other.obj) == -1
+ return K
##############################################################################
# Test framework core
@@ -107,7 +114,7 @@ class TestResult:
self.failures = []
self.errors = []
self.testsRun = 0
- self.shouldStop = 0
+ self.shouldStop = False
def startTest(self, test):
"Called when the given test is about to be run"
@@ -153,7 +160,7 @@ class TestResult:
return ''.join(traceback.format_exception(exctype, value, tb))
def _is_relevant_tb_level(self, tb):
- return tb.tb_frame.f_globals.has_key('__unittest')
+ return '__unittest' in tb.tb_frame.f_globals
def _count_relevant_tb_levels(self, tb):
length = 0
@@ -235,6 +242,18 @@ class TestCase:
def id(self):
return "%s.%s" % (_strclass(self.__class__), self._testMethodName)
+ def __eq__(self, other):
+ if type(self) is not type(other):
+ return False
+
+ return self._testMethodName == other._testMethodName
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ return hash((type(self), self._testMethodName))
+
def __str__(self):
return "%s (%s)" % (self._testMethodName, _strclass(self.__class__))
@@ -291,10 +310,7 @@ class TestCase:
minimised; usually the top level of the traceback frame is not
needed.
"""
- exctype, excvalue, tb = sys.exc_info()
- if sys.platform[:4] == 'java': ## tracebacks look different in Jython
- return (exctype, excvalue, tb)
- return (exctype, excvalue, tb)
+ return sys.exc_info()
def fail(self, msg=None):
"""Fail immediately, with the given message."""
@@ -349,7 +365,7 @@ class TestCase:
Note that decimal places (from zero) are usually not the same
as significant digits (measured from the most signficant digit).
"""
- if round(second-first, places) != 0:
+ if round(abs(second-first), places) != 0:
raise self.failureException, \
(msg or '%r != %r within %r places' % (first, second, places))
@@ -361,7 +377,7 @@ class TestCase:
Note that decimal places (from zero) are usually not the same
as significant digits (measured from the most signficant digit).
"""
- if round(second-first, places) == 0:
+ if round(abs(second-first), places) == 0:
raise self.failureException, \
(msg or '%r == %r within %r places' % (first, second, places))
@@ -401,6 +417,17 @@ class TestSuite:
__str__ = __repr__
+ def __eq__(self, other):
+ if type(self) is not type(other):
+ return False
+ return self._tests == other._tests
+
+ def __ne__(self, other):
+ return not self == other
+
+ # Can't guarantee hash invariant, so flag as unhashable
+ __hash__ = None
+
def __iter__(self):
return iter(self._tests)
@@ -412,7 +439,7 @@ class TestSuite:
def addTest(self, test):
# sanity checks
- if not callable(test):
+ if not hasattr(test, '__call__'):
raise TypeError("the test to add must be callable")
if (isinstance(test, (type, types.ClassType)) and
issubclass(test, (TestCase, TestSuite))):
@@ -445,7 +472,7 @@ class FunctionTestCase(TestCase):
"""A test case that wraps a test function.
This is useful for slipping pre-existing test functions into the
- PyUnit framework. Optionally, set-up and tidy-up functions can be
+ unittest framework. Optionally, set-up and tidy-up functions can be
supplied. As with TestCase, the tidy-up ('tearDown') function will
always be called if the set-up ('setUp') function ran successfully.
"""
@@ -472,6 +499,22 @@ class FunctionTestCase(TestCase):
def id(self):
return self.__testFunc.__name__
+ def __eq__(self, other):
+ if type(self) is not type(other):
+ return False
+
+ return self.__setUpFunc == other.__setUpFunc and \
+ self.__tearDownFunc == other.__tearDownFunc and \
+ self.__testFunc == other.__testFunc and \
+ self.__description == other.__description
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ return hash((type(self), self.__setUpFunc, self.__tearDownFunc,
+ self.__testFunc, self.__description))
+
def __str__(self):
return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__)
@@ -491,7 +534,7 @@ class FunctionTestCase(TestCase):
class TestLoader:
"""This class is responsible for loading tests according to various
- criteria and returning them wrapped in a Test
+ criteria and returning them wrapped in a TestSuite
"""
testMethodPrefix = 'test'
sortTestMethodsUsing = cmp
@@ -545,18 +588,23 @@ class TestLoader:
elif (isinstance(obj, (type, types.ClassType)) and
issubclass(obj, TestCase)):
return self.loadTestsFromTestCase(obj)
- elif type(obj) == types.UnboundMethodType:
- return parent(obj.__name__)
+ elif (type(obj) == types.UnboundMethodType and
+ isinstance(parent, (type, types.ClassType)) and
+ issubclass(parent, TestCase)):
+ return TestSuite([parent(obj.__name__)])
elif isinstance(obj, TestSuite):
return obj
- elif callable(obj):
+ elif hasattr(obj, '__call__'):
test = obj()
- if not isinstance(test, (TestCase, TestSuite)):
- raise ValueError, \
- "calling %s returned %s, not a test" % (obj,test)
- return test
+ if isinstance(test, TestSuite):
+ return test
+ elif isinstance(test, TestCase):
+ return TestSuite([test])
+ else:
+ raise TypeError("calling %s returned %s, not a test" %
+ (obj, test))
else:
- raise ValueError, "don't know how to make test from: %s" % obj
+ raise TypeError("don't know how to make test from: %s" % obj)
def loadTestsFromNames(self, names, module=None):
"""Return a suite of all tests cases found using the given sequence
@@ -569,14 +617,10 @@ class TestLoader:
"""Return a sorted sequence of method names found within testCaseClass
"""
def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix):
- return attrname.startswith(prefix) and callable(getattr(testCaseClass, attrname))
+ return attrname.startswith(prefix) and hasattr(getattr(testCaseClass, attrname), '__call__')
testFnNames = filter(isTestMethod, dir(testCaseClass))
- for baseclass in testCaseClass.__bases__:
- for testFnName in self.getTestCaseNames(baseclass):
- if testFnName not in testFnNames: # handle overridden methods
- testFnNames.append(testFnName)
if self.sortTestMethodsUsing:
- testFnNames.sort(self.sortTestMethodsUsing)
+ testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
return testFnNames
@@ -648,6 +692,7 @@ class _TextTestResult(TestResult):
if self.showAll:
self.stream.write(self.getDescription(test))
self.stream.write(" ... ")
+ self.stream.flush()
def addSuccess(self, test):
TestResult.addSuccess(self, test)
@@ -655,6 +700,7 @@ class _TextTestResult(TestResult):
self.stream.writeln("ok")
elif self.dots:
self.stream.write('.')
+ self.stream.flush()
def addError(self, test, err):
TestResult.addError(self, test, err)
@@ -662,6 +708,7 @@ class _TextTestResult(TestResult):
self.stream.writeln("ERROR")
elif self.dots:
self.stream.write('E')
+ self.stream.flush()
def addFailure(self, test, err):
TestResult.addFailure(self, test, err)
@@ -669,6 +716,7 @@ class _TextTestResult(TestResult):
self.stream.writeln("FAIL")
elif self.dots:
self.stream.write('F')
+ self.stream.flush()
def printErrors(self):
if self.dots or self.showAll:
@@ -750,7 +798,8 @@ Examples:
in MyTestCase
"""
def __init__(self, module='__main__', defaultTest=None,
- argv=None, testRunner=None, testLoader=defaultTestLoader):
+ argv=None, testRunner=None,
+ testLoader=defaultTestLoader):
if type(module) == type(''):
self.module = __import__(module)
for part in module.split('.')[1:]:
@@ -801,8 +850,18 @@ Examples:
def runTests(self):
if self.testRunner is None:
- self.testRunner = TextTestRunner(verbosity=self.verbosity)
- result = self.testRunner.run(self.test)
+ self.testRunner = TextTestRunner
+
+ if isinstance(self.testRunner, (type, types.ClassType)):
+ try:
+ testRunner = self.testRunner(verbosity=self.verbosity)
+ except TypeError:
+ # didn't accept the verbosity argument
+ testRunner = self.testRunner()
+ else:
+ # it is assumed to be a TestRunner instance
+ testRunner = self.testRunner
+ result = testRunner.run(self.test)
sys.exit(not result.wasSuccessful())
main = TestProgram