summaryrefslogtreecommitdiffstats
path: root/Modules/cPickle.c
diff options
context:
space:
mode:
Diffstat (limited to 'Modules/cPickle.c')
-rw-r--r--Modules/cPickle.c264
1 files changed, 195 insertions, 69 deletions
diff --git a/Modules/cPickle.c b/Modules/cPickle.c
index 09485b9..f09e502 100644
--- a/Modules/cPickle.c
+++ b/Modules/cPickle.c
@@ -119,6 +119,12 @@ static PyObject *extension_cache;
/* For looking up name pairs in copy_reg._extension_registry. */
static PyObject *two_tuple;
+/* object.__reduce__, the default reduce callable. */
+PyObject *object_reduce;
+
+/* copy_reg._better_reduce, the protocol 2 reduction function. */
+PyObject *better_reduce;
+
static PyObject *__class___str, *__getinitargs___str, *__dict___str,
*__getstate___str, *__setstate___str, *__name___str, *__reduce___str,
*write_str, *append_str,
@@ -2181,38 +2187,142 @@ save_pers(Picklerobject *self, PyObject *args, PyObject *f)
return res;
}
-
+/* We're saving ob, and args is the 2-thru-5 tuple returned by the
+ * appropriate __reduce__ method for ob.
+ */
static int
-save_reduce(Picklerobject *self, PyObject *callable,
- PyObject *tup, PyObject *state, PyObject *ob)
-{
- static char reduce = REDUCE, build = BUILD;
-
- if (save(self, callable, 0) < 0)
+save_reduce(Picklerobject *self, PyObject *args, PyObject *ob)
+{
+ PyObject *callable;
+ PyObject *argtup;
+ PyObject *state = NULL;
+ PyObject *listitems = NULL;
+ PyObject *dictitems = NULL;
+
+ int use_newobj = self->proto >= 2;
+
+ static char reduce = REDUCE;
+ static char build = BUILD;
+ static char newobj = NEWOBJ;
+
+ if (! PyArg_UnpackTuple(args, "save_reduce", 2, 5,
+ &callable,
+ &argtup,
+ &state,
+ &listitems,
+ &dictitems))
return -1;
- if (save(self, tup, 0) < 0)
- return -1;
+ if (state == Py_None)
+ state = NULL;
+ if (listitems == Py_None)
+ listitems = NULL;
+ if (dictitems == Py_None)
+ dictitems = NULL;
- if (self->write_func(self, &reduce, 1) < 0)
- return -1;
+ /* Protocol 2 special case: if callable's name is __newobj__, use
+ * NEWOBJ. This consumes a lot of code.
+ */
+ if (use_newobj) {
+ PyObject *temp = PyObject_GetAttr(callable, __name___str);
+
+ if (temp == NULL) {
+ PyErr_Clear();
+ use_newobj = 0;
+ }
+ else {
+ use_newobj = PyString_Check(temp) &&
+ strcmp(PyString_AS_STRING(temp),
+ "__newobj__") == 0;
+ Py_DECREF(temp);
+ }
+ }
+ if (use_newobj) {
+ PyObject *cls;
+ PyObject *newargtup;
+ int n, i;
+
+ /* Sanity checks. */
+ n = PyTuple_Size(argtup);
+ if (n < 1) {
+ PyErr_SetString(PicklingError, "__newobj__ arglist "
+ "is empty");
+ return -1;
+ }
+
+ cls = PyTuple_GET_ITEM(argtup, 0);
+ if (! PyObject_HasAttrString(cls, "__new__")) {
+ PyErr_SetString(PicklingError, "args[0] from "
+ "__newobj__ args has no __new__");
+ return -1;
+ }
+
+ /* XXX How could ob be NULL? */
+ if (ob != NULL) {
+ PyObject *ob_dot_class;
+ ob_dot_class = PyObject_GetAttr(ob, __class___str);
+ if (ob_dot_class == NULL)
+ PyErr_Clear();
+ i = ob_dot_class != cls; /* true iff a problem */
+ Py_XDECREF(ob_dot_class);
+ if (i) {
+ PyErr_SetString(PicklingError, "args[0] from "
+ "__newobj__ args has the wrong class");
+ return -1;
+ }
+ }
+
+ /* Save the class and its __new__ arguments. */
+ if (save(self, cls, 0) < 0)
+ return -1;
+
+ newargtup = PyTuple_New(n-1); /* argtup[1:] */
+ if (newargtup == NULL)
+ return -1;
+ for (i = 1; i < n; ++i) {
+ PyObject *temp = PyTuple_GET_ITEM(argtup, i);
+ Py_INCREF(temp);
+ PyTuple_SET_ITEM(newargtup, i-1, temp);
+ }
+ i = save(self, newargtup, 0) < 0;
+ Py_DECREF(newargtup);
+ if (i < 0)
+ return -1;
+
+ /* Add NEWOBJ opcode. */
+ if (self->write_func(self, &newobj, 1) < 0)
+ return -1;
+ }
+ else {
+ /* Not using NEWOBJ. */
+ if (save(self, callable, 0) < 0 ||
+ save(self, argtup, 0) < 0 ||
+ self->write_func(self, &reduce, 1) < 0)
+ return -1;
+ }
+
+ /* Memoize. */
+ /* XXX How can ob be NULL? */
if (ob != NULL) {
if (state && !PyDict_Check(state)) {
if (put2(self, ob) < 0)
return -1;
}
- else {
- if (put(self, ob) < 0)
+ else if (put(self, ob) < 0)
return -1;
- }
}
- if (state) {
- if (save(self, state, 0) < 0)
- return -1;
- if (self->write_func(self, &build, 1) < 0)
+ if (listitems && batch_list(self, listitems) < 0)
+ return -1;
+
+ if (dictitems && batch_dict(self, dictitems) < 0)
+ return -1;
+
+ if (state) {
+ if (save(self, state, 0) < 0 ||
+ self->write_func(self, &build, 1) < 0)
return -1;
}
@@ -2223,9 +2333,10 @@ static int
save(Picklerobject *self, PyObject *args, int pers_save)
{
PyTypeObject *type;
- PyObject *py_ob_id = 0, *__reduce__ = 0, *t = 0, *arg_tup = 0,
- *callable = 0, *state = 0;
- int res = -1, tmp, size;
+ PyObject *py_ob_id = 0, *__reduce__ = 0, *t = 0;
+ PyObject *arg_tup;
+ int res = -1;
+ int tmp, size;
if (self->nesting++ > Py_GetRecursionLimit()){
PyErr_SetString(PyExc_RuntimeError,
@@ -2392,72 +2503,80 @@ save(Picklerobject *self, PyObject *args, int pers_save)
goto finally;
}
- assert(t == NULL); /* just a reminder */
+ /* Get a reduction callable. This may come from
+ * copy_reg.dispatch_table, the object's __reduce__ method,
+ * the default object.__reduce__, or copy_reg._better_reduce.
+ */
__reduce__ = PyDict_GetItem(dispatch_table, (PyObject *)type);
if (__reduce__ != NULL) {
Py_INCREF(__reduce__);
- Py_INCREF(args);
- ARG_TUP(self, args);
- if (self->arg) {
- t = PyObject_Call(__reduce__, self->arg, NULL);
- FREE_ARG_TUP(self);
- }
- if (! t) goto finally;
}
else {
- __reduce__ = PyObject_GetAttr(args, __reduce___str);
- if (__reduce__ == NULL)
+ /* Check for a __reduce__ method.
+ * Subtle: get the unbound method from the class, so that
+ * protocol 2 can override the default __reduce__ that all
+ * classes inherit from object.
+ * XXX object.__reduce__ should really be rewritten so that
+ * XXX we don't need to call back into Python code here
+ * XXX (better_reduce), but no time to do that.
+ */
+ __reduce__ = PyObject_GetAttr((PyObject *)type,
+ __reduce___str);
+ if (__reduce__ == NULL) {
PyErr_Clear();
- else {
- t = PyObject_Call(__reduce__, empty_tuple, NULL);
- if (!t)
- goto finally;
- }
- }
-
- if (t) {
- if (PyString_Check(t)) {
- res = save_global(self, args, t);
- goto finally;
- }
-
- if (!PyTuple_Check(t)) {
- cPickle_ErrFormat(PicklingError, "Value returned by "
- "%s must be a tuple",
- "O", __reduce__);
+ PyErr_SetObject(UnpickleableError, args);
goto finally;
}
- size = PyTuple_Size(t);
-
- if (size != 3 && size != 2) {
- cPickle_ErrFormat(PicklingError, "tuple returned by "
- "%s must contain only two or three elements",
- "O", __reduce__);
- goto finally;
+ if (self->proto >= 2 && __reduce__ == object_reduce) {
+ /* Proto 2 can do better than the default. */
+ Py_DECREF(__reduce__);
+ Py_INCREF(better_reduce);
+ __reduce__ = better_reduce;
}
+ }
- callable = PyTuple_GET_ITEM(t, 0);
- arg_tup = PyTuple_GET_ITEM(t, 1);
+ /* Call the reduction callable, setting t to the result. */
+ assert(__reduce__ != NULL);
+ assert(t == NULL);
+ Py_INCREF(args);
+ ARG_TUP(self, args);
+ if (self->arg) {
+ t = PyObject_Call(__reduce__, self->arg, NULL);
+ FREE_ARG_TUP(self);
+ }
+ if (t == NULL)
+ goto finally;
- if (size > 2) {
- state = PyTuple_GET_ITEM(t, 2);
- if (state == Py_None)
- state = NULL;
- }
+ if (PyString_Check(t)) {
+ res = save_global(self, args, t);
+ goto finally;
+ }
- if (!( PyTuple_Check(arg_tup) || arg_tup==Py_None )) {
- cPickle_ErrFormat(PicklingError, "Second element of "
- "tuple returned by %s must be a tuple",
+ if (! PyTuple_Check(t)) {
+ cPickle_ErrFormat(PicklingError, "Value returned by "
+ "%s must be string or tuple",
"O", __reduce__);
- goto finally;
- }
+ goto finally;
+ }
- res = save_reduce(self, callable, arg_tup, state, args);
+ size = PyTuple_Size(t);
+ if (size < 2 || size > 5) {
+ cPickle_ErrFormat(PicklingError, "tuple returned by "
+ "%s must contain 2 through 5 elements",
+ "O", __reduce__);
goto finally;
}
- PyErr_SetObject(UnpickleableError, args);
+ arg_tup = PyTuple_GET_ITEM(t, 1);
+ if (!(PyTuple_Check(arg_tup) || arg_tup == Py_None)) {
+ cPickle_ErrFormat(PicklingError, "Second element of "
+ "tuple returned by %s must be a tuple",
+ "O", __reduce__);
+ goto finally;
+ }
+
+ res = save_reduce(self, t, args);
finally:
self->nesting--;
@@ -5447,8 +5566,15 @@ init_stuff(PyObject *module_dict)
"_extension_cache");
if (!extension_cache) return -1;
+ better_reduce = PyObject_GetAttrString(copy_reg, "_better_reduce");
+ if (!better_reduce) return -1;
+
Py_DECREF(copy_reg);
+ object_reduce = PyObject_GetAttrString((PyObject *)&PyBaseObject_Type,
+ "__reduce__");
+ if (object_reduce == NULL) return -1;
+
if (!(empty_tuple = PyTuple_New(0)))
return -1;