diff options
Diffstat (limited to 'Lib/unittest')
-rw-r--r-- | Lib/unittest/case.py | 30 | ||||
-rw-r--r-- | Lib/unittest/test/test_assertions.py | 6 | ||||
-rw-r--r-- | Lib/unittest/util.py | 58 |
3 files changed, 72 insertions, 22 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 02dbd7e..235af82 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -10,7 +10,8 @@ import collections from . import result from .util import (strclass, safe_repr, sorted_list_difference, - unorderable_list_difference) + unorderable_list_difference, _count_diff_all_purpose, + _count_diff_hashable) __unittest = True @@ -1022,23 +1023,22 @@ class TestCase(object): expected = collections.Counter(expected_seq) except TypeError: # Handle case with unhashable elements - missing, unexpected = unorderable_list_difference(expected_seq, actual_seq) + differences = _count_diff_all_purpose(expected_seq, actual_seq) else: if actual == expected: return - missing = list(expected - actual) - unexpected = list(actual - expected) - - errors = [] - if missing: - errors.append('Expected, but missing:\n %s' % - safe_repr(missing)) - if unexpected: - errors.append('Unexpected, but present:\n %s' % - safe_repr(unexpected)) - if errors: - standardMsg = '\n'.join(errors) - self.fail(self._formatMessage(msg, standardMsg)) + differences = _count_diff_hashable(expected_seq, actual_seq) + + if differences: + standardMsg = 'Element counts were not equal:\n' + lines = [] + for act, exp, elem in differences: + line = 'Expected %d, got %d: %r' % (exp, act, elem) + lines.append(line) + diffMsg = '\n'.join(lines) + standardMsg = self._truncateMessage(standardMsg, diffMsg) + msg = self._formatMessage(msg, standardMsg) + self.fail(msg) def assertMultiLineEqual(self, first, second, msg=None): """Assert that two multi-line strings are equal.""" diff --git a/Lib/unittest/test/test_assertions.py b/Lib/unittest/test/test_assertions.py index c81db24..e5dc9fa 100644 --- a/Lib/unittest/test/test_assertions.py +++ b/Lib/unittest/test/test_assertions.py @@ -229,12 +229,6 @@ class TestLongMessage(unittest.TestCase): "^Missing: 'key'$", "^Missing: 'key' : oops$"]) - def testassertCountEqual(self): - self.assertMessages('assertCountEqual', ([], [None]), - [r"\[None\]$", "^oops$", - r"\[None\]$", - r"\[None\] : oops$"]) - def testAssertMultiLineEqual(self): self.assertMessages('assertMultiLineEqual', ("", "foo"), [r"\+ foo$", "^oops$", diff --git a/Lib/unittest/util.py b/Lib/unittest/util.py index c3f4a2d..0407ae9 100644 --- a/Lib/unittest/util.py +++ b/Lib/unittest/util.py @@ -1,5 +1,7 @@ """Various utility functions.""" +from collections import namedtuple, Counter + __unittest = True _MAX_LENGTH = 80 @@ -12,7 +14,6 @@ def safe_repr(obj, short=False): return result return result[:_MAX_LENGTH] + ' [truncated]...' - def strclass(cls): return "%s.%s" % (cls.__module__, cls.__name__) @@ -77,3 +78,58 @@ def unorderable_list_difference(expected, actual): def three_way_cmp(x, y): """Return -1 if x < y, 0 if x == y and 1 if x > y""" return (x > y) - (x < y) + +_Mismatch = namedtuple('Mismatch', 'actual expected value') + +def _count_diff_all_purpose(actual, expected): + 'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ' + # elements need not be hashable + s, t = list(actual), list(expected) + m, n = len(s), len(t) + NULL = object() + result = [] + for i, elem in enumerate(s): + if elem is NULL: + continue + cnt_s = cnt_t = 0 + for j in range(i, m): + if s[j] == elem: + cnt_s += 1 + s[j] = NULL + for j, other_elem in enumerate(t): + if other_elem == elem: + cnt_t += 1 + t[j] = NULL + if cnt_s != cnt_t: + diff = _Mismatch(cnt_s, cnt_t, elem) + result.append(diff) + + for i, elem in enumerate(t): + if elem is NULL: + continue + cnt_t = 0 + for j in range(i, n): + if t[j] == elem: + cnt_t += 1 + t[j] = NULL + diff = _Mismatch(0, cnt_t, elem) + result.append(diff) + return result + +def _count_diff_hashable(actual, expected): + 'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ' + # elements must be hashable + s, t = Counter(actual), Counter(expected) + if s == t: + return [] + result = [] + for elem, cnt_s in s.items(): + cnt_t = t[elem] + if cnt_s != cnt_t: + diff = _Mismatch(cnt_s, cnt_t, elem) + result.append(diff) + for elem, cnt_t in t.items(): + if elem not in s: + diff = _Mismatch(0, cnt_t, elem) + result.append(diff) + return result |