summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/collections.py9
-rw-r--r--Lib/test/test_collections.py114
2 files changed, 121 insertions, 2 deletions
diff --git a/Lib/collections.py b/Lib/collections.py
index 2f19459..98c4325 100644
--- a/Lib/collections.py
+++ b/Lib/collections.py
@@ -694,6 +694,15 @@ class _ChainMap(MutableMapping):
__copy__ = copy
+ def new_child(self): # like Django's Context.push()
+ 'New ChainMap with a new dict followed by all previous maps.'
+ return self.__class__({}, *self.maps)
+
+ @property
+ def parents(self): # like Django's Context.pop()
+ 'New ChainMap from maps[1:].'
+ return self.__class__(*self.maps[1:])
+
def __setitem__(self, key, value):
self.maps[0][key] = value
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
index d785fcb..b3e0907 100644
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -10,6 +10,7 @@ from random import randrange, shuffle
import keyword
import re
import sys
+from collections import _ChainMap as ChainMap
from collections import Hashable, Iterable, Iterator
from collections import Sized, Container, Callable
from collections import Set, MutableSet
@@ -17,6 +18,97 @@ from collections import Mapping, MutableMapping, KeysView, ItemsView, UserDict
from collections import Sequence, MutableSequence
from collections import ByteString
+
+################################################################################
+### _ChainMap (helper class for configparser)
+################################################################################
+
+class TestChainMap(unittest.TestCase):
+
+ def test_basics(self):
+ c = ChainMap()
+ c['a'] = 1
+ c['b'] = 2
+ d = c.new_child()
+ d['b'] = 20
+ d['c'] = 30
+ self.assertEqual(d.maps, [{'b':20, 'c':30}, {'a':1, 'b':2}]) # check internal state
+ self.assertEqual(d.items(), dict(a=1, b=20, c=30).items()) # check items/iter/getitem
+ self.assertEqual(len(d), 3) # check len
+ for key in 'abc': # check contains
+ self.assertIn(key, d)
+ for k, v in dict(a=1, b=20, c=30, z=100).items(): # check get
+ self.assertEqual(d.get(k, 100), v)
+
+ del d['b'] # unmask a value
+ self.assertEqual(d.maps, [{'c':30}, {'a':1, 'b':2}]) # check internal state
+ self.assertEqual(d.items(), dict(a=1, b=2, c=30).items()) # check items/iter/getitem
+ self.assertEqual(len(d), 3) # check len
+ for key in 'abc': # check contains
+ self.assertIn(key, d)
+ for k, v in dict(a=1, b=2, c=30, z=100).items(): # check get
+ self.assertEqual(d.get(k, 100), v)
+ self.assertIn(repr(d), [ # check repr
+ type(d).__name__ + "({'c': 30}, {'a': 1, 'b': 2})",
+ type(d).__name__ + "({'c': 30}, {'b': 2, 'a': 1})"
+ ])
+
+ for e in d.copy(), copy.copy(d): # check shallow copies
+ self.assertEqual(d, e)
+ self.assertEqual(d.maps, e.maps)
+ self.assertIsNot(d, e)
+ self.assertIsNot(d.maps[0], e.maps[0])
+ for m1, m2 in zip(d.maps[1:], e.maps[1:]):
+ self.assertIs(m1, m2)
+
+ for e in [pickle.loads(pickle.dumps(d)),
+ copy.deepcopy(d),
+ eval(repr(d))
+ ]: # check deep copies
+ self.assertEqual(d, e)
+ self.assertEqual(d.maps, e.maps)
+ self.assertIsNot(d, e)
+ for m1, m2 in zip(d.maps, e.maps):
+ self.assertIsNot(m1, m2, e)
+
+ d.new_child()
+ d['b'] = 5
+ self.assertEqual(d.maps, [{'b': 5}, {'c':30}, {'a':1, 'b':2}])
+ self.assertEqual(d.parents.maps, [{'c':30}, {'a':1, 'b':2}]) # check parents
+ self.assertEqual(d['b'], 5) # find first in chain
+ self.assertEqual(d.parents['b'], 2) # look beyond maps[0]
+
+ def test_contructor(self):
+ self.assertEqual(ChainedContext().maps, [{}]) # no-args --> one new dict
+ self.assertEqual(ChainMap({1:2}).maps, [{1:2}]) # 1 arg --> list
+
+ def test_missing(self):
+ class DefaultChainMap(ChainMap):
+ def __missing__(self, key):
+ return 999
+ d = DefaultChainMap(dict(a=1, b=2), dict(b=20, c=30))
+ for k, v in dict(a=1, b=2, c=30, d=999).items():
+ self.assertEqual(d[k], v) # check __getitem__ w/missing
+ for k, v in dict(a=1, b=2, c=30, d=77).items():
+ self.assertEqual(d.get(k, 77), v) # check get() w/ missing
+ for k, v in dict(a=True, b=True, c=True, d=False).items():
+ self.assertEqual(k in d, v) # check __contains__ w/missing
+ self.assertEqual(d.pop('a', 1001), 1, d)
+ self.assertEqual(d.pop('a', 1002), 1002) # check pop() w/missing
+ self.assertEqual(d.popitem(), ('b', 2)) # check popitem() w/missing
+ with self.assertRaises(KeyError):
+ d.popitem()
+
+ def test_dict_coercion(self):
+ d = ChainMap(dict(a=1, b=2), dict(b=20, c=30))
+ self.assertEqual(dict(d), dict(a=1, b=2, c=30))
+ self.assertEqual(dict(d.items()), dict(a=1, b=2, c=30))
+
+
+################################################################################
+### Named Tuples
+################################################################################
+
TestNT = namedtuple('TestNT', 'x y z') # type used for pickle tests
class TestNamedTuple(unittest.TestCase):
@@ -228,6 +320,10 @@ class TestNamedTuple(unittest.TestCase):
self.assertEqual(repr(B(1)), 'B(x=1)')
+################################################################################
+### Abstract Base Classes
+################################################################################
+
class ABCTestCase(unittest.TestCase):
def validate_abstract_methods(self, abc, *names):
@@ -507,7 +603,7 @@ class TestCollectionABCs(ABCTestCase):
def test_issue_4920(self):
# MutableSet.pop() method did not work
- class MySet(collections.MutableSet):
+ class MySet(MutableSet):
__slots__=['__s']
def __init__(self,items=None):
if items is None:
@@ -553,7 +649,7 @@ class TestCollectionABCs(ABCTestCase):
self.assertTrue(issubclass(sample, Mapping))
self.validate_abstract_methods(Mapping, '__contains__', '__iter__', '__len__',
'__getitem__')
- class MyMapping(collections.Mapping):
+ class MyMapping(Mapping):
def __len__(self):
return 0
def __getitem__(self, i):
@@ -625,6 +721,11 @@ class TestCollectionABCs(ABCTestCase):
self.validate_abstract_methods(MutableSequence, '__contains__', '__iter__',
'__len__', '__getitem__', '__setitem__', '__delitem__', 'insert')
+
+################################################################################
+### Counter
+################################################################################
+
class TestCounter(unittest.TestCase):
def test_basics(self):
@@ -788,6 +889,11 @@ class TestCounter(unittest.TestCase):
self.assertEqual(m,
OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)]))
+
+################################################################################
+### OrderedDict
+################################################################################
+
class TestOrderedDict(unittest.TestCase):
def test_init(self):
@@ -1066,6 +1172,10 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
self.assertRaises(KeyError, d.popitem)
+################################################################################
+### Run tests
+################################################################################
+
import doctest, collections
def test_main(verbose=None):