diff options
-rw-r--r-- | Lib/test/test_pickle.py | 65 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2017-10-23-12-05-33.bpo-28416.Ldnw8X.rst | 3 | ||||
-rw-r--r-- | Modules/_pickle.c | 136 |
3 files changed, 164 insertions, 40 deletions
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index ee71c63..895ed48 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -6,6 +6,7 @@ import io import collections import struct import sys +import weakref import unittest from test import support @@ -117,6 +118,66 @@ class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests, pickler = pickle._Pickler unpickler = pickle._Unpickler + @support.cpython_only + def test_pickler_reference_cycle(self): + def check(Pickler): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + f = io.BytesIO() + pickler = Pickler(f, proto) + pickler.dump('abc') + self.assertEqual(self.loads(f.getvalue()), 'abc') + pickler = Pickler(io.BytesIO()) + self.assertEqual(pickler.persistent_id('def'), 'def') + r = weakref.ref(pickler) + del pickler + self.assertIsNone(r()) + + class PersPickler(self.pickler): + def persistent_id(subself, obj): + return obj + check(PersPickler) + + class PersPickler(self.pickler): + @classmethod + def persistent_id(cls, obj): + return obj + check(PersPickler) + + class PersPickler(self.pickler): + @staticmethod + def persistent_id(obj): + return obj + check(PersPickler) + + @support.cpython_only + def test_unpickler_reference_cycle(self): + def check(Unpickler): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto))) + self.assertEqual(unpickler.load(), 'abc') + unpickler = Unpickler(io.BytesIO()) + self.assertEqual(unpickler.persistent_load('def'), 'def') + r = weakref.ref(unpickler) + del unpickler + self.assertIsNone(r()) + + class PersUnpickler(self.unpickler): + def persistent_load(subself, pid): + return pid + check(PersUnpickler) + + class PersUnpickler(self.unpickler): + @classmethod + def persistent_load(cls, pid): + return pid + check(PersUnpickler) + + class PersUnpickler(self.unpickler): + @staticmethod + def persistent_load(pid): + return pid + check(PersUnpickler) + class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests): @@ -197,7 +258,7 @@ if has_c_implementation: check_sizeof = support.check_sizeof def test_pickler(self): - basesize = support.calcobjsize('5P2n3i2n3iP') + basesize = support.calcobjsize('6P2n3i2n3iP') p = _pickle.Pickler(io.BytesIO()) self.assertEqual(object.__sizeof__(p), basesize) MT_size = struct.calcsize('3nP0n') @@ -214,7 +275,7 @@ if has_c_implementation: 0) # Write buffer is cleared after every dump(). def test_unpickler(self): - basesize = support.calcobjsize('2Pn2P 2P2n2i5P 2P3n6P2n2i') + basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n6P2n2i') unpickler = _pickle.Unpickler P = struct.calcsize('P') # Size of memo table entry. n = struct.calcsize('n') # Size of mark table entry. diff --git a/Misc/NEWS.d/next/Library/2017-10-23-12-05-33.bpo-28416.Ldnw8X.rst b/Misc/NEWS.d/next/Library/2017-10-23-12-05-33.bpo-28416.Ldnw8X.rst new file mode 100644 index 0000000..b101482 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2017-10-23-12-05-33.bpo-28416.Ldnw8X.rst @@ -0,0 +1,3 @@ +Instances of pickle.Pickler subclass with the persistent_id() method and +pickle.Unpickler subclass with the persistent_load() method no longer create +reference cycles. diff --git a/Modules/_pickle.c b/Modules/_pickle.c index 943c701..da915ef 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -360,6 +360,69 @@ _Pickle_FastCall(PyObject *func, PyObject *obj) /*************************************************************************/ +/* Retrieve and deconstruct a method for avoiding a reference cycle + (pickler -> bound method of pickler -> pickler) */ +static int +init_method_ref(PyObject *self, _Py_Identifier *name, + PyObject **method_func, PyObject **method_self) +{ + PyObject *func, *func2; + + /* *method_func and *method_self should be consistent. All refcount decrements + should be occurred after setting *method_self and *method_func. */ + func = _PyObject_GetAttrId(self, name); + if (func == NULL) { + *method_self = NULL; + Py_CLEAR(*method_func); + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return -1; + } + PyErr_Clear(); + return 0; + } + + if (PyMethod_Check(func) && PyMethod_GET_SELF(func) == self) { + /* Deconstruct a bound Python method */ + func2 = PyMethod_GET_FUNCTION(func); + Py_INCREF(func2); + *method_self = self; /* borrowed */ + Py_XSETREF(*method_func, func2); + Py_DECREF(func); + return 0; + } + else { + *method_self = NULL; + Py_XSETREF(*method_func, func); + return 0; + } +} + +/* Bind a method if it was deconstructed */ +static PyObject * +reconstruct_method(PyObject *func, PyObject *self) +{ + if (self) { + return PyMethod_New(func, self); + } + else { + Py_INCREF(func); + return func; + } +} + +static PyObject * +call_method(PyObject *func, PyObject *self, PyObject *obj) +{ + if (self) { + return PyObject_CallFunctionObjArgs(func, self, obj, NULL); + } + else { + return PyObject_CallFunctionObjArgs(func, obj, NULL); + } +} + +/*************************************************************************/ + /* Internal data type used as the unpickling stack. */ typedef struct { PyObject_VAR_HEAD @@ -552,6 +615,8 @@ typedef struct PicklerObject { objects to support self-referential objects pickling. */ PyObject *pers_func; /* persistent_id() method, can be NULL */ + PyObject *pers_func_self; /* borrowed reference to self if pers_func + is an unbound method, NULL otherwise */ PyObject *dispatch_table; /* private dispatch_table, can be NULL */ PyObject *write; /* write() method of the output stream. */ @@ -590,6 +655,8 @@ typedef struct UnpicklerObject { Py_ssize_t memo_len; /* Number of objects in the memo */ PyObject *pers_func; /* persistent_load() method, can be NULL. */ + PyObject *pers_func_self; /* borrowed reference to self if pers_func + is an unbound method, NULL otherwise */ Py_buffer buffer; char *input_buffer; @@ -3444,7 +3511,7 @@ save_type(PicklerObject *self, PyObject *obj) } static int -save_pers(PicklerObject *self, PyObject *obj, PyObject *func) +save_pers(PicklerObject *self, PyObject *obj) { PyObject *pid = NULL; int status = 0; @@ -3452,8 +3519,7 @@ save_pers(PicklerObject *self, PyObject *obj, PyObject *func) const char persid_op = PERSID; const char binpersid_op = BINPERSID; - Py_INCREF(obj); - pid = _Pickle_FastCall(func, obj); + pid = call_method(self->pers_func, self->pers_func_self, obj); if (pid == NULL) return -1; @@ -3831,7 +3897,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save) 0 if it did nothing successfully; 1 if a persistent id was saved. */ - if ((status = save_pers(self, obj, self->pers_func)) != 0) + if ((status = save_pers(self, obj)) != 0) goto done; } @@ -4246,13 +4312,10 @@ _pickle_Pickler___init___impl(PicklerObject *self, PyObject *file, self->fast_nesting = 0; self->fast_memo = NULL; - self->pers_func = _PyObject_GetAttrId((PyObject *)self, - &PyId_persistent_id); - if (self->pers_func == NULL) { - if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { - return -1; - } - PyErr_Clear(); + if (init_method_ref((PyObject *)self, &PyId_persistent_id, + &self->pers_func, &self->pers_func_self) < 0) + { + return -1; } self->dispatch_table = _PyObject_GetAttrId((PyObject *)self, @@ -4519,11 +4582,11 @@ Pickler_set_memo(PicklerObject *self, PyObject *obj) static PyObject * Pickler_get_persid(PicklerObject *self) { - if (self->pers_func == NULL) + if (self->pers_func == NULL) { PyErr_SetString(PyExc_AttributeError, "persistent_id"); - else - Py_INCREF(self->pers_func); - return self->pers_func; + return NULL; + } + return reconstruct_method(self->pers_func, self->pers_func_self); } static int @@ -4540,6 +4603,7 @@ Pickler_set_persid(PicklerObject *self, PyObject *value) return -1; } + self->pers_func_self = NULL; Py_INCREF(value); Py_XSETREF(self->pers_func, value); @@ -5489,7 +5553,7 @@ load_stack_global(UnpicklerObject *self) static int load_persid(UnpicklerObject *self) { - PyObject *pid; + PyObject *pid, *obj; Py_ssize_t len; char *s; @@ -5509,13 +5573,12 @@ load_persid(UnpicklerObject *self) return -1; } - /* This does not leak since _Pickle_FastCall() steals the reference - to pid first. */ - pid = _Pickle_FastCall(self->pers_func, pid); - if (pid == NULL) + obj = call_method(self->pers_func, self->pers_func_self, pid); + Py_DECREF(pid); + if (obj == NULL) return -1; - PDATA_PUSH(self->stack, pid, -1); + PDATA_PUSH(self->stack, obj, -1); return 0; } else { @@ -5530,20 +5593,19 @@ load_persid(UnpicklerObject *self) static int load_binpersid(UnpicklerObject *self) { - PyObject *pid; + PyObject *pid, *obj; if (self->pers_func) { PDATA_POP(self->stack, pid); if (pid == NULL) return -1; - /* This does not leak since _Pickle_FastCall() steals the - reference to pid first. */ - pid = _Pickle_FastCall(self->pers_func, pid); - if (pid == NULL) + obj = call_method(self->pers_func, self->pers_func_self, pid); + Py_DECREF(pid); + if (obj == NULL) return -1; - PDATA_PUSH(self->stack, pid, -1); + PDATA_PUSH(self->stack, obj, -1); return 0; } else { @@ -6690,13 +6752,10 @@ _pickle_Unpickler___init___impl(UnpicklerObject *self, PyObject *file, self->fix_imports = fix_imports; - self->pers_func = _PyObject_GetAttrId((PyObject *)self, - &PyId_persistent_load); - if (self->pers_func == NULL) { - if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { - return -1; - } - PyErr_Clear(); + if (init_method_ref((PyObject *)self, &PyId_persistent_load, + &self->pers_func, &self->pers_func_self) < 0) + { + return -1; } self->stack = (Pdata *)Pdata_New(); @@ -6983,11 +7042,11 @@ Unpickler_set_memo(UnpicklerObject *self, PyObject *obj) static PyObject * Unpickler_get_persload(UnpicklerObject *self) { - if (self->pers_func == NULL) + if (self->pers_func == NULL) { PyErr_SetString(PyExc_AttributeError, "persistent_load"); - else - Py_INCREF(self->pers_func); - return self->pers_func; + return NULL; + } + return reconstruct_method(self->pers_func, self->pers_func_self); } static int @@ -7005,6 +7064,7 @@ Unpickler_set_persload(UnpicklerObject *self, PyObject *value) return -1; } + self->pers_func_self = NULL; Py_INCREF(value); Py_XSETREF(self->pers_func, value); |