diff options
Diffstat (limited to 'Lib/test/test_collections.py')
-rw-r--r-- | Lib/test/test_collections.py | 49 |
1 files changed, 48 insertions, 1 deletions
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 1bd49d9..e595b75 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -1,6 +1,6 @@ """Unit tests for collections.py.""" -import unittest, doctest +import unittest, doctest, operator import inspect from test import support from collections import namedtuple, Counter, OrderedDict @@ -246,6 +246,37 @@ class ABCTestCase(unittest.TestCase): self.assertNotIsInstance(C(), abc) self.assertFalse(issubclass(C, abc)) + def validate_comparison(self, instance): + ops = ['lt', 'gt', 'le', 'ge', 'ne', 'or', 'and', 'xor', 'sub'] + operators = {} + for op in ops: + name = '__' + op + '__' + operators[name] = getattr(operator, name) + + class Other: + def __init__(self): + self.right_side = False + def __eq__(self, other): + self.right_side = True + return True + __lt__ = __eq__ + __gt__ = __eq__ + __le__ = __eq__ + __ge__ = __eq__ + __ne__ = __eq__ + __ror__ = __eq__ + __rand__ = __eq__ + __rxor__ = __eq__ + __rsub__ = __eq__ + + for name, op in operators.items(): + if not hasattr(instance, name): + continue + other = Other() + op(instance, other) + self.assertTrue(other.right_side,'Right side not called for %s.%s' + % (type(instance), name)) + class TestOneTrickPonyABCs(ABCTestCase): def test_Hashable(self): @@ -420,6 +451,14 @@ class TestCollectionABCs(ABCTestCase): self.assertIsInstance(sample(), Set) self.assertTrue(issubclass(sample, Set)) self.validate_abstract_methods(Set, '__contains__', '__iter__', '__len__') + class MySet(Set): + def __contains__(self, x): + return False + def __len__(self): + return 0 + def __iter__(self): + return iter([]) + self.validate_comparison(MySet()) def test_hash_Set(self): class OneTwoThreeSet(Set): @@ -483,6 +522,14 @@ class TestCollectionABCs(ABCTestCase): self.assertTrue(issubclass(sample, Mapping)) self.validate_abstract_methods(Mapping, '__contains__', '__iter__', '__len__', '__getitem__') + class MyMapping(collections.Mapping): + def __len__(self): + return 0 + def __getitem__(self, i): + raise IndexError + def __iter__(self): + return iter(()) + self.validate_comparison(MyMapping()) def test_MutableMapping(self): for sample in [dict]: |