summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorGuido van Rossum <guido@python.org>2006-02-25 22:38:04 (GMT)
committerGuido van Rossum <guido@python.org>2006-02-25 22:38:04 (GMT)
commit1968ad32cd7f46d9bb64826672ef68cdaee35288 (patch)
treec46db5a446d9de18fb8436408ec29d2111a2f5ad /Lib
parentab51f5f24d6f6edef5e8fac5e31b2e4ac0cbdbac (diff)
downloadcpython-1968ad32cd7f46d9bb64826672ef68cdaee35288.zip
cpython-1968ad32cd7f46d9bb64826672ef68cdaee35288.tar.gz
cpython-1968ad32cd7f46d9bb64826672ef68cdaee35288.tar.bz2
- Patch 1433928:
- The copy module now "copies" function objects (as atomic objects). - dict.__getitem__ now looks for a __missing__ hook before raising KeyError. - Added a new type, defaultdict, to the collections module. This uses the new __missing__ hook behavior added to dict (see above).
Diffstat (limited to 'Lib')
-rw-r--r--Lib/UserDict.py7
-rw-r--r--Lib/copy.py4
-rw-r--r--Lib/test/test_copy.py16
-rw-r--r--Lib/test/test_defaultdict.py135
-rw-r--r--Lib/test/test_dict.py50
-rw-r--r--Lib/test/test_userdict.py49
6 files changed, 259 insertions, 2 deletions
diff --git a/Lib/UserDict.py b/Lib/UserDict.py
index 7168703..5e97817 100644
--- a/Lib/UserDict.py
+++ b/Lib/UserDict.py
@@ -14,7 +14,12 @@ class UserDict:
else:
return cmp(self.data, dict)
def __len__(self): return len(self.data)
- def __getitem__(self, key): return self.data[key]
+ def __getitem__(self, key):
+ if key in self.data:
+ return self.data[key]
+ if hasattr(self.__class__, "__missing__"):
+ return self.__class__.__missing__(self, key)
+ raise KeyError(key)
def __setitem__(self, key, item): self.data[key] = item
def __delitem__(self, key): del self.data[key]
def clear(self): self.data.clear()
diff --git a/Lib/copy.py b/Lib/copy.py
index b3419ca..9e60144 100644
--- a/Lib/copy.py
+++ b/Lib/copy.py
@@ -101,7 +101,8 @@ def _copy_immutable(x):
return x
for t in (type(None), int, long, float, bool, str, tuple,
frozenset, type, xrange, types.ClassType,
- types.BuiltinFunctionType):
+ types.BuiltinFunctionType,
+ types.FunctionType):
d[t] = _copy_immutable
for name in ("ComplexType", "UnicodeType", "CodeType"):
t = getattr(types, name, None)
@@ -217,6 +218,7 @@ d[type] = _deepcopy_atomic
d[xrange] = _deepcopy_atomic
d[types.ClassType] = _deepcopy_atomic
d[types.BuiltinFunctionType] = _deepcopy_atomic
+d[types.FunctionType] = _deepcopy_atomic
def _deepcopy_list(x, memo):
y = []
diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py
index bd5a3e1..ff4c987 100644
--- a/Lib/test/test_copy.py
+++ b/Lib/test/test_copy.py
@@ -568,6 +568,22 @@ class TestCopy(unittest.TestCase):
raise ValueError, "ain't got no stickin' state"
self.assertRaises(ValueError, copy.copy, EvilState())
+ def test_copy_function(self):
+ self.assertEqual(copy.copy(global_foo), global_foo)
+ def foo(x, y): return x+y
+ self.assertEqual(copy.copy(foo), foo)
+ bar = lambda: None
+ self.assertEqual(copy.copy(bar), bar)
+
+ def test_deepcopy_function(self):
+ self.assertEqual(copy.deepcopy(global_foo), global_foo)
+ def foo(x, y): return x+y
+ self.assertEqual(copy.deepcopy(foo), foo)
+ bar = lambda: None
+ self.assertEqual(copy.deepcopy(bar), bar)
+
+def global_foo(x, y): return x+y
+
def test_main():
test_support.run_unittest(TestCopy)
diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py
new file mode 100644
index 0000000..b5a6628
--- /dev/null
+++ b/Lib/test/test_defaultdict.py
@@ -0,0 +1,135 @@
+"""Unit tests for collections.defaultdict."""
+
+import os
+import copy
+import tempfile
+import unittest
+
+from collections import defaultdict
+
+def foobar():
+ return list
+
+class TestDefaultDict(unittest.TestCase):
+
+ def test_basic(self):
+ d1 = defaultdict()
+ self.assertEqual(d1.default_factory, None)
+ d1.default_factory = list
+ d1[12].append(42)
+ self.assertEqual(d1, {12: [42]})
+ d1[12].append(24)
+ self.assertEqual(d1, {12: [42, 24]})
+ d1[13]
+ d1[14]
+ self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
+ self.assert_(d1[12] is not d1[13] is not d1[14])
+ d2 = defaultdict(list, foo=1, bar=2)
+ self.assertEqual(d2.default_factory, list)
+ self.assertEqual(d2, {"foo": 1, "bar": 2})
+ self.assertEqual(d2["foo"], 1)
+ self.assertEqual(d2["bar"], 2)
+ self.assertEqual(d2[42], [])
+ self.assert_("foo" in d2)
+ self.assert_("foo" in d2.keys())
+ self.assert_("bar" in d2)
+ self.assert_("bar" in d2.keys())
+ self.assert_(42 in d2)
+ self.assert_(42 in d2.keys())
+ self.assert_(12 not in d2)
+ self.assert_(12 not in d2.keys())
+ d2.default_factory = None
+ self.assertEqual(d2.default_factory, None)
+ try:
+ d2[15]
+ except KeyError, err:
+ self.assertEqual(err.args, (15,))
+ else:
+ self.fail("d2[15] didn't raise KeyError")
+
+ def test_missing(self):
+ d1 = defaultdict()
+ self.assertRaises(KeyError, d1.__missing__, 42)
+ d1.default_factory = list
+ self.assertEqual(d1.__missing__(42), [])
+
+ def test_repr(self):
+ d1 = defaultdict()
+ self.assertEqual(d1.default_factory, None)
+ self.assertEqual(repr(d1), "defaultdict(None, {})")
+ d1[11] = 41
+ self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
+ d2 = defaultdict(0)
+ self.assertEqual(d2.default_factory, 0)
+ d2[12] = 42
+ self.assertEqual(repr(d2), "defaultdict(0, {12: 42})")
+ def foo(): return 43
+ d3 = defaultdict(foo)
+ self.assert_(d3.default_factory is foo)
+ d3[13]
+ self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
+
+ def test_print(self):
+ d1 = defaultdict()
+ def foo(): return 42
+ d2 = defaultdict(foo, {1: 2})
+ # NOTE: We can't use tempfile.[Named]TemporaryFile since this
+ # code must exercise the tp_print C code, which only gets
+ # invoked for *real* files.
+ tfn = tempfile.mktemp()
+ try:
+ f = open(tfn, "w+")
+ try:
+ print >>f, d1
+ print >>f, d2
+ f.seek(0)
+ self.assertEqual(f.readline(), repr(d1) + "\n")
+ self.assertEqual(f.readline(), repr(d2) + "\n")
+ finally:
+ f.close()
+ finally:
+ os.remove(tfn)
+
+ def test_copy(self):
+ d1 = defaultdict()
+ d2 = d1.copy()
+ self.assertEqual(type(d2), defaultdict)
+ self.assertEqual(d2.default_factory, None)
+ self.assertEqual(d2, {})
+ d1.default_factory = list
+ d3 = d1.copy()
+ self.assertEqual(type(d3), defaultdict)
+ self.assertEqual(d3.default_factory, list)
+ self.assertEqual(d3, {})
+ d1[42]
+ d4 = d1.copy()
+ self.assertEqual(type(d4), defaultdict)
+ self.assertEqual(d4.default_factory, list)
+ self.assertEqual(d4, {42: []})
+ d4[12]
+ self.assertEqual(d4, {42: [], 12: []})
+
+ def test_shallow_copy(self):
+ d1 = defaultdict(foobar, {1: 1})
+ d2 = copy.copy(d1)
+ self.assertEqual(d2.default_factory, foobar)
+ self.assertEqual(d2, d1)
+ d1.default_factory = list
+ d2 = copy.copy(d1)
+ self.assertEqual(d2.default_factory, list)
+ self.assertEqual(d2, d1)
+
+ def test_deep_copy(self):
+ d1 = defaultdict(foobar, {1: [1]})
+ d2 = copy.deepcopy(d1)
+ self.assertEqual(d2.default_factory, foobar)
+ self.assertEqual(d2, d1)
+ self.assert_(d1[1] is not d2[1])
+ d1.default_factory = list
+ d2 = copy.deepcopy(d1)
+ self.assertEqual(d2.default_factory, list)
+ self.assertEqual(d2, d1)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py
index e13829c..f3f78e7 100644
--- a/Lib/test/test_dict.py
+++ b/Lib/test/test_dict.py
@@ -395,6 +395,56 @@ class DictTest(unittest.TestCase):
else:
self.fail("< didn't raise Exc")
+ def test_missing(self):
+ # Make sure dict doesn't have a __missing__ method
+ self.assertEqual(hasattr(dict, "__missing__"), False)
+ self.assertEqual(hasattr({}, "__missing__"), False)
+ # Test several cases:
+ # (D) subclass defines __missing__ method returning a value
+ # (E) subclass defines __missing__ method raising RuntimeError
+ # (F) subclass sets __missing__ instance variable (no effect)
+ # (G) subclass doesn't define __missing__ at a all
+ class D(dict):
+ def __missing__(self, key):
+ return 42
+ d = D({1: 2, 3: 4})
+ self.assertEqual(d[1], 2)
+ self.assertEqual(d[3], 4)
+ self.assert_(2 not in d)
+ self.assert_(2 not in d.keys())
+ self.assertEqual(d[2], 42)
+ class E(dict):
+ def __missing__(self, key):
+ raise RuntimeError(key)
+ e = E()
+ try:
+ e[42]
+ except RuntimeError, err:
+ self.assertEqual(err.args, (42,))
+ else:
+ self.fail_("e[42] didn't raise RuntimeError")
+ class F(dict):
+ def __init__(self):
+ # An instance variable __missing__ should have no effect
+ self.__missing__ = lambda key: None
+ f = F()
+ try:
+ f[42]
+ except KeyError, err:
+ self.assertEqual(err.args, (42,))
+ else:
+ self.fail_("f[42] didn't raise KeyError")
+ class G(dict):
+ pass
+ g = G()
+ try:
+ g[42]
+ except KeyError, err:
+ self.assertEqual(err.args, (42,))
+ else:
+ self.fail_("g[42] didn't raise KeyError")
+
+
import mapping_tests
class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
diff --git a/Lib/test/test_userdict.py b/Lib/test/test_userdict.py
index 2d5fa03..a4b7de4 100644
--- a/Lib/test/test_userdict.py
+++ b/Lib/test/test_userdict.py
@@ -148,6 +148,55 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol):
self.assertEqual(t.popitem(), ("x", 42))
self.assertRaises(KeyError, t.popitem)
+ def test_missing(self):
+ # Make sure UserDict doesn't have a __missing__ method
+ self.assertEqual(hasattr(UserDict, "__missing__"), False)
+ # Test several cases:
+ # (D) subclass defines __missing__ method returning a value
+ # (E) subclass defines __missing__ method raising RuntimeError
+ # (F) subclass sets __missing__ instance variable (no effect)
+ # (G) subclass doesn't define __missing__ at a all
+ class D(UserDict.UserDict):
+ def __missing__(self, key):
+ return 42
+ d = D({1: 2, 3: 4})
+ self.assertEqual(d[1], 2)
+ self.assertEqual(d[3], 4)
+ self.assert_(2 not in d)
+ self.assert_(2 not in d.keys())
+ self.assertEqual(d[2], 42)
+ class E(UserDict.UserDict):
+ def __missing__(self, key):
+ raise RuntimeError(key)
+ e = E()
+ try:
+ e[42]
+ except RuntimeError, err:
+ self.assertEqual(err.args, (42,))
+ else:
+ self.fail_("e[42] didn't raise RuntimeError")
+ class F(UserDict.UserDict):
+ def __init__(self):
+ # An instance variable __missing__ should have no effect
+ self.__missing__ = lambda key: None
+ UserDict.UserDict.__init__(self)
+ f = F()
+ try:
+ f[42]
+ except KeyError, err:
+ self.assertEqual(err.args, (42,))
+ else:
+ self.fail_("f[42] didn't raise KeyError")
+ class G(UserDict.UserDict):
+ pass
+ g = G()
+ try:
+ g[42]
+ except KeyError, err:
+ self.assertEqual(err.args, (42,))
+ else:
+ self.fail_("g[42] didn't raise KeyError")
+
##########################
# Test Dict Mixin