diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/UserDict.py | 7 | ||||
-rw-r--r-- | Lib/copy.py | 4 | ||||
-rw-r--r-- | Lib/test/test_copy.py | 16 | ||||
-rw-r--r-- | Lib/test/test_defaultdict.py | 135 | ||||
-rw-r--r-- | Lib/test/test_dict.py | 50 | ||||
-rw-r--r-- | Lib/test/test_userdict.py | 49 |
6 files changed, 259 insertions, 2 deletions
diff --git a/Lib/UserDict.py b/Lib/UserDict.py index 7168703..5e97817 100644 --- a/Lib/UserDict.py +++ b/Lib/UserDict.py @@ -14,7 +14,12 @@ class UserDict: else: return cmp(self.data, dict) def __len__(self): return len(self.data) - def __getitem__(self, key): return self.data[key] + def __getitem__(self, key): + if key in self.data: + return self.data[key] + if hasattr(self.__class__, "__missing__"): + return self.__class__.__missing__(self, key) + raise KeyError(key) def __setitem__(self, key, item): self.data[key] = item def __delitem__(self, key): del self.data[key] def clear(self): self.data.clear() diff --git a/Lib/copy.py b/Lib/copy.py index b3419ca..9e60144 100644 --- a/Lib/copy.py +++ b/Lib/copy.py @@ -101,7 +101,8 @@ def _copy_immutable(x): return x for t in (type(None), int, long, float, bool, str, tuple, frozenset, type, xrange, types.ClassType, - types.BuiltinFunctionType): + types.BuiltinFunctionType, + types.FunctionType): d[t] = _copy_immutable for name in ("ComplexType", "UnicodeType", "CodeType"): t = getattr(types, name, None) @@ -217,6 +218,7 @@ d[type] = _deepcopy_atomic d[xrange] = _deepcopy_atomic d[types.ClassType] = _deepcopy_atomic d[types.BuiltinFunctionType] = _deepcopy_atomic +d[types.FunctionType] = _deepcopy_atomic def _deepcopy_list(x, memo): y = [] diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py index bd5a3e1..ff4c987 100644 --- a/Lib/test/test_copy.py +++ b/Lib/test/test_copy.py @@ -568,6 +568,22 @@ class TestCopy(unittest.TestCase): raise ValueError, "ain't got no stickin' state" self.assertRaises(ValueError, copy.copy, EvilState()) + def test_copy_function(self): + self.assertEqual(copy.copy(global_foo), global_foo) + def foo(x, y): return x+y + self.assertEqual(copy.copy(foo), foo) + bar = lambda: None + self.assertEqual(copy.copy(bar), bar) + + def test_deepcopy_function(self): + self.assertEqual(copy.deepcopy(global_foo), global_foo) + def foo(x, y): return x+y + self.assertEqual(copy.deepcopy(foo), foo) + bar = lambda: None + self.assertEqual(copy.deepcopy(bar), bar) + +def global_foo(x, y): return x+y + def test_main(): test_support.run_unittest(TestCopy) diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py new file mode 100644 index 0000000..b5a6628 --- /dev/null +++ b/Lib/test/test_defaultdict.py @@ -0,0 +1,135 @@ +"""Unit tests for collections.defaultdict.""" + +import os +import copy +import tempfile +import unittest + +from collections import defaultdict + +def foobar(): + return list + +class TestDefaultDict(unittest.TestCase): + + def test_basic(self): + d1 = defaultdict() + self.assertEqual(d1.default_factory, None) + d1.default_factory = list + d1[12].append(42) + self.assertEqual(d1, {12: [42]}) + d1[12].append(24) + self.assertEqual(d1, {12: [42, 24]}) + d1[13] + d1[14] + self.assertEqual(d1, {12: [42, 24], 13: [], 14: []}) + self.assert_(d1[12] is not d1[13] is not d1[14]) + d2 = defaultdict(list, foo=1, bar=2) + self.assertEqual(d2.default_factory, list) + self.assertEqual(d2, {"foo": 1, "bar": 2}) + self.assertEqual(d2["foo"], 1) + self.assertEqual(d2["bar"], 2) + self.assertEqual(d2[42], []) + self.assert_("foo" in d2) + self.assert_("foo" in d2.keys()) + self.assert_("bar" in d2) + self.assert_("bar" in d2.keys()) + self.assert_(42 in d2) + self.assert_(42 in d2.keys()) + self.assert_(12 not in d2) + self.assert_(12 not in d2.keys()) + d2.default_factory = None + self.assertEqual(d2.default_factory, None) + try: + d2[15] + except KeyError, err: + self.assertEqual(err.args, (15,)) + else: + self.fail("d2[15] didn't raise KeyError") + + def test_missing(self): + d1 = defaultdict() + self.assertRaises(KeyError, d1.__missing__, 42) + d1.default_factory = list + self.assertEqual(d1.__missing__(42), []) + + def test_repr(self): + d1 = defaultdict() + self.assertEqual(d1.default_factory, None) + self.assertEqual(repr(d1), "defaultdict(None, {})") + d1[11] = 41 + self.assertEqual(repr(d1), "defaultdict(None, {11: 41})") + d2 = defaultdict(0) + self.assertEqual(d2.default_factory, 0) + d2[12] = 42 + self.assertEqual(repr(d2), "defaultdict(0, {12: 42})") + def foo(): return 43 + d3 = defaultdict(foo) + self.assert_(d3.default_factory is foo) + d3[13] + self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo)) + + def test_print(self): + d1 = defaultdict() + def foo(): return 42 + d2 = defaultdict(foo, {1: 2}) + # NOTE: We can't use tempfile.[Named]TemporaryFile since this + # code must exercise the tp_print C code, which only gets + # invoked for *real* files. + tfn = tempfile.mktemp() + try: + f = open(tfn, "w+") + try: + print >>f, d1 + print >>f, d2 + f.seek(0) + self.assertEqual(f.readline(), repr(d1) + "\n") + self.assertEqual(f.readline(), repr(d2) + "\n") + finally: + f.close() + finally: + os.remove(tfn) + + def test_copy(self): + d1 = defaultdict() + d2 = d1.copy() + self.assertEqual(type(d2), defaultdict) + self.assertEqual(d2.default_factory, None) + self.assertEqual(d2, {}) + d1.default_factory = list + d3 = d1.copy() + self.assertEqual(type(d3), defaultdict) + self.assertEqual(d3.default_factory, list) + self.assertEqual(d3, {}) + d1[42] + d4 = d1.copy() + self.assertEqual(type(d4), defaultdict) + self.assertEqual(d4.default_factory, list) + self.assertEqual(d4, {42: []}) + d4[12] + self.assertEqual(d4, {42: [], 12: []}) + + def test_shallow_copy(self): + d1 = defaultdict(foobar, {1: 1}) + d2 = copy.copy(d1) + self.assertEqual(d2.default_factory, foobar) + self.assertEqual(d2, d1) + d1.default_factory = list + d2 = copy.copy(d1) + self.assertEqual(d2.default_factory, list) + self.assertEqual(d2, d1) + + def test_deep_copy(self): + d1 = defaultdict(foobar, {1: [1]}) + d2 = copy.deepcopy(d1) + self.assertEqual(d2.default_factory, foobar) + self.assertEqual(d2, d1) + self.assert_(d1[1] is not d2[1]) + d1.default_factory = list + d2 = copy.deepcopy(d1) + self.assertEqual(d2.default_factory, list) + self.assertEqual(d2, d1) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index e13829c..f3f78e7 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -395,6 +395,56 @@ class DictTest(unittest.TestCase): else: self.fail("< didn't raise Exc") + def test_missing(self): + # Make sure dict doesn't have a __missing__ method + self.assertEqual(hasattr(dict, "__missing__"), False) + self.assertEqual(hasattr({}, "__missing__"), False) + # Test several cases: + # (D) subclass defines __missing__ method returning a value + # (E) subclass defines __missing__ method raising RuntimeError + # (F) subclass sets __missing__ instance variable (no effect) + # (G) subclass doesn't define __missing__ at a all + class D(dict): + def __missing__(self, key): + return 42 + d = D({1: 2, 3: 4}) + self.assertEqual(d[1], 2) + self.assertEqual(d[3], 4) + self.assert_(2 not in d) + self.assert_(2 not in d.keys()) + self.assertEqual(d[2], 42) + class E(dict): + def __missing__(self, key): + raise RuntimeError(key) + e = E() + try: + e[42] + except RuntimeError, err: + self.assertEqual(err.args, (42,)) + else: + self.fail_("e[42] didn't raise RuntimeError") + class F(dict): + def __init__(self): + # An instance variable __missing__ should have no effect + self.__missing__ = lambda key: None + f = F() + try: + f[42] + except KeyError, err: + self.assertEqual(err.args, (42,)) + else: + self.fail_("f[42] didn't raise KeyError") + class G(dict): + pass + g = G() + try: + g[42] + except KeyError, err: + self.assertEqual(err.args, (42,)) + else: + self.fail_("g[42] didn't raise KeyError") + + import mapping_tests class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol): diff --git a/Lib/test/test_userdict.py b/Lib/test/test_userdict.py index 2d5fa03..a4b7de4 100644 --- a/Lib/test/test_userdict.py +++ b/Lib/test/test_userdict.py @@ -148,6 +148,55 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): self.assertEqual(t.popitem(), ("x", 42)) self.assertRaises(KeyError, t.popitem) + def test_missing(self): + # Make sure UserDict doesn't have a __missing__ method + self.assertEqual(hasattr(UserDict, "__missing__"), False) + # Test several cases: + # (D) subclass defines __missing__ method returning a value + # (E) subclass defines __missing__ method raising RuntimeError + # (F) subclass sets __missing__ instance variable (no effect) + # (G) subclass doesn't define __missing__ at a all + class D(UserDict.UserDict): + def __missing__(self, key): + return 42 + d = D({1: 2, 3: 4}) + self.assertEqual(d[1], 2) + self.assertEqual(d[3], 4) + self.assert_(2 not in d) + self.assert_(2 not in d.keys()) + self.assertEqual(d[2], 42) + class E(UserDict.UserDict): + def __missing__(self, key): + raise RuntimeError(key) + e = E() + try: + e[42] + except RuntimeError, err: + self.assertEqual(err.args, (42,)) + else: + self.fail_("e[42] didn't raise RuntimeError") + class F(UserDict.UserDict): + def __init__(self): + # An instance variable __missing__ should have no effect + self.__missing__ = lambda key: None + UserDict.UserDict.__init__(self) + f = F() + try: + f[42] + except KeyError, err: + self.assertEqual(err.args, (42,)) + else: + self.fail_("f[42] didn't raise KeyError") + class G(UserDict.UserDict): + pass + g = G() + try: + g[42] + except KeyError, err: + self.assertEqual(err.args, (42,)) + else: + self.fail_("g[42] didn't raise KeyError") + ########################## # Test Dict Mixin |