summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/unittest/case.py30
-rw-r--r--Lib/unittest/test/test_assertions.py6
-rw-r--r--Lib/unittest/util.py58
3 files changed, 72 insertions, 22 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index 02dbd7e..235af82 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -10,7 +10,8 @@ import collections
from . import result
from .util import (strclass, safe_repr, sorted_list_difference,
- unorderable_list_difference)
+ unorderable_list_difference, _count_diff_all_purpose,
+ _count_diff_hashable)
__unittest = True
@@ -1022,23 +1023,22 @@ class TestCase(object):
expected = collections.Counter(expected_seq)
except TypeError:
# Handle case with unhashable elements
- missing, unexpected = unorderable_list_difference(expected_seq, actual_seq)
+ differences = _count_diff_all_purpose(expected_seq, actual_seq)
else:
if actual == expected:
return
- missing = list(expected - actual)
- unexpected = list(actual - expected)
-
- errors = []
- if missing:
- errors.append('Expected, but missing:\n %s' %
- safe_repr(missing))
- if unexpected:
- errors.append('Unexpected, but present:\n %s' %
- safe_repr(unexpected))
- if errors:
- standardMsg = '\n'.join(errors)
- self.fail(self._formatMessage(msg, standardMsg))
+ differences = _count_diff_hashable(expected_seq, actual_seq)
+
+ if differences:
+ standardMsg = 'Element counts were not equal:\n'
+ lines = []
+ for act, exp, elem in differences:
+ line = 'Expected %d, got %d: %r' % (exp, act, elem)
+ lines.append(line)
+ diffMsg = '\n'.join(lines)
+ standardMsg = self._truncateMessage(standardMsg, diffMsg)
+ msg = self._formatMessage(msg, standardMsg)
+ self.fail(msg)
def assertMultiLineEqual(self, first, second, msg=None):
"""Assert that two multi-line strings are equal."""
diff --git a/Lib/unittest/test/test_assertions.py b/Lib/unittest/test/test_assertions.py
index c81db24..e5dc9fa 100644
--- a/Lib/unittest/test/test_assertions.py
+++ b/Lib/unittest/test/test_assertions.py
@@ -229,12 +229,6 @@ class TestLongMessage(unittest.TestCase):
"^Missing: 'key'$",
"^Missing: 'key' : oops$"])
- def testassertCountEqual(self):
- self.assertMessages('assertCountEqual', ([], [None]),
- [r"\[None\]$", "^oops$",
- r"\[None\]$",
- r"\[None\] : oops$"])
-
def testAssertMultiLineEqual(self):
self.assertMessages('assertMultiLineEqual', ("", "foo"),
[r"\+ foo$", "^oops$",
diff --git a/Lib/unittest/util.py b/Lib/unittest/util.py
index c3f4a2d..0407ae9 100644
--- a/Lib/unittest/util.py
+++ b/Lib/unittest/util.py
@@ -1,5 +1,7 @@
"""Various utility functions."""
+from collections import namedtuple, Counter
+
__unittest = True
_MAX_LENGTH = 80
@@ -12,7 +14,6 @@ def safe_repr(obj, short=False):
return result
return result[:_MAX_LENGTH] + ' [truncated]...'
-
def strclass(cls):
return "%s.%s" % (cls.__module__, cls.__name__)
@@ -77,3 +78,58 @@ def unorderable_list_difference(expected, actual):
def three_way_cmp(x, y):
"""Return -1 if x < y, 0 if x == y and 1 if x > y"""
return (x > y) - (x < y)
+
+_Mismatch = namedtuple('Mismatch', 'actual expected value')
+
+def _count_diff_all_purpose(actual, expected):
+ 'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ'
+ # elements need not be hashable
+ s, t = list(actual), list(expected)
+ m, n = len(s), len(t)
+ NULL = object()
+ result = []
+ for i, elem in enumerate(s):
+ if elem is NULL:
+ continue
+ cnt_s = cnt_t = 0
+ for j in range(i, m):
+ if s[j] == elem:
+ cnt_s += 1
+ s[j] = NULL
+ for j, other_elem in enumerate(t):
+ if other_elem == elem:
+ cnt_t += 1
+ t[j] = NULL
+ if cnt_s != cnt_t:
+ diff = _Mismatch(cnt_s, cnt_t, elem)
+ result.append(diff)
+
+ for i, elem in enumerate(t):
+ if elem is NULL:
+ continue
+ cnt_t = 0
+ for j in range(i, n):
+ if t[j] == elem:
+ cnt_t += 1
+ t[j] = NULL
+ diff = _Mismatch(0, cnt_t, elem)
+ result.append(diff)
+ return result
+
+def _count_diff_hashable(actual, expected):
+ 'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ'
+ # elements must be hashable
+ s, t = Counter(actual), Counter(expected)
+ if s == t:
+ return []
+ result = []
+ for elem, cnt_s in s.items():
+ cnt_t = t[elem]
+ if cnt_s != cnt_t:
+ diff = _Mismatch(cnt_s, cnt_t, elem)
+ result.append(diff)
+ for elem, cnt_t in t.items():
+ if elem not in s:
+ diff = _Mismatch(0, cnt_t, elem)
+ result.append(diff)
+ return result