summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2015-11-07 09:15:32 (GMT)
committerSerhiy Storchaka <storchaka@gmail.com>2015-11-07 09:15:32 (GMT)
commitda87e45add0d8a4834a43a57ba26e6c6d74a7ab8 (patch)
treec694d5651ac6a8e1087294a321f4e78480f530c3
parent43415ba5717fcdac45c36f6880e90a114dd538ad (diff)
downloadcpython-da87e45add0d8a4834a43a57ba26e6c6d74a7ab8.zip
cpython-da87e45add0d8a4834a43a57ba26e6c6d74a7ab8.tar.gz
cpython-da87e45add0d8a4834a43a57ba26e6c6d74a7ab8.tar.bz2
Issue #892902: Fixed pickling recursive objects.
-rw-r--r--Lib/pickle.py8
-rw-r--r--Lib/test/pickletester.py105
-rw-r--r--Lib/test/test_cpickle.py39
-rw-r--r--Misc/NEWS2
-rw-r--r--Modules/cPickle.c21
5 files changed, 145 insertions, 30 deletions
diff --git a/Lib/pickle.py b/Lib/pickle.py
index 299de16..1b3196f 100644
--- a/Lib/pickle.py
+++ b/Lib/pickle.py
@@ -402,7 +402,13 @@ class Pickler:
write(REDUCE)
if obj is not None:
- self.memoize(obj)
+ # If the object is already in the memo, this means it is
+ # recursive. In this case, throw away everything we put on the
+ # stack, and fetch the object back from the memo.
+ if id(obj) in self.memo:
+ write(POP + self.get(self.memo[id(obj)][0]))
+ else:
+ self.memoize(obj)
# More new special cases (that work with older protocols as
# well): when __reduce__ returns a tuple with 4 or 5 items,
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index f7b9225..d8346ea 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -117,6 +117,18 @@ class E(C):
def __getinitargs__(self):
return ()
+class H(object):
+ pass
+
+# Hashable mutable key
+class K(object):
+ def __init__(self, value):
+ self.value = value
+
+ def __reduce__(self):
+ # Shouldn't support the recursion itself
+ return K, (self.value,)
+
import __main__
__main__.C = C
C.__module__ = "__main__"
@@ -124,6 +136,10 @@ __main__.D = D
D.__module__ = "__main__"
__main__.E = E
E.__module__ = "__main__"
+__main__.H = H
+H.__module__ = "__main__"
+__main__.K = K
+K.__module__ = "__main__"
class myint(int):
def __init__(self, x):
@@ -676,18 +692,21 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(l, proto)
x = self.loads(s)
+ self.assertIsInstance(x, list)
self.assertEqual(len(x), 1)
- self.assertTrue(x is x[0])
+ self.assertIs(x[0], x)
- def test_recursive_tuple(self):
+ def test_recursive_tuple_and_list(self):
t = ([],)
t[0].append(t)
for proto in protocols:
s = self.dumps(t, proto)
x = self.loads(s)
+ self.assertIsInstance(x, tuple)
self.assertEqual(len(x), 1)
+ self.assertIsInstance(x[0], list)
self.assertEqual(len(x[0]), 1)
- self.assertTrue(x is x[0][0])
+ self.assertIs(x[0][0], x)
def test_recursive_dict(self):
d = {}
@@ -695,8 +714,50 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(d, proto)
x = self.loads(s)
+ self.assertIsInstance(x, dict)
self.assertEqual(x.keys(), [1])
- self.assertTrue(x[1] is x)
+ self.assertIs(x[1], x)
+
+ def test_recursive_dict_key(self):
+ d = {}
+ k = K(d)
+ d[k] = 1
+ for proto in protocols:
+ s = self.dumps(d, proto)
+ x = self.loads(s)
+ self.assertIsInstance(x, dict)
+ self.assertEqual(len(x.keys()), 1)
+ self.assertIsInstance(x.keys()[0], K)
+ self.assertIs(x.keys()[0].value, x)
+
+ def test_recursive_list_subclass(self):
+ y = MyList()
+ y.append(y)
+ s = self.dumps(y, 2)
+ x = self.loads(s)
+ self.assertIsInstance(x, MyList)
+ self.assertEqual(len(x), 1)
+ self.assertIs(x[0], x)
+
+ def test_recursive_dict_subclass(self):
+ d = MyDict()
+ d[1] = d
+ s = self.dumps(d, 2)
+ x = self.loads(s)
+ self.assertIsInstance(x, MyDict)
+ self.assertEqual(x.keys(), [1])
+ self.assertIs(x[1], x)
+
+ def test_recursive_dict_subclass_key(self):
+ d = MyDict()
+ k = K(d)
+ d[k] = 1
+ s = self.dumps(d, 2)
+ x = self.loads(s)
+ self.assertIsInstance(x, MyDict)
+ self.assertEqual(len(x.keys()), 1)
+ self.assertIsInstance(x.keys()[0], K)
+ self.assertIs(x.keys()[0].value, x)
def test_recursive_inst(self):
i = C()
@@ -721,6 +782,42 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(x[0].attr.keys(), [1])
self.assertTrue(x[0].attr[1] is x)
+ def check_recursive_collection_and_inst(self, factory):
+ h = H()
+ y = factory([h])
+ h.attr = y
+ for proto in protocols:
+ s = self.dumps(y, proto)
+ x = self.loads(s)
+ self.assertIsInstance(x, type(y))
+ self.assertEqual(len(x), 1)
+ self.assertIsInstance(list(x)[0], H)
+ self.assertIs(list(x)[0].attr, x)
+
+ def test_recursive_list_and_inst(self):
+ self.check_recursive_collection_and_inst(list)
+
+ def test_recursive_tuple_and_inst(self):
+ self.check_recursive_collection_and_inst(tuple)
+
+ def test_recursive_dict_and_inst(self):
+ self.check_recursive_collection_and_inst(dict.fromkeys)
+
+ def test_recursive_set_and_inst(self):
+ self.check_recursive_collection_and_inst(set)
+
+ def test_recursive_frozenset_and_inst(self):
+ self.check_recursive_collection_and_inst(frozenset)
+
+ def test_recursive_list_subclass_and_inst(self):
+ self.check_recursive_collection_and_inst(MyList)
+
+ def test_recursive_tuple_subclass_and_inst(self):
+ self.check_recursive_collection_and_inst(MyTuple)
+
+ def test_recursive_dict_subclass_and_inst(self):
+ self.check_recursive_collection_and_inst(MyDict.fromkeys)
+
if have_unicode:
def test_unicode(self):
endcases = [u'', u'<\\u>', u'<\\\u1234>', u'<\n>',
diff --git a/Lib/test/test_cpickle.py b/Lib/test/test_cpickle.py
index f6b3347..0a1eb43 100644
--- a/Lib/test/test_cpickle.py
+++ b/Lib/test/test_cpickle.py
@@ -1,6 +1,7 @@
import cPickle
import cStringIO
import io
+import functools
import unittest
from test.pickletester import (AbstractUnpickleTests,
AbstractPickleTests,
@@ -151,31 +152,6 @@ class cPickleFastPicklerTests(AbstractPickleTests):
finally:
self.close(f)
- def test_recursive_list(self):
- self.assertRaises(ValueError,
- AbstractPickleTests.test_recursive_list,
- self)
-
- def test_recursive_tuple(self):
- self.assertRaises(ValueError,
- AbstractPickleTests.test_recursive_tuple,
- self)
-
- def test_recursive_inst(self):
- self.assertRaises(ValueError,
- AbstractPickleTests.test_recursive_inst,
- self)
-
- def test_recursive_dict(self):
- self.assertRaises(ValueError,
- AbstractPickleTests.test_recursive_dict,
- self)
-
- def test_recursive_multi(self):
- self.assertRaises(ValueError,
- AbstractPickleTests.test_recursive_multi,
- self)
-
def test_nonrecursive_deep(self):
# If it's not cyclic, it should pickle OK even if the nesting
# depth exceeds PY_CPICKLE_FAST_LIMIT. That happens to be
@@ -187,6 +163,19 @@ class cPickleFastPicklerTests(AbstractPickleTests):
b = self.loads(self.dumps(a))
self.assertEqual(a, b)
+for name in dir(AbstractPickleTests):
+ if name.startswith('test_recursive_'):
+ func = getattr(AbstractPickleTests, name)
+ if '_subclass' in name and '_and_inst' not in name:
+ assert_args = RuntimeError, 'maximum recursion depth exceeded'
+ else:
+ assert_args = ValueError, "can't pickle cyclic objects"
+ def wrapper(self, func=func, assert_args=assert_args):
+ with self.assertRaisesRegexp(*assert_args):
+ func(self)
+ functools.update_wrapper(wrapper, func)
+ setattr(cPickleFastPicklerTests, name, wrapper)
+
class cStringIOCPicklerFastTests(cStringIOMixin, cPickleFastPicklerTests):
pass
diff --git a/Misc/NEWS b/Misc/NEWS
index 6f056b2..f9163d6 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -46,6 +46,8 @@ Core and Builtins
Library
-------
+- Issue #892902: Fixed pickling recursive objects.
+
- Issue #18010: Fix the pydoc GUI's search function to handle exceptions
from importing packages.
diff --git a/Modules/cPickle.c b/Modules/cPickle.c
index 89448a6..0e93723 100644
--- a/Modules/cPickle.c
+++ b/Modules/cPickle.c
@@ -2533,6 +2533,27 @@ save_reduce(Picklerobject *self, PyObject *args, PyObject *fn, PyObject *ob)
/* Memoize. */
/* XXX How can ob be NULL? */
if (ob != NULL) {
+ /* If the object is already in the memo, this means it is
+ recursive. In this case, throw away everything we put on the
+ stack, and fetch the object back from the memo. */
+ if (Py_REFCNT(ob) > 1 && !self->fast) {
+ PyObject *py_ob_id = PyLong_FromVoidPtr(ob);
+ if (!py_ob_id)
+ return -1;
+ if (PyDict_GetItem(self->memo, py_ob_id)) {
+ const char pop_op = POP;
+ if (self->write_func(self, &pop_op, 1) < 0 ||
+ get(self, py_ob_id) < 0) {
+ Py_DECREF(py_ob_id);
+ return -1;
+ }
+ Py_DECREF(py_ob_id);
+ return 0;
+ }
+ Py_DECREF(py_ob_id);
+ if (PyErr_Occurred())
+ return -1;
+ }
if (state && !PyDict_Check(state)) {
if (put2(self, ob) < 0)
return -1;