diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2015-03-31 11:07:24 (GMT) |
---|---|---|
committer | Serhiy Storchaka <storchaka@gmail.com> | 2015-03-31 11:07:24 (GMT) |
commit | 58e4134a1cf668a16006b896fca093d5a79966e3 (patch) | |
tree | b4baced8cb92d58dda991491c786e39fdfa5725e | |
parent | 72e731cc03f29cdb8bf17bd9ea34c8448954c798 (diff) | |
download | cpython-58e4134a1cf668a16006b896fca093d5a79966e3.zip cpython-58e4134a1cf668a16006b896fca093d5a79966e3.tar.gz cpython-58e4134a1cf668a16006b896fca093d5a79966e3.tar.bz2 |
Issue #23611: Serializing more "lookupable" objects (such as unbound methods
or nested classes) now are supported with pickle protocols < 4.
-rw-r--r-- | Doc/whatsnew/3.5.rst | 7 | ||||
-rw-r--r-- | Lib/pickle.py | 33 | ||||
-rw-r--r-- | Lib/test/pickletester.py | 19 | ||||
-rw-r--r-- | Misc/NEWS | 3 | ||||
-rw-r--r-- | Modules/_pickle.c | 118 |
5 files changed, 114 insertions, 66 deletions
diff --git a/Doc/whatsnew/3.5.rst b/Doc/whatsnew/3.5.rst index 434c0a7..ad6475a 100644 --- a/Doc/whatsnew/3.5.rst +++ b/Doc/whatsnew/3.5.rst @@ -370,6 +370,13 @@ os * :class:`os.stat_result` now has a :attr:`~os.stat_result.st_file_attributes` attribute on Windows. (Contributed by Ben Hoyt in :issue:`21719`.) +pickle +------ + +* Serializing more "lookupable" objects (such as unbound methods or nested + classes) now are supported with pickle protocols < 4. + (Contributed by Serhiy Storchaka in :issue:`23611`.) + re -- diff --git a/Lib/pickle.py b/Lib/pickle.py index 67382ae..6c26c5e 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -258,24 +258,20 @@ class _Unframer: # Tools used for pickling. -def _getattribute(obj, name, allow_qualname=False): - dotted_path = name.split(".") - if not allow_qualname and len(dotted_path) > 1: - raise AttributeError("Can't get qualified attribute {!r} on {!r}; " + - "use protocols >= 4 to enable support" - .format(name, obj)) - for subpath in dotted_path: +def _getattribute(obj, name): + for subpath in name.split('.'): if subpath == '<locals>': raise AttributeError("Can't get local attribute {!r} on {!r}" .format(name, obj)) try: + parent = obj obj = getattr(obj, subpath) except AttributeError: raise AttributeError("Can't get attribute {!r} on {!r}" .format(name, obj)) - return obj + return obj, parent -def whichmodule(obj, name, allow_qualname=False): +def whichmodule(obj, name): """Find the module an object belong to.""" module_name = getattr(obj, '__module__', None) if module_name is not None: @@ -286,7 +282,7 @@ def whichmodule(obj, name, allow_qualname=False): if module_name == '__main__' or module is None: continue try: - if _getattribute(module, name, allow_qualname) is obj: + if _getattribute(module, name)[0] is obj: return module_name except AttributeError: pass @@ -899,16 +895,16 @@ class _Pickler: write = self.write memo = self.memo - if name is None and self.proto >= 4: + if name is None: name = getattr(obj, '__qualname__', None) if name is None: name = obj.__name__ - module_name = whichmodule(obj, name, allow_qualname=self.proto >= 4) + module_name = whichmodule(obj, name) try: __import__(module_name, level=0) module = sys.modules[module_name] - obj2 = _getattribute(module, name, allow_qualname=self.proto >= 4) + obj2, parent = _getattribute(module, name) except (ImportError, KeyError, AttributeError): raise PicklingError( "Can't pickle %r: it's not found as %s.%s" % @@ -930,11 +926,16 @@ class _Pickler: else: write(EXT4 + pack("<i", code)) return + lastname = name.rpartition('.')[2] + if parent is module: + name = lastname # Non-ASCII identifiers are supported only with protocols >= 3. if self.proto >= 4: self.save(module_name) self.save(name) write(STACK_GLOBAL) + elif parent is not module: + self.save_reduce(getattr, (parent, lastname)) elif self.proto >= 3: write(GLOBAL + bytes(module_name, "utf-8") + b'\n' + bytes(name, "utf-8") + b'\n') @@ -1373,8 +1374,10 @@ class _Unpickler: elif module in _compat_pickle.IMPORT_MAPPING: module = _compat_pickle.IMPORT_MAPPING[module] __import__(module, level=0) - return _getattribute(sys.modules[module], name, - allow_qualname=self.proto >= 4) + if self.proto >= 4: + return _getattribute(sys.modules[module], name)[0] + else: + return getattr(sys.modules[module], name) def load_reduce(self): stack = self.stack diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index c95fb22..a0c7a0a 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -1602,13 +1602,24 @@ class AbstractPickleTests(unittest.TestCase): class B: class C: pass - - for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): for obj in [Nested.A, Nested.A.B, Nested.A.B.C]: with self.subTest(proto=proto, obj=obj): unpickled = self.loads(self.dumps(obj, proto)) self.assertIs(obj, unpickled) + def test_recursive_nested_names(self): + global Recursive + class Recursive: + pass + Recursive.mod = sys.modules[Recursive.__module__] + Recursive.__qualname__ = 'Recursive.mod.Recursive' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + unpickled = self.loads(self.dumps(Recursive, proto)) + self.assertIs(unpickled, Recursive) + del Recursive.mod # break reference loop + def test_py_methods(self): global PyMethodsTest class PyMethodsTest: @@ -1647,7 +1658,7 @@ class AbstractPickleTests(unittest.TestCase): (PyMethodsTest.biscuits, PyMethodsTest), (PyMethodsTest.Nested.pie, PyMethodsTest.Nested) ) - for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): for method in py_methods: with self.subTest(proto=proto, method=method): unpickled = self.loads(self.dumps(method, proto)) @@ -1687,7 +1698,7 @@ class AbstractPickleTests(unittest.TestCase): (Subclass.Nested("sweet").count, ("e",)), (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")), ) - for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): for method, args in c_methods: with self.subTest(proto=proto, method=method): unpickled = self.loads(self.dumps(method, proto)) @@ -13,6 +13,9 @@ Core and Builtins Library ------- +- Issue #23611: Serializing more "lookupable" objects (such as unbound methods + or nested classes) now are supported with pickle protocols < 4. + - Issue #13583: sqlite3.Row now supports slice indexing. - Issue #18473: Fixed 2to3 and 3to2 compatible pickle mappings. Fixed diff --git a/Modules/_pickle.c b/Modules/_pickle.c index c1e2b40..d4130d1 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -152,6 +152,8 @@ typedef struct { /* codecs.encode, used for saving bytes in older protocols */ PyObject *codecs_encode; + /* builtins.getattr, used for saving nested names with protocol < 4 */ + PyObject *getattr; } PickleState; /* Forward declaration of the _pickle module definition. */ @@ -188,16 +190,26 @@ _Pickle_ClearState(PickleState *st) Py_CLEAR(st->name_mapping_3to2); Py_CLEAR(st->import_mapping_3to2); Py_CLEAR(st->codecs_encode); + Py_CLEAR(st->getattr); } /* Initialize the given pickle module state. */ static int _Pickle_InitState(PickleState *st) { + PyObject *builtins; PyObject *copyreg = NULL; PyObject *compat_pickle = NULL; PyObject *codecs = NULL; + builtins = PyEval_GetBuiltins(); + if (builtins == NULL) + goto error; + st->getattr = PyDict_GetItemString(builtins, "getattr"); + if (st->getattr == NULL) + goto error; + Py_INCREF(st->getattr); + copyreg = PyImport_ImportModule("copyreg"); if (!copyreg) goto error; @@ -1535,7 +1547,7 @@ memo_put(PicklerObject *self, PyObject *obj) } static PyObject * -get_dotted_path(PyObject *obj, PyObject *name, int allow_qualname) { +get_dotted_path(PyObject *obj, PyObject *name) { _Py_static_string(PyId_dot, "."); _Py_static_string(PyId_locals, "<locals>"); PyObject *dotted_path; @@ -1546,20 +1558,6 @@ get_dotted_path(PyObject *obj, PyObject *name, int allow_qualname) { return NULL; n = PyList_GET_SIZE(dotted_path); assert(n >= 1); - if (!allow_qualname && n > 1) { - if (obj == NULL) - PyErr_Format(PyExc_AttributeError, - "Can't pickle qualified object %R; " - "use protocols >= 4 to enable support", - name); - else - PyErr_Format(PyExc_AttributeError, - "Can't pickle qualified attribute %R on %R; " - "use protocols >= 4 to enable support", - name, obj); - Py_DECREF(dotted_path); - return NULL; - } for (i = 0; i < n; i++) { PyObject *subpath = PyList_GET_ITEM(dotted_path, i); PyObject *result = PyUnicode_RichCompare( @@ -1582,22 +1580,28 @@ get_dotted_path(PyObject *obj, PyObject *name, int allow_qualname) { } static PyObject * -get_deep_attribute(PyObject *obj, PyObject *names) +get_deep_attribute(PyObject *obj, PyObject *names, PyObject **pparent) { Py_ssize_t i, n; + PyObject *parent = NULL; assert(PyList_CheckExact(names)); Py_INCREF(obj); n = PyList_GET_SIZE(names); for (i = 0; i < n; i++) { PyObject *name = PyList_GET_ITEM(names, i); - PyObject *tmp; - tmp = PyObject_GetAttr(obj, name); - Py_DECREF(obj); - if (tmp == NULL) + Py_XDECREF(parent); + parent = obj; + obj = PyObject_GetAttr(parent, name); + if (obj == NULL) { + Py_DECREF(parent); return NULL; - obj = tmp; + } } + if (pparent != NULL) + *pparent = parent; + else + Py_XDECREF(parent); return obj; } @@ -1617,18 +1621,22 @@ getattribute(PyObject *obj, PyObject *name, int allow_qualname) { PyObject *dotted_path, *attr; - dotted_path = get_dotted_path(obj, name, allow_qualname); - if (dotted_path == NULL) - return NULL; - attr = get_deep_attribute(obj, dotted_path); - Py_DECREF(dotted_path); + if (allow_qualname) { + dotted_path = get_dotted_path(obj, name); + if (dotted_path == NULL) + return NULL; + attr = get_deep_attribute(obj, dotted_path, NULL); + Py_DECREF(dotted_path); + } + else + attr = PyObject_GetAttr(obj, name); if (attr == NULL) reformat_attribute_error(obj, name); return attr; } static PyObject * -whichmodule(PyObject *global, PyObject *global_name, int allow_qualname) +whichmodule(PyObject *global, PyObject *dotted_path) { PyObject *module_name; PyObject *modules_dict; @@ -1637,7 +1645,6 @@ whichmodule(PyObject *global, PyObject *global_name, int allow_qualname) _Py_IDENTIFIER(__module__); _Py_IDENTIFIER(modules); _Py_IDENTIFIER(__main__); - PyObject *dotted_path; module_name = _PyObject_GetAttrId(global, &PyId___module__); @@ -1663,10 +1670,6 @@ whichmodule(PyObject *global, PyObject *global_name, int allow_qualname) return NULL; } - dotted_path = get_dotted_path(NULL, global_name, allow_qualname); - if (dotted_path == NULL) - return NULL; - i = 0; while (PyDict_Next(modules_dict, &i, &module_name, &module)) { PyObject *candidate; @@ -1676,19 +1679,16 @@ whichmodule(PyObject *global, PyObject *global_name, int allow_qualname) if (module == Py_None) continue; - candidate = get_deep_attribute(module, dotted_path); + candidate = get_deep_attribute(module, dotted_path, NULL); if (candidate == NULL) { - if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { - Py_DECREF(dotted_path); + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) return NULL; - } PyErr_Clear(); continue; } if (candidate == global) { Py_INCREF(module_name); - Py_DECREF(dotted_path); Py_DECREF(candidate); return module_name; } @@ -1698,7 +1698,6 @@ whichmodule(PyObject *global, PyObject *global_name, int allow_qualname) /* If no module is found, use __main__. */ module_name = _PyUnicode_FromId(&PyId___main__); Py_INCREF(module_name); - Py_DECREF(dotted_path); return module_name; } @@ -3105,6 +3104,9 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) PyObject *global_name = NULL; PyObject *module_name = NULL; PyObject *module = NULL; + PyObject *parent = NULL; + PyObject *dotted_path = NULL; + PyObject *lastname = NULL; PyObject *cls; PickleState *st = _Pickle_GetGlobalState(); int status = 0; @@ -3118,13 +3120,11 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) global_name = name; } else { - if (self->proto >= 4) { - global_name = _PyObject_GetAttrId(obj, &PyId___qualname__); - if (global_name == NULL) { - if (!PyErr_ExceptionMatches(PyExc_AttributeError)) - goto error; - PyErr_Clear(); - } + global_name = _PyObject_GetAttrId(obj, &PyId___qualname__); + if (global_name == NULL) { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) + goto error; + PyErr_Clear(); } if (global_name == NULL) { global_name = _PyObject_GetAttrId(obj, &PyId___name__); @@ -3133,7 +3133,10 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) } } - module_name = whichmodule(obj, global_name, self->proto >= 4); + dotted_path = get_dotted_path(module, global_name); + if (dotted_path == NULL) + goto error; + module_name = whichmodule(obj, dotted_path); if (module_name == NULL) goto error; @@ -3152,7 +3155,10 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) obj, module_name); goto error; } - cls = getattribute(module, global_name, self->proto >= 4); + lastname = PyList_GET_ITEM(dotted_path, PyList_GET_SIZE(dotted_path)-1); + Py_INCREF(lastname); + cls = get_deep_attribute(module, dotted_path, &parent); + Py_CLEAR(dotted_path); if (cls == NULL) { PyErr_Format(st->PicklingError, "Can't pickle %R: attribute lookup %S on %S failed", @@ -3239,6 +3245,11 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) } else { gen_global: + if (parent == module) { + Py_INCREF(lastname); + Py_DECREF(global_name); + global_name = lastname; + } if (self->proto >= 4) { const char stack_global_op = STACK_GLOBAL; @@ -3250,6 +3261,15 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) if (_Pickler_Write(self, &stack_global_op, 1) < 0) goto error; } + else if (parent != module) { + PickleState *st = _Pickle_GetGlobalState(); + PyObject *reduce_value = Py_BuildValue("(O(OO))", + st->getattr, parent, lastname); + status = save_reduce(self, reduce_value, NULL); + Py_DECREF(reduce_value); + if (status < 0) + goto error; + } else { /* Generate a normal global opcode if we are using a pickle protocol < 4, or if the object is not registered in the @@ -3328,6 +3348,9 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) Py_XDECREF(module_name); Py_XDECREF(global_name); Py_XDECREF(module); + Py_XDECREF(parent); + Py_XDECREF(dotted_path); + Py_XDECREF(lastname); return status; } @@ -7150,6 +7173,7 @@ pickle_traverse(PyObject *m, visitproc visit, void *arg) Py_VISIT(st->name_mapping_3to2); Py_VISIT(st->import_mapping_3to2); Py_VISIT(st->codecs_encode); + Py_VISIT(st->getattr); return 0; } |