diff options
Diffstat (limited to 'Lib/unittest.py')
-rw-r--r-- | Lib/unittest.py | 607 |
1 files changed, 544 insertions, 63 deletions
diff --git a/Lib/unittest.py b/Lib/unittest.py index ade806c..16a8663 100644 --- a/Lib/unittest.py +++ b/Lib/unittest.py @@ -14,11 +14,11 @@ Simple usage: class IntegerArithmenticTestCase(unittest.TestCase): def testAdd(self): ## test method names begin 'test*' - self.assertEquals((1 + 2), 3) - self.assertEquals(0 + 1, 1) + self.assertEqual((1 + 2), 3) + self.assertEqual(0 + 1, 1) def testMultiply(self): - self.assertEquals((0 * 10), 0) - self.assertEquals((5 * 8), 40) + self.assertEqual((0 * 10), 0) + self.assertEqual((5 * 8), 40) if __name__ == '__main__': unittest.main() @@ -45,12 +45,16 @@ AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. ''' -import time +import difflib +import functools +import os +import pprint +import re import sys +import time import traceback -import os import types -import functools +import warnings ############################################################################## # Exported classes and functions @@ -143,7 +147,6 @@ def expectedFailure(func): raise _UnexpectedSuccess return wrapper - __unittest = 1 class TestResult(object): @@ -239,10 +242,12 @@ class TestResult(object): len(self.failures)) -class AssertRaisesContext(object): +class _AssertRaisesContext(object): + """A context manager used to implement TestCase.assertRaises* methods.""" - def __init__(self, expected, test_case, callable_obj=None): + def __init__(self, expected, test_case, callable_obj=None, + expected_regexp=None): self.expected = expected self.failureException = test_case.failureException if callable_obj is not None: @@ -252,6 +257,7 @@ class AssertRaisesContext(object): self.obj_name = str(callable_obj) else: self.obj_name = None + self.expected_regex = expected_regexp def __enter__(self): pass @@ -268,10 +274,30 @@ class AssertRaisesContext(object): else: raise self.failureException("{0} not raised" .format(exc_name)) - if issubclass(exc_type, self.expected): + if not issubclass(exc_type, self.expected): + # let unexpected exceptions pass through + return False + if self.expected_regex is None: return True - # Let unexpected exceptions skip through - return False + + expected_regexp = self.expected_regex + if isinstance(expected_regexp, (bytes, str)): + expected_regexp = re.compile(expected_regexp) + if not expected_regexp.search(str(exc_value)): + raise self.failureException('"%s" does not match "%s"' % + (expected_regexp.pattern, str(exc_value))) + return True + + +class _AssertWrapper(object): + """Wrap entries in the _type_equality_funcs registry to make them deep + copyable.""" + + def __init__(self, function): + self.function = function + + def __deepcopy__(self, memo): + memo[id(self)] = self class TestCase(object): @@ -302,6 +328,13 @@ class TestCase(object): failureException = AssertionError + # This attribute determines whether long messages (including repr of + # objects used in assert methods) will be printed on failure in *addition* + # to any explicit message passed. + + longMessage = False + + def __init__(self, methodName='runTest'): """Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does @@ -315,6 +348,31 @@ class TestCase(object): (self.__class__, methodName)) self._testMethodDoc = testMethod.__doc__ + # Map types to custom assertEqual functions that will compare + # instances of said type in more detail to generate a more useful + # error message. + self._type_equality_funcs = {} + self.addTypeEqualityFunc(dict, self.assertDictEqual) + self.addTypeEqualityFunc(list, self.assertListEqual) + self.addTypeEqualityFunc(tuple, self.assertTupleEqual) + self.addTypeEqualityFunc(set, self.assertSetEqual) + self.addTypeEqualityFunc(frozenset, self.assertSetEqual) + + def addTypeEqualityFunc(self, typeobj, function): + """Add a type specific assertEqual style function to compare a type. + + This method is for use by TestCase subclasses that need to register + their own type equality functions to provide nicer error messages. + + Args: + typeobj: The data type to call this function on when both values + are of the same type in assertEqual(). + function: The callable taking two arguments and an optional + msg= argument that raises self.failureException with a + useful error message when the two arguments are not equal. + """ + self._type_equality_funcs[typeobj] = _AssertWrapper(function) + def setUp(self): "Hook method for setting up the test fixture before exercising it." pass @@ -330,14 +388,22 @@ class TestCase(object): return TestResult() def shortDescription(self): - """Returns a one-line description of the test, or None if no - description has been provided. + """Returns both the test method name and first line of its docstring. - The default implementation of this method returns the first line of - the specified test method's docstring. + If no docstring is given, only returns the method name. + + This method overrides unittest.TestCase.shortDescription(), which + only returns the first line of the docstring, obscuring the name + of the test upon failure. """ - doc = self._testMethodDoc - return doc and doc.split("\n")[0].strip() or None + desc = str(self) + doc_first_line = None + + if self._testMethodDoc: + doc_first_line = self._testMethodDoc.split("\n")[0].strip() + if doc_first_line: + desc = '\n'.join((desc, doc_first_line)) + return desc def id(self): return "%s.%s" % (_strclass(self.__class__), self._testMethodName) @@ -419,17 +485,36 @@ class TestCase(object): """Fail immediately, with the given message.""" raise self.failureException(msg) - def failIf(self, expr, msg=None): + def assertFalse(self, expr, msg=None): "Fail the test if the expression is true." if expr: + msg = self._formatMessage(msg, "%r is not False" % expr) raise self.failureException(msg) - def failUnless(self, expr, msg=None): + def assertTrue(self, expr, msg=None): """Fail the test unless the expression is true.""" if not expr: + msg = self._formatMessage(msg, "%r is not True" % expr) raise self.failureException(msg) - def failUnlessRaises(self, excClass, callableObj=None, *args, **kwargs): + def _formatMessage(self, msg, standardMsg): + """Honour the longMessage attribute when generating failure messages. + If longMessage is False this means: + * Use only an explicit message if it is provided + * Otherwise use the standard message for the assert + + If longMessage is True: + * Use the standard message + * If an explicit message is provided, plus ' : ' and the explicit message + """ + if not self.longMessage: + return msg or standardMsg + if msg is None: + return standardMsg + return standardMsg + ' : ' + msg + + + def assertRaises(self, excClass, callableObj=None, *args, **kwargs): """Fail unless an exception of class excClass is thrown by callableObj when invoked with arguments args and keyword arguments kwargs. If a different type of exception is @@ -440,30 +525,62 @@ class TestCase(object): If called with callableObj omitted or None, will return a context object used like this:: - with self.failUnlessRaises(some_error_class): + with self.assertRaises(some_error_class): do_something() """ - context = AssertRaisesContext(excClass, self, callableObj) + context = _AssertRaisesContext(excClass, self, callableObj) if callableObj is None: return context with context: callableObj(*args, **kwargs) - def failUnlessEqual(self, first, second, msg=None): + def _getAssertEqualityFunc(self, first, second): + """Get a detailed comparison function for the types of the two args. + + Returns: A callable accepting (first, second, msg=None) that will + raise a failure exception if first != second with a useful human + readable error message for those types. + """ + # + # NOTE(gregory.p.smith): I considered isinstance(first, type(second)) + # and vice versa. I opted for the conservative approach in case + # subclasses are not intended to be compared in detail to their super + # class instances using a type equality func. This means testing + # subtypes won't automagically use the detailed comparison. Callers + # should use their type specific assertSpamEqual method to compare + # subclasses if the detailed comparison is desired and appropriate. + # See the discussion in http://bugs.python.org/issue2578. + # + if type(first) is type(second): + asserter = self._type_equality_funcs.get(type(first)) + if asserter is not None: + return asserter.function + + return self._baseAssertEqual + + def _baseAssertEqual(self, first, second, msg=None): + """The default assertEqual implementation, not type specific.""" + if not first == second: + standardMsg = '%r != %r' % (first, second) + msg = self._formatMessage(msg, standardMsg) + raise self.failureException(msg) + + def assertEqual(self, first, second, msg=None): """Fail if the two objects are unequal as determined by the '==' operator. """ - if not first == second: - raise self.failureException(msg or '%r != %r' % (first, second)) + assertion_func = self._getAssertEqualityFunc(first, second) + assertion_func(first, second, msg=msg) - def failIfEqual(self, first, second, msg=None): + def assertNotEqual(self, first, second, msg=None): """Fail if the two objects are equal as determined by the '==' operator. """ - if first == second: - raise self.failureException(msg or '%r == %r' % (first, second)) + if not first != second: + msg = self._formatMessage(msg, '%r == %r' % (first, second)) + raise self.failureException(msg) - def failUnlessAlmostEqual(self, first, second, *, places=7, msg=None): + def assertAlmostEqual(self, first, second, *, places=7, msg=None): """Fail if the two objects are unequal as determined by their difference rounded to the given number of decimal places (default 7) and comparing to zero. @@ -472,10 +589,11 @@ class TestCase(object): as significant digits (measured from the most signficant digit). """ if round(abs(second-first), places) != 0: - raise self.failureException( - msg or '%r != %r within %r places' % (first, second, places)) + standardMsg = '%r != %r within %r places' % (first, second, places) + msg = self._formatMessage(msg, standardMsg) + raise self.failureException(msg) - def failIfAlmostEqual(self, first, second, *, places=7, msg=None): + def assertNotAlmostEqual(self, first, second, *, places=7, msg=None): """Fail if the two objects are equal as determined by their difference rounded to the given number of decimal places (default 7) and comparing to zero. @@ -484,25 +602,388 @@ class TestCase(object): as significant digits (measured from the most signficant digit). """ if round(abs(second-first), places) == 0: - raise self.failureException( - msg or '%r == %r within %r places' % (first, second, places)) + standardMsg = '%r == %r within %r places' % (first, second, places) + msg = self._formatMessage(msg, standardMsg) + raise self.failureException(msg) # Synonyms for assertion methods - assertEqual = assertEquals = failUnlessEqual + # The plurals are undocumented. Keep them that way to discourage use. + # Do not add more. Do not remove. + # Going through a deprecation cycle on these would annoy many people. + assertEquals = assertEqual + assertNotEquals = assertNotEqual + assertAlmostEquals = assertAlmostEqual + assertNotAlmostEquals = assertNotAlmostEqual + assert_ = assertTrue + + # These fail* assertion method names are pending deprecation and will + # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578 + def _deprecate(original_func): + def deprecated_func(*args, **kwargs): + warnings.warn( + 'Please use {0} instead.'.format(original_func.__name__), + PendingDeprecationWarning, 2) + return original_func(*args, **kwargs) + return deprecated_func + + failUnlessEqual = _deprecate(assertEqual) + failIfEqual = _deprecate(assertNotEqual) + failUnlessAlmostEqual = _deprecate(assertAlmostEqual) + failIfAlmostEqual = _deprecate(assertNotAlmostEqual) + failUnless = _deprecate(assertTrue) + failUnlessRaises = _deprecate(assertRaises) + failIf = _deprecate(assertFalse) + + def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None): + """An equality assertion for ordered sequences (like lists and tuples). + + For the purposes of this function, a valid orderd sequence type is one + which can be indexed, has a length, and has an equality operator. + + Args: + seq1: The first sequence to compare. + seq2: The second sequence to compare. + seq_type: The expected datatype of the sequences, or None if no + datatype should be enforced. + msg: Optional message to use on failure instead of a list of + differences. + """ + if seq_type != None: + seq_type_name = seq_type.__name__ + if not isinstance(seq1, seq_type): + raise self.failureException('First sequence is not a %s: %r' + % (seq_type_name, seq1)) + if not isinstance(seq2, seq_type): + raise self.failureException('Second sequence is not a %s: %r' + % (seq_type_name, seq2)) + else: + seq_type_name = "sequence" + + differing = None + try: + len1 = len(seq1) + except (TypeError, NotImplementedError): + differing = 'First %s has no length. Non-sequence?' % ( + seq_type_name) + + if differing is None: + try: + len2 = len(seq2) + except (TypeError, NotImplementedError): + differing = 'Second %s has no length. Non-sequence?' % ( + seq_type_name) + + if differing is None: + if seq1 == seq2: + return + + for i in range(min(len1, len2)): + try: + item1 = seq1[i] + except (TypeError, IndexError, NotImplementedError): + differing = ('Unable to index element %d of first %s\n' % + (i, seq_type_name)) + break + + try: + item2 = seq2[i] + except (TypeError, IndexError, NotImplementedError): + differing = ('Unable to index element %d of second %s\n' % + (i, seq_type_name)) + break + + if item1 != item2: + differing = ('First differing element %d:\n%s\n%s\n' % + (i, item1, item2)) + break + else: + if (len1 == len2 and seq_type is None and + type(seq1) != type(seq2)): + # The sequences are the same, but have differing types. + return + # A catch-all message for handling arbitrary user-defined + # sequences. + differing = '%ss differ:\n' % seq_type_name.capitalize() + if len1 > len2: + differing = ('First %s contains %d additional ' + 'elements.\n' % (seq_type_name, len1 - len2)) + try: + differing += ('First extra element %d:\n%s\n' % + (len2, seq1[len2])) + except (TypeError, IndexError, NotImplementedError): + differing += ('Unable to index element %d ' + 'of first %s\n' % (len2, seq_type_name)) + elif len1 < len2: + differing = ('Second %s contains %d additional ' + 'elements.\n' % (seq_type_name, len2 - len1)) + try: + differing += ('First extra element %d:\n%s\n' % + (len1, seq2[len1])) + except (TypeError, IndexError, NotImplementedError): + differing += ('Unable to index element %d ' + 'of second %s\n' % (len1, seq_type_name)) + standardMsg = differing + '\n'.join(difflib.ndiff(pprint.pformat(seq1).splitlines(), + pprint.pformat(seq2).splitlines())) + msg = self._formatMessage(msg, standardMsg) + self.fail(msg) + + def assertListEqual(self, list1, list2, msg=None): + """A list-specific equality assertion. + + Args: + list1: The first list to compare. + list2: The second list to compare. + msg: Optional message to use on failure instead of a list of + differences. + + """ + self.assertSequenceEqual(list1, list2, msg, seq_type=list) - assertNotEqual = assertNotEquals = failIfEqual + def assertTupleEqual(self, tuple1, tuple2, msg=None): + """A tuple-specific equality assertion. - assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual + Args: + tuple1: The first tuple to compare. + tuple2: The second tuple to compare. + msg: Optional message to use on failure instead of a list of + differences. + """ + self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple) - assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual + def assertSetEqual(self, set1, set2, msg=None): + """A set-specific equality assertion. - assertRaises = failUnlessRaises + Args: + set1: The first set to compare. + set2: The second set to compare. + msg: Optional message to use on failure instead of a list of + differences. - assert_ = assertTrue = failUnless + For more general containership equality, assertSameElements will work + with things other than sets. This uses ducktyping to support + different types of sets, and is optimized for sets specifically + (parameters must support a difference method). + """ + try: + difference1 = set1.difference(set2) + except TypeError as e: + self.fail('invalid type when attempting set difference: %s' % e) + except AttributeError as e: + self.fail('first argument does not support set difference: %s' % e) - assertFalse = failIf + try: + difference2 = set2.difference(set1) + except TypeError as e: + self.fail('invalid type when attempting set difference: %s' % e) + except AttributeError as e: + self.fail('second argument does not support set difference: %s' % e) + + if not (difference1 or difference2): + return + + lines = [] + if difference1: + lines.append('Items in the first set but not the second:') + for item in difference1: + lines.append(repr(item)) + if difference2: + lines.append('Items in the second set but not the first:') + for item in difference2: + lines.append(repr(item)) + + standardMsg = '\n'.join(lines) + self.fail(self._formatMessage(msg, standardMsg)) + + def assertIn(self, member, container, msg=None): + """Just like self.assertTrue(a in b), but with a nicer default message.""" + if member not in container: + standardMsg = '%r not found in %r' % (member, container) + self.fail(self._formatMessage(msg, standardMsg)) + + def assertNotIn(self, member, container, msg=None): + """Just like self.assertTrue(a not in b), but with a nicer default message.""" + if member in container: + standardMsg = '%r unexpectedly found in %r' % (member, container) + self.fail(self._formatMessage(msg, standardMsg)) + + def assertDictEqual(self, d1, d2, msg=None): + self.assert_(isinstance(d1, dict), 'First argument is not a dictionary') + self.assert_(isinstance(d2, dict), 'Second argument is not a dictionary') + + if d1 != d2: + standardMsg = ('\n' + '\n'.join(difflib.ndiff( + pprint.pformat(d1).splitlines(), + pprint.pformat(d2).splitlines()))) + self.fail(self._formatMessage(msg, standardMsg)) + + def assertDictContainsSubset(self, expected, actual, msg=None): + """Checks whether actual is a superset of expected.""" + missing = [] + mismatched = [] + for key, value in expected.items(): + if key not in actual: + missing.append(key) + elif value != actual[key]: + mismatched.append('%s, expected: %s, actual: %s' % (key, value, actual[key])) + + if not (missing or mismatched): + return + + standardMsg = '' + if missing: + standardMsg = 'Missing: %r' % ','.join(missing) + if mismatched: + if standardMsg: + standardMsg += '; ' + standardMsg += 'Mismatched values: %s' % ','.join(mismatched) + + self.fail(self._formatMessage(msg, standardMsg)) + + def assertSameElements(self, expected_seq, actual_seq, msg=None): + """An unordered sequence specific comparison. + + Raises with an error message listing which elements of expected_seq + are missing from actual_seq and vice versa if any. + """ + try: + expected = set(expected_seq) + actual = set(actual_seq) + missing = list(expected.difference(actual)) + unexpected = list(actual.difference(expected)) + missing.sort() + unexpected.sort() + except TypeError: + # Fall back to slower list-compare if any of the objects are + # not hashable. + expected = list(expected_seq) + actual = list(actual_seq) + expected.sort() + actual.sort() + missing, unexpected = _SortedListDifference(expected, actual) + errors = [] + if missing: + errors.append('Expected, but missing:\n %r' % missing) + if unexpected: + errors.append('Unexpected, but present:\n %r' % unexpected) + if errors: + standardMsg = '\n'.join(errors) + self.fail(self._formatMessage(msg, standardMsg)) + + def assertMultiLineEqual(self, first, second, msg=None): + """Assert that two multi-line strings are equal.""" + self.assert_(isinstance(first, str), ( + 'First argument is not a string')) + self.assert_(isinstance(second, str), ( + 'Second argument is not a string')) + + if first != second: + standardMsg = '\n' + ''.join(difflib.ndiff(first.splitlines(True), second.splitlines(True))) + self.fail(self._formatMessage(msg, standardMsg)) + + def assertLess(self, a, b, msg=None): + """Just like self.assertTrue(a < b), but with a nicer default message.""" + if not a < b: + standardMsg = '%r not less than %r' % (a, b) + self.fail(self._formatMessage(msg, standardMsg)) + + def assertLessEqual(self, a, b, msg=None): + """Just like self.assertTrue(a <= b), but with a nicer default message.""" + if not a <= b: + standardMsg = '%r not less than or equal to %r' % (a, b) + self.fail(self._formatMessage(msg, standardMsg)) + + def assertGreater(self, a, b, msg=None): + """Just like self.assertTrue(a > b), but with a nicer default message.""" + if not a > b: + standardMsg = '%r not greater than %r' % (a, b) + self.fail(self._formatMessage(msg, standardMsg)) + + def assertGreaterEqual(self, a, b, msg=None): + """Just like self.assertTrue(a >= b), but with a nicer default message.""" + if not a >= b: + standardMsg = '%r not greater than or equal to %r' % (a, b) + self.fail(self._formatMessage(msg, standardMsg)) + + def assertIsNone(self, obj, msg=None): + """Same as self.assertTrue(obj is None), with a nicer default message.""" + if obj is not None: + standardMsg = '%r is not None' % obj + self.fail(self._formatMessage(msg, standardMsg)) + + def assertIsNotNone(self, obj, msg=None): + """Included for symmetry with assertIsNone.""" + if obj is None: + standardMsg = 'unexpectedly None' + self.fail(self._formatMessage(msg, standardMsg)) + + def assertRaisesRegexp(self, expected_exception, expected_regexp, + callable_obj=None, *args, **kwargs): + """Asserts that the message in a raised exception matches a regexp. + + Args: + expected_exception: Exception class expected to be raised. + expected_regexp: Regexp (re pattern object or string) expected + to be found in error message. + callable_obj: Function to be called. + args: Extra args. + kwargs: Extra kwargs. + """ + context = _AssertRaisesContext(expected_exception, self, callable_obj, + expected_regexp) + if callable_obj is None: + return context + with context: + callable_obj(*args, **kwargs) + + def assertRegexpMatches(self, text, expected_regex, msg=None): + if isinstance(expected_regex, (str, bytes)): + expected_regex = re.compile(expected_regex) + if not expected_regex.search(text): + msg = msg or "Regexp didn't match" + msg = '%s: %r not found in %r' % (msg, expected_regex.pattern, text) + raise self.failureException(msg) + + +def _SortedListDifference(expected, actual): + """Finds elements in only one or the other of two, sorted input lists. + Returns a two-element tuple of lists. The first list contains those + elements in the "expected" list but not in the "actual" list, and the + second contains those elements in the "actual" list but not in the + "expected" list. Duplicate elements in either input list are ignored. + """ + i = j = 0 + missing = [] + unexpected = [] + while True: + try: + e = expected[i] + a = actual[j] + if e < a: + missing.append(e) + i += 1 + while expected[i] == e: + i += 1 + elif e > a: + unexpected.append(a) + j += 1 + while actual[j] == a: + j += 1 + else: + i += 1 + try: + while expected[i] == e: + i += 1 + finally: + j += 1 + while actual[j] == a: + j += 1 + except IndexError: + missing.extend(expected[i:]) + unexpected.extend(actual[j:]) + break + return missing, unexpected class TestSuite(object): @@ -611,52 +1092,52 @@ class FunctionTestCase(TestCase): def __init__(self, testFunc, setUp=None, tearDown=None, description=None): super(FunctionTestCase, self).__init__() - self.__setUpFunc = setUp - self.__tearDownFunc = tearDown - self.__testFunc = testFunc - self.__description = description + self._setUpFunc = setUp + self._tearDownFunc = tearDown + self._testFunc = testFunc + self._description = description def setUp(self): - if self.__setUpFunc is not None: - self.__setUpFunc() + if self._setUpFunc is not None: + self._setUpFunc() def tearDown(self): - if self.__tearDownFunc is not None: - self.__tearDownFunc() + if self._tearDownFunc is not None: + self._tearDownFunc() def runTest(self): - self.__testFunc() + self._testFunc() def id(self): - return self.__testFunc.__name__ + return self._testFunc.__name__ def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented - return self.__setUpFunc == other.__setUpFunc and \ - self.__tearDownFunc == other.__tearDownFunc and \ - self.__testFunc == other.__testFunc and \ - self.__description == other.__description + return self._setUpFunc == other._setUpFunc and \ + self._tearDownFunc == other._tearDownFunc and \ + self._testFunc == other._testFunc and \ + self._description == other._description def __ne__(self, other): return not self == other def __hash__(self): - return hash((type(self), self.__setUpFunc, self.__tearDownFunc, - self.__testFunc, self.__description)) + return hash((type(self), self._setUpFunc, self._tearDownFunc, + self._testFunc, self._description)) def __str__(self): return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__) def __repr__(self): - return "<%s testFunc=%s>" % (_strclass(self.__class__), - self.__testFunc) + return "<%s testFunc=%s>" % (_strclass(self.__class__), self._testFunc) def shortDescription(self): - if self.__description is not None: return self.__description - doc = self.__testFunc.__doc__ + if self._description is not None: + return self._description + doc = self._testFunc.__doc__ return doc and doc.split("\n")[0].strip() or None |