summaryrefslogtreecommitdiffstats
path: root/Lib/unittest.py
diff options
context:
space:
mode:
authorSteve Purcell <steve@pythonconsulting.com>2003-09-22 11:08:12 (GMT)
committerSteve Purcell <steve@pythonconsulting.com>2003-09-22 11:08:12 (GMT)
commit7e74384af5060e131675ef3c286ad98ed3db9128 (patch)
tree964ad46b3acd2263b221312e28cf39a5f60221fd /Lib/unittest.py
parent1e8035973325158eaa134164e7d6882c000774ff (diff)
downloadcpython-7e74384af5060e131675ef3c286ad98ed3db9128.zip
cpython-7e74384af5060e131675ef3c286ad98ed3db9128.tar.gz
cpython-7e74384af5060e131675ef3c286ad98ed3db9128.tar.bz2
- Fixed loading of tests by name when name refers to unbound
method (PyUnit issue 563882, thanks to Alexandre Fayolle) - Ignore non-callable attributes of classes when searching for test method names (PyUnit issue 769338, thanks to Seth Falcon) - New assertTrue and assertFalse aliases for comfort of JUnit users - Automatically discover 'runTest()' test methods (PyUnit issue 469444, thanks to Roeland Rengelink) - Dropped Python 1.5.2 compatibility, merged appropriate shortcuts from Python CVS; should work with Python >= 2.1. - Removed all references to string module by using string methods instead
Diffstat (limited to 'Lib/unittest.py')
-rw-r--r--Lib/unittest.py98
1 files changed, 58 insertions, 40 deletions
diff --git a/Lib/unittest.py b/Lib/unittest.py
index 043b9a8..f44769e 100644
--- a/Lib/unittest.py
+++ b/Lib/unittest.py
@@ -27,7 +27,7 @@ Further information is available in the bundled documentation, and from
http://pyunit.sourceforge.net/
-Copyright (c) 1999, 2000, 2001 Steve Purcell
+Copyright (c) 1999-2003 Steve Purcell
This module is free software, and you may redistribute it and/or modify
it under the same terms as Python itself, so long as this copyright message
and disclaimer are retained in their original form.
@@ -46,12 +46,11 @@ SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
__author__ = "Steve Purcell"
__email__ = "stephen_purcell at yahoo dot com"
-__version__ = "#Revision: 1.46 $"[11:-2]
+__version__ = "#Revision: 1.56 $"[11:-2]
import time
import sys
import traceback
-import string
import os
import types
@@ -61,11 +60,27 @@ import types
__all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner',
'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader']
-# Expose obsolete functions for backwards compatability
+# Expose obsolete functions for backwards compatibility
__all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])
##############################################################################
+# Backward compatibility
+##############################################################################
+if sys.version_info[:2] < (2, 2):
+ False, True = 0, 1
+ def isinstance(obj, clsinfo):
+ import __builtin__
+ if type(clsinfo) in (types.TupleType, types.ListType):
+ for cls in clsinfo:
+ if cls is type: cls = types.ClassType
+ if __builtin__.isinstance(obj, cls):
+ return 1
+ return 0
+ else: return __builtin__.isinstance(obj, clsinfo)
+
+
+##############################################################################
# Test framework core
##############################################################################
@@ -121,11 +136,11 @@ class TestResult:
def stop(self):
"Indicates that the tests should be aborted"
- self.shouldStop = 1
+ self.shouldStop = True
def _exc_info_to_string(self, err):
"""Converts a sys.exc_info()-style tuple of values into a string."""
- return string.join(traceback.format_exception(*err), '')
+ return ''.join(traceback.format_exception(*err))
def __repr__(self):
return "<%s run=%i errors=%i failures=%i>" % \
@@ -196,7 +211,7 @@ class TestCase:
the specified test method's docstring.
"""
doc = self.__testMethodDoc
- return doc and string.strip(string.split(doc, "\n")[0]) or None
+ return doc and doc.split("\n")[0].strip() or None
def id(self):
return "%s.%s" % (_strclass(self.__class__), self.__testMethodName)
@@ -209,9 +224,6 @@ class TestCase:
(_strclass(self.__class__), self.__testMethodName)
def run(self, result=None):
- return self(result)
-
- def __call__(self, result=None):
if result is None: result = self.defaultTestResult()
result.startTest(self)
testMethod = getattr(self, self.__testMethodName)
@@ -224,10 +236,10 @@ class TestCase:
result.addError(self, self.__exc_info())
return
- ok = 0
+ ok = False
try:
testMethod()
- ok = 1
+ ok = True
except self.failureException:
result.addFailure(self, self.__exc_info())
except KeyboardInterrupt:
@@ -241,11 +253,13 @@ class TestCase:
raise
except:
result.addError(self, self.__exc_info())
- ok = 0
+ ok = False
if ok: result.addSuccess(self)
finally:
result.stopTest(self)
+ __call__ = run
+
def debug(self):
"""Run the test without collecting errors in a TestResult"""
self.setUp()
@@ -292,7 +306,7 @@ class TestCase:
else:
if hasattr(excClass,'__name__'): excName = excClass.__name__
else: excName = str(excClass)
- raise self.failureException, excName
+ raise self.failureException, "%s not raised" % excName
def failUnlessEqual(self, first, second, msg=None):
"""Fail if the two objects are unequal as determined by the '=='
@@ -334,6 +348,8 @@ class TestCase:
raise self.failureException, \
(msg or '%s == %s within %s places' % (`first`, `second`, `places`))
+ # Synonyms for assertion methods
+
assertEqual = assertEquals = failUnlessEqual
assertNotEqual = assertNotEquals = failIfEqual
@@ -344,7 +360,9 @@ class TestCase:
assertRaises = failUnlessRaises
- assert_ = failUnless
+ assert_ = assertTrue = failUnless
+
+ assertFalse = failIf
@@ -369,7 +387,7 @@ class TestSuite:
def countTestCases(self):
cases = 0
for test in self._tests:
- cases = cases + test.countTestCases()
+ cases += test.countTestCases()
return cases
def addTest(self, test):
@@ -434,7 +452,7 @@ class FunctionTestCase(TestCase):
def shortDescription(self):
if self.__description is not None: return self.__description
doc = self.__testFunc.__doc__
- return doc and string.strip(string.split(doc, "\n")[0]) or None
+ return doc and doc.split("\n")[0].strip() or None
@@ -452,8 +470,10 @@ class TestLoader:
def loadTestsFromTestCase(self, testCaseClass):
"""Return a suite of all tests cases contained in testCaseClass"""
- return self.suiteClass(map(testCaseClass,
- self.getTestCaseNames(testCaseClass)))
+ testCaseNames = self.getTestCaseNames(testCaseClass)
+ if not testCaseNames and hasattr(testCaseClass, 'runTest'):
+ testCaseNames = ['runTest']
+ return self.suiteClass(map(testCaseClass, testCaseNames))
def loadTestsFromModule(self, module):
"""Return a suite of all tests cases contained in the given module"""
@@ -474,23 +494,20 @@ class TestLoader:
The method optionally resolves the names relative to a given module.
"""
- parts = string.split(name, '.')
+ parts = name.split('.')
if module is None:
- if not parts:
- raise ValueError, "incomplete test name: %s" % name
- else:
- parts_copy = parts[:]
- while parts_copy:
- try:
- module = __import__(string.join(parts_copy,'.'))
- break
- except ImportError:
- del parts_copy[-1]
- if not parts_copy: raise
+ parts_copy = parts[:]
+ while parts_copy:
+ try:
+ module = __import__('.'.join(parts_copy))
+ break
+ except ImportError:
+ del parts_copy[-1]
+ if not parts_copy: raise
parts = parts[1:]
obj = module
for part in parts:
- obj = getattr(obj, part)
+ parent, obj = obj, getattr(obj, part)
import unittest
if type(obj) == types.ModuleType:
@@ -499,11 +516,13 @@ class TestLoader:
issubclass(obj, unittest.TestCase)):
return self.loadTestsFromTestCase(obj)
elif type(obj) == types.UnboundMethodType:
+ return parent(obj.__name__)
return obj.im_class(obj.__name__)
+ elif isinstance(obj, unittest.TestSuite):
+ return obj
elif callable(obj):
test = obj()
- if not isinstance(test, unittest.TestCase) and \
- not isinstance(test, unittest.TestSuite):
+ if not isinstance(test, (unittest.TestCase, unittest.TestSuite)):
raise ValueError, \
"calling %s returned %s, not a test" % (obj,test)
return test
@@ -514,16 +533,15 @@ class TestLoader:
"""Return a suite of all tests cases found using the given sequence
of string specifiers. See 'loadTestsFromName()'.
"""
- suites = []
- for name in names:
- suites.append(self.loadTestsFromName(name, module))
+ suites = [self.loadTestsFromName(name, module) for name in names]
return self.suiteClass(suites)
def getTestCaseNames(self, testCaseClass):
"""Return a sorted sequence of method names found within testCaseClass
"""
- testFnNames = filter(lambda n,p=self.testMethodPrefix: n[:len(p)] == p,
- dir(testCaseClass))
+ def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix):
+ return attrname[:len(prefix)] == prefix and callable(getattr(testCaseClass, attrname))
+ testFnNames = filter(isTestMethod, dir(testCaseClass))
for baseclass in testCaseClass.__bases__:
for testFnName in self.getTestCaseNames(baseclass):
if testFnName not in testFnNames: # handle overridden methods
@@ -706,7 +724,7 @@ Examples:
argv=None, testRunner=None, testLoader=defaultTestLoader):
if type(module) == type(''):
self.module = __import__(module)
- for part in string.split(module,'.')[1:]:
+ for part in module.split('.')[1:]:
self.module = getattr(self.module, part)
else:
self.module = module