summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/copy.py4
-rw-r--r--Lib/test/test_copy.py88
-rw-r--r--Lib/weakref.py22
3 files changed, 113 insertions, 1 deletions
diff --git a/Lib/copy.py b/Lib/copy.py
index a334b79..2646350 100644
--- a/Lib/copy.py
+++ b/Lib/copy.py
@@ -49,6 +49,7 @@ __getstate__() and __setstate__(). See the documentation for module
"""
import types
+import weakref
from copyreg import dispatch_table
class Error(Exception):
@@ -102,7 +103,7 @@ def _copy_immutable(x):
for t in (type(None), int, float, bool, str, tuple,
frozenset, type, range,
types.BuiltinFunctionType, type(Ellipsis),
- types.FunctionType):
+ types.FunctionType, weakref.ref):
d[t] = _copy_immutable
t = getattr(types, "CodeType", None)
if t is not None:
@@ -198,6 +199,7 @@ d[type] = _deepcopy_atomic
d[range] = _deepcopy_atomic
d[types.BuiltinFunctionType] = _deepcopy_atomic
d[types.FunctionType] = _deepcopy_atomic
+d[weakref.ref] = _deepcopy_atomic
def _deepcopy_list(x, memo):
y = []
diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py
index 133c888..502bf3f 100644
--- a/Lib/test/test_copy.py
+++ b/Lib/test/test_copy.py
@@ -2,7 +2,9 @@
import copy
import copyreg
+import weakref
from operator import le, lt, ge, gt, eq, ne
+
import unittest
from test import support
@@ -590,6 +592,92 @@ class TestCopy(unittest.TestCase):
bar = lambda: None
self.assertEqual(copy.deepcopy(bar), bar)
+ def _check_weakref(self, _copy):
+ class C(object):
+ pass
+ obj = C()
+ x = weakref.ref(obj)
+ y = _copy(x)
+ self.assertTrue(y is x)
+ del obj
+ y = _copy(x)
+ self.assertTrue(y is x)
+
+ def test_copy_weakref(self):
+ self._check_weakref(copy.copy)
+
+ def test_deepcopy_weakref(self):
+ self._check_weakref(copy.deepcopy)
+
+ def _check_copy_weakdict(self, _dicttype):
+ class C(object):
+ pass
+ a, b, c, d = [C() for i in range(4)]
+ u = _dicttype()
+ u[a] = b
+ u[c] = d
+ v = copy.copy(u)
+ self.assertFalse(v is u)
+ self.assertEqual(v, u)
+ self.assertEqual(v[a], b)
+ self.assertEqual(v[c], d)
+ self.assertEqual(len(v), 2)
+ del c, d
+ self.assertEqual(len(v), 1)
+ x, y = C(), C()
+ # The underlying containers are decoupled
+ v[x] = y
+ self.assertFalse(x in u)
+
+ def test_copy_weakkeydict(self):
+ self._check_copy_weakdict(weakref.WeakKeyDictionary)
+
+ def test_copy_weakvaluedict(self):
+ self._check_copy_weakdict(weakref.WeakValueDictionary)
+
+ def test_deepcopy_weakkeydict(self):
+ class C(object):
+ def __init__(self, i):
+ self.i = i
+ a, b, c, d = [C(i) for i in range(4)]
+ u = weakref.WeakKeyDictionary()
+ u[a] = b
+ u[c] = d
+ # Keys aren't copied, values are
+ v = copy.deepcopy(u)
+ self.assertNotEqual(v, u)
+ self.assertEqual(len(v), 2)
+ self.assertFalse(v[a] is b)
+ self.assertFalse(v[c] is d)
+ self.assertEqual(v[a].i, b.i)
+ self.assertEqual(v[c].i, d.i)
+ del c
+ self.assertEqual(len(v), 1)
+
+ def test_deepcopy_weakvaluedict(self):
+ class C(object):
+ def __init__(self, i):
+ self.i = i
+ a, b, c, d = [C(i) for i in range(4)]
+ u = weakref.WeakValueDictionary()
+ u[a] = b
+ u[c] = d
+ # Keys are copied, values aren't
+ v = copy.deepcopy(u)
+ self.assertNotEqual(v, u)
+ self.assertEqual(len(v), 2)
+ (x, y), (z, t) = sorted(v.items(), key=lambda pair: pair[0].i)
+ self.assertFalse(x is a)
+ self.assertEqual(x.i, a.i)
+ self.assertTrue(y is b)
+ self.assertFalse(z is c)
+ self.assertEqual(z.i, c.i)
+ self.assertTrue(t is d)
+ del x, y, z, t
+ del d
+ self.assertEqual(len(v), 1)
+
+
def global_foo(x, y): return x+y
def test_main():
diff --git a/Lib/weakref.py b/Lib/weakref.py
index 6663c26..0276dfd 100644
--- a/Lib/weakref.py
+++ b/Lib/weakref.py
@@ -85,6 +85,17 @@ class WeakValueDictionary(collections.MutableMapping):
new[key] = o
return new
+ __copy__ = copy
+
+ def __deepcopy__(self, memo):
+ from copy import deepcopy
+ new = self.__class__()
+ for key, wr in self.data.items():
+ o = wr()
+ if o is not None:
+ new[deepcopy(key, memo)] = o
+ return new
+
def get(self, key, default=None):
try:
wr = self.data[key]
@@ -251,6 +262,17 @@ class WeakKeyDictionary(collections.MutableMapping):
new[o] = value
return new
+ __copy__ = copy
+
+ def __deepcopy__(self, memo):
+ from copy import deepcopy
+ new = self.__class__()
+ for key, value in self.data.items():
+ o = key()
+ if o is not None:
+ new[o] = deepcopy(value, memo)
+ return new
+
def get(self, key, default=None):
return self.data.get(ref(key),default)