summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_collections.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_collections.py')
-rw-r--r--Lib/test/test_collections.py354
1 files changed, 348 insertions, 6 deletions
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
index 0882bac..5238382 100644
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -3,6 +3,7 @@
import collections
import copy
import doctest
+import inspect
import keyword
import operator
import pickle
@@ -11,12 +12,15 @@ import re
import string
import sys
from test import support
+import types
import unittest
from collections import namedtuple, Counter, OrderedDict, _count_elements
-from collections import UserDict
+from collections import UserDict, UserString, UserList
from collections import ChainMap
-from collections.abc import Hashable, Iterable, Iterator
+from collections import deque
+from collections.abc import Awaitable, Coroutine, AsyncIterator, AsyncIterable
+from collections.abc import Hashable, Iterable, Iterator, Generator
from collections.abc import Sized, Container, Callable
from collections.abc import Set, MutableSet
from collections.abc import Mapping, MutableMapping, KeysView, ItemsView
@@ -24,6 +28,26 @@ from collections.abc import Sequence, MutableSequence
from collections.abc import ByteString
+class TestUserObjects(unittest.TestCase):
+ def _superset_test(self, a, b):
+ self.assertGreaterEqual(
+ set(dir(a)),
+ set(dir(b)),
+ '{a} should have all the methods of {b}'.format(
+ a=a.__name__,
+ b=b.__name__,
+ ),
+ )
+ def test_str_protocol(self):
+ self._superset_test(UserString, str)
+
+ def test_list_protocol(self):
+ self._superset_test(UserList, list)
+
+ def test_dict_protocol(self):
+ self._superset_test(UserDict, dict)
+
+
################################################################################
### ChainMap (helper class for configparser and the string module)
################################################################################
@@ -90,7 +114,7 @@ class TestChainMap(unittest.TestCase):
self.assertEqual(f['b'], 5) # find first in chain
self.assertEqual(f.parents['b'], 2) # look beyond maps[0]
- def test_contructor(self):
+ def test_constructor(self):
self.assertEqual(ChainMap().maps, [{}]) # no-args --> one new dict
self.assertEqual(ChainMap({1:2}).maps, [{1:2}]) # 1 arg --> list
@@ -199,6 +223,14 @@ class TestNamedTuple(unittest.TestCase):
Point = namedtuple('Point', 'x y')
self.assertEqual(Point.__doc__, 'Point(x, y)')
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ def test_doc_writable(self):
+ Point = namedtuple('Point', 'x y')
+ self.assertEqual(Point.x.__doc__, 'Alias for field number 0')
+ Point.x.__doc__ = 'docstring for Point.x'
+ self.assertEqual(Point.x.__doc__, 'docstring for Point.x')
+
def test_name_fixer(self):
for spec, renamed in [
[('efg', 'g%hi'), ('efg', '_1')], # field with non-alpha char
@@ -457,6 +489,121 @@ class ABCTestCase(unittest.TestCase):
class TestOneTrickPonyABCs(ABCTestCase):
+ def test_Awaitable(self):
+ def gen():
+ yield
+
+ @types.coroutine
+ def coro():
+ yield
+
+ async def new_coro():
+ pass
+
+ class Bar:
+ def __await__(self):
+ yield
+
+ class MinimalCoro(Coroutine):
+ def send(self, value):
+ return value
+ def throw(self, typ, val=None, tb=None):
+ super().throw(typ, val, tb)
+ def __await__(self):
+ yield
+
+ non_samples = [None, int(), gen(), object()]
+ for x in non_samples:
+ self.assertNotIsInstance(x, Awaitable)
+ self.assertFalse(issubclass(type(x), Awaitable), repr(type(x)))
+
+ samples = [Bar(), MinimalCoro()]
+ for x in samples:
+ self.assertIsInstance(x, Awaitable)
+ self.assertTrue(issubclass(type(x), Awaitable))
+
+ c = coro()
+ # Iterable coroutines (generators with CO_ITERABLE_COROUTINE
+ # flag don't have '__await__' method, hence can't be instances
+ # of Awaitable. Use inspect.isawaitable to detect them.
+ self.assertNotIsInstance(c, Awaitable)
+
+ c = new_coro()
+ self.assertIsInstance(c, Awaitable)
+ c.close() # awoid RuntimeWarning that coro() was not awaited
+
+ class CoroLike: pass
+ Coroutine.register(CoroLike)
+ self.assertTrue(isinstance(CoroLike(), Awaitable))
+ self.assertTrue(issubclass(CoroLike, Awaitable))
+ CoroLike = None
+ support.gc_collect() # Kill CoroLike to clean-up ABCMeta cache
+
+ def test_Coroutine(self):
+ def gen():
+ yield
+
+ @types.coroutine
+ def coro():
+ yield
+
+ async def new_coro():
+ pass
+
+ class Bar:
+ def __await__(self):
+ yield
+
+ class MinimalCoro(Coroutine):
+ def send(self, value):
+ return value
+ def throw(self, typ, val=None, tb=None):
+ super().throw(typ, val, tb)
+ def __await__(self):
+ yield
+
+ non_samples = [None, int(), gen(), object(), Bar()]
+ for x in non_samples:
+ self.assertNotIsInstance(x, Coroutine)
+ self.assertFalse(issubclass(type(x), Coroutine), repr(type(x)))
+
+ samples = [MinimalCoro()]
+ for x in samples:
+ self.assertIsInstance(x, Awaitable)
+ self.assertTrue(issubclass(type(x), Awaitable))
+
+ c = coro()
+ # Iterable coroutines (generators with CO_ITERABLE_COROUTINE
+ # flag don't have '__await__' method, hence can't be instances
+ # of Coroutine. Use inspect.isawaitable to detect them.
+ self.assertNotIsInstance(c, Coroutine)
+
+ c = new_coro()
+ self.assertIsInstance(c, Coroutine)
+ c.close() # awoid RuntimeWarning that coro() was not awaited
+
+ class CoroLike:
+ def send(self, value):
+ pass
+ def throw(self, typ, val=None, tb=None):
+ pass
+ def close(self):
+ pass
+ def __await__(self):
+ pass
+ self.assertTrue(isinstance(CoroLike(), Coroutine))
+ self.assertTrue(issubclass(CoroLike, Coroutine))
+
+ class CoroLike:
+ def send(self, value):
+ pass
+ def close(self):
+ pass
+ def __await__(self):
+ pass
+ self.assertFalse(isinstance(CoroLike(), Coroutine))
+ self.assertFalse(issubclass(CoroLike, Coroutine))
+
def test_Hashable(self):
# Check some non-hashables
non_samples = [bytearray(), list(), set(), dict()]
@@ -483,6 +630,40 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.validate_abstract_methods(Hashable, '__hash__')
self.validate_isinstance(Hashable, '__hash__')
+ def test_AsyncIterable(self):
+ class AI:
+ async def __aiter__(self):
+ return self
+ self.assertTrue(isinstance(AI(), AsyncIterable))
+ self.assertTrue(issubclass(AI, AsyncIterable))
+ # Check some non-iterables
+ non_samples = [None, object, []]
+ for x in non_samples:
+ self.assertNotIsInstance(x, AsyncIterable)
+ self.assertFalse(issubclass(type(x), AsyncIterable), repr(type(x)))
+ self.validate_abstract_methods(AsyncIterable, '__aiter__')
+ self.validate_isinstance(AsyncIterable, '__aiter__')
+
+ def test_AsyncIterator(self):
+ class AI:
+ async def __aiter__(self):
+ return self
+ async def __anext__(self):
+ raise StopAsyncIteration
+ self.assertTrue(isinstance(AI(), AsyncIterator))
+ self.assertTrue(issubclass(AI, AsyncIterator))
+ non_samples = [None, object, []]
+ # Check some non-iterables
+ for x in non_samples:
+ self.assertNotIsInstance(x, AsyncIterator)
+ self.assertFalse(issubclass(type(x), AsyncIterator), repr(type(x)))
+ # Similarly to regular iterators (see issue 10565)
+ class AnextOnly:
+ async def __anext__(self):
+ raise StopAsyncIteration
+ self.assertNotIsInstance(AnextOnly(), AsyncIterator)
+ self.validate_abstract_methods(AsyncIterator, '__anext__', '__aiter__')
+
def test_Iterable(self):
# Check some non-iterables
non_samples = [None, 42, 3.14, 1j]
@@ -530,9 +711,80 @@ class TestOneTrickPonyABCs(ABCTestCase):
class NextOnly:
def __next__(self):
yield 1
- raise StopIteration
+ return
self.assertNotIsInstance(NextOnly(), Iterator)
+ def test_Generator(self):
+ class NonGen1:
+ def __iter__(self): return self
+ def __next__(self): return None
+ def close(self): pass
+ def throw(self, typ, val=None, tb=None): pass
+
+ class NonGen2:
+ def __iter__(self): return self
+ def __next__(self): return None
+ def close(self): pass
+ def send(self, value): return value
+
+ class NonGen3:
+ def close(self): pass
+ def send(self, value): return value
+ def throw(self, typ, val=None, tb=None): pass
+
+ non_samples = [
+ None, 42, 3.14, 1j, b"", "", (), [], {}, set(),
+ iter(()), iter([]), NonGen1(), NonGen2(), NonGen3()]
+ for x in non_samples:
+ self.assertNotIsInstance(x, Generator)
+ self.assertFalse(issubclass(type(x), Generator), repr(type(x)))
+
+ class Gen:
+ def __iter__(self): return self
+ def __next__(self): return None
+ def close(self): pass
+ def send(self, value): return value
+ def throw(self, typ, val=None, tb=None): pass
+
+ class MinimalGen(Generator):
+ def send(self, value):
+ return value
+ def throw(self, typ, val=None, tb=None):
+ super().throw(typ, val, tb)
+
+ def gen():
+ yield 1
+
+ samples = [gen(), (lambda: (yield))(), Gen(), MinimalGen()]
+ for x in samples:
+ self.assertIsInstance(x, Iterator)
+ self.assertIsInstance(x, Generator)
+ self.assertTrue(issubclass(type(x), Generator), repr(type(x)))
+ self.validate_abstract_methods(Generator, 'send', 'throw')
+
+ # mixin tests
+ mgen = MinimalGen()
+ self.assertIs(mgen, iter(mgen))
+ self.assertIs(mgen.send(None), next(mgen))
+ self.assertEqual(2, mgen.send(2))
+ self.assertIsNone(mgen.close())
+ self.assertRaises(ValueError, mgen.throw, ValueError)
+ self.assertRaisesRegex(ValueError, "^huhu$",
+ mgen.throw, ValueError, ValueError("huhu"))
+ self.assertRaises(StopIteration, mgen.throw, StopIteration())
+
+ class FailOnClose(Generator):
+ def send(self, value): return value
+ def throw(self, *args): raise ValueError
+
+ self.assertRaises(ValueError, FailOnClose().close)
+
+ class IgnoreGeneratorExit(Generator):
+ def send(self, value): return value
+ def throw(self, *args): pass
+
+ self.assertRaises(RuntimeError, IgnoreGeneratorExit().close)
+
def test_Sized(self):
non_samples = [None, 42, 3.14, 1j,
(lambda: (yield))(),
@@ -659,6 +911,59 @@ class TestCollectionABCs(ABCTestCase):
a, b = OneTwoThreeSet(), OneTwoThreeSet()
self.assertTrue(hash(a) == hash(b))
+ def test_isdisjoint_Set(self):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
+ def __contains__(self, x):
+ return x in self.contents
+ def __iter__(self):
+ return iter(self.contents)
+ def __len__(self):
+ return len([x for x in self.contents])
+ s1 = MySet((1, 2, 3))
+ s2 = MySet((4, 5, 6))
+ s3 = MySet((1, 5, 6))
+ self.assertTrue(s1.isdisjoint(s2))
+ self.assertFalse(s1.isdisjoint(s3))
+
+ def test_equality_Set(self):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
+ def __contains__(self, x):
+ return x in self.contents
+ def __iter__(self):
+ return iter(self.contents)
+ def __len__(self):
+ return len([x for x in self.contents])
+ s1 = MySet((1,))
+ s2 = MySet((1, 2))
+ s3 = MySet((3, 4))
+ s4 = MySet((3, 4))
+ self.assertTrue(s2 > s1)
+ self.assertTrue(s1 < s2)
+ self.assertFalse(s2 <= s1)
+ self.assertFalse(s2 <= s3)
+ self.assertFalse(s1 >= s2)
+ self.assertEqual(s3, s4)
+ self.assertNotEqual(s2, s3)
+
+ def test_arithmetic_Set(self):
+ class MySet(Set):
+ def __init__(self, itr):
+ self.contents = itr
+ def __contains__(self, x):
+ return x in self.contents
+ def __iter__(self):
+ return iter(self.contents)
+ def __len__(self):
+ return len([x for x in self.contents])
+ s1 = MySet((1, 2, 3))
+ s2 = MySet((3, 4, 5))
+ s3 = s1 & s2
+ self.assertEqual(s3, MySet((3,)))
+
def test_MutableSet(self):
self.assertIsInstance(set(), MutableSet)
self.assertTrue(issubclass(set, MutableSet))
@@ -959,6 +1264,41 @@ class TestCollectionABCs(ABCTestCase):
self.validate_abstract_methods(Sequence, '__contains__', '__iter__', '__len__',
'__getitem__')
+ def test_Sequence_mixins(self):
+ class SequenceSubclass(Sequence):
+ def __init__(self, seq=()):
+ self.seq = seq
+
+ def __getitem__(self, index):
+ return self.seq[index]
+
+ def __len__(self):
+ return len(self.seq)
+
+ # Compare Sequence.index() behavior to (list|str).index() behavior
+ def assert_index_same(seq1, seq2, index_args):
+ try:
+ expected = seq1.index(*index_args)
+ except ValueError:
+ with self.assertRaises(ValueError):
+ seq2.index(*index_args)
+ else:
+ actual = seq2.index(*index_args)
+ self.assertEqual(
+ actual, expected, '%r.index%s' % (seq1, index_args))
+
+ for ty in list, str:
+ nativeseq = ty('abracadabra')
+ indexes = [-10000, -9999] + list(range(-3, len(nativeseq) + 3))
+ seqseq = SequenceSubclass(nativeseq)
+ for letter in set(nativeseq) | {'z'}:
+ assert_index_same(nativeseq, seqseq, (letter,))
+ for start in range(-3, len(nativeseq) + 3):
+ assert_index_same(nativeseq, seqseq, (letter, start))
+ for stop in range(-3, len(nativeseq) + 3):
+ assert_index_same(
+ nativeseq, seqseq, (letter, start, stop))
+
def test_ByteString(self):
for sample in [bytes, bytearray]:
self.assertIsInstance(sample(), ByteString)
@@ -973,7 +1313,7 @@ class TestCollectionABCs(ABCTestCase):
for sample in [tuple, str, bytes]:
self.assertNotIsInstance(sample(), MutableSequence)
self.assertFalse(issubclass(sample, MutableSequence))
- for sample in [list, bytearray]:
+ for sample in [list, bytearray, deque]:
self.assertIsInstance(sample(), MutableSequence)
self.assertTrue(issubclass(sample, MutableSequence))
self.assertFalse(issubclass(str, MutableSequence))
@@ -1289,7 +1629,9 @@ class TestCounter(unittest.TestCase):
def test_main(verbose=None):
NamedTupleDocs = doctest.DocTestSuite(module=collections)
test_classes = [TestNamedTuple, NamedTupleDocs, TestOneTrickPonyABCs,
- TestCollectionABCs, TestCounter, TestChainMap]
+ TestCollectionABCs, TestCounter, TestChainMap,
+ TestUserObjects,
+ ]
support.run_unittest(*test_classes)
support.run_doctest(collections, verbose)