diff options
author | Nick Coghlan <ncoghlan@gmail.com> | 2012-08-25 07:59:50 (GMT) |
---|---|---|
committer | Nick Coghlan <ncoghlan@gmail.com> | 2012-08-25 07:59:50 (GMT) |
commit | 06e1ab0a6b51c35e4637bb82c3aa18548b6412b0 (patch) | |
tree | e24cfb1c7f4c51a518cdc3589b5a92c913ab4b76 /Objects | |
parent | 5c0b1ca55ec76c0891a2ae0557e2e40391e1c74f (diff) | |
download | cpython-06e1ab0a6b51c35e4637bb82c3aa18548b6412b0.zip cpython-06e1ab0a6b51c35e4637bb82c3aa18548b6412b0.tar.gz cpython-06e1ab0a6b51c35e4637bb82c3aa18548b6412b0.tar.bz2 |
Close #15573: use value-based memoryview comparisons (patch by Stefan Krah)
Diffstat (limited to 'Objects')
-rw-r--r-- | Objects/memoryobject.c | 326 |
1 files changed, 283 insertions, 43 deletions
diff --git a/Objects/memoryobject.c b/Objects/memoryobject.c index 46a8416..f547983 100644 --- a/Objects/memoryobject.c +++ b/Objects/memoryobject.c @@ -246,7 +246,7 @@ Create a new memoryview object which references the given object."); (view->suboffsets && view->suboffsets[dest->ndim-1] >= 0) Py_LOCAL_INLINE(int) -last_dim_is_contiguous(Py_buffer *dest, Py_buffer *src) +last_dim_is_contiguous(const Py_buffer *dest, const Py_buffer *src) { assert(dest->ndim > 0 && src->ndim > 0); return (!HAVE_SUBOFFSETS_IN_LAST_DIM(dest) && @@ -255,37 +255,63 @@ last_dim_is_contiguous(Py_buffer *dest, Py_buffer *src) src->strides[src->ndim-1] == src->itemsize); } -/* Check that the logical structure of the destination and source buffers - is identical. */ -static int -cmp_structure(Py_buffer *dest, Py_buffer *src) +/* This is not a general function for determining format equivalence. + It is used in copy_single() and copy_buffer() to weed out non-matching + formats. Skipping the '@' character is specifically used in slice + assignments, where the lvalue is already known to have a single character + format. This is a performance hack that could be rewritten (if properly + benchmarked). */ +Py_LOCAL_INLINE(int) +equiv_format(const Py_buffer *dest, const Py_buffer *src) { const char *dfmt, *sfmt; - int i; assert(dest->format && src->format); dfmt = dest->format[0] == '@' ? dest->format+1 : dest->format; sfmt = src->format[0] == '@' ? src->format+1 : src->format; if (strcmp(dfmt, sfmt) != 0 || - dest->itemsize != src->itemsize || - dest->ndim != src->ndim) { - goto value_error; + dest->itemsize != src->itemsize) { + return 0; } + return 1; +} + +/* Two shapes are equivalent if they are either equal or identical up + to a zero element at the same position. For example, in NumPy arrays + the shapes [1, 0, 5] and [1, 0, 7] are equivalent. */ +Py_LOCAL_INLINE(int) +equiv_shape(const Py_buffer *dest, const Py_buffer *src) +{ + int i; + + if (dest->ndim != src->ndim) + return 0; + for (i = 0; i < dest->ndim; i++) { if (dest->shape[i] != src->shape[i]) - goto value_error; + return 0; if (dest->shape[i] == 0) break; } - return 0; + return 1; +} -value_error: - PyErr_SetString(PyExc_ValueError, - "ndarray assignment: lvalue and rvalue have different structures"); - return -1; +/* Check that the logical structure of the destination and source buffers + is identical. */ +static int +equiv_structure(const Py_buffer *dest, const Py_buffer *src) +{ + if (!equiv_format(dest, src) || + !equiv_shape(dest, src)) { + PyErr_SetString(PyExc_ValueError, + "ndarray assignment: lvalue and rvalue have different structures"); + return 0; + } + + return 1; } /* Base case for recursive multi-dimensional copying. Contiguous arrays are @@ -358,7 +384,7 @@ copy_single(Py_buffer *dest, Py_buffer *src) assert(dest->ndim == 1); - if (cmp_structure(dest, src) < 0) + if (!equiv_structure(dest, src)) return -1; if (!last_dim_is_contiguous(dest, src)) { @@ -390,7 +416,7 @@ copy_buffer(Py_buffer *dest, Py_buffer *src) assert(dest->ndim > 0); - if (cmp_structure(dest, src) < 0) + if (!equiv_structure(dest, src)) return -1; if (!last_dim_is_contiguous(dest, src)) { @@ -1828,6 +1854,131 @@ err_format: /****************************************************************************/ +/* unpack using the struct module */ +/****************************************************************************/ + +/* For reasonable performance it is necessary to cache all objects required + for unpacking. An unpacker can handle the format passed to unpack_from(). + Invariant: All pointer fields of the struct should either be NULL or valid + pointers. */ +struct unpacker { + PyObject *unpack_from; /* Struct.unpack_from(format) */ + PyObject *mview; /* cached memoryview */ + char *item; /* buffer for mview */ + Py_ssize_t itemsize; /* len(item) */ +}; + +static struct unpacker * +unpacker_new(void) +{ + struct unpacker *x = PyMem_Malloc(sizeof *x); + + if (x == NULL) { + PyErr_NoMemory(); + return NULL; + } + + x->unpack_from = NULL; + x->mview = NULL; + x->item = NULL; + x->itemsize = 0; + + return x; +} + +static void +unpacker_free(struct unpacker *x) +{ + if (x) { + Py_XDECREF(x->unpack_from); + Py_XDECREF(x->mview); + PyMem_Free(x->item); + PyMem_Free(x); + } +} + +/* Return a new unpacker for the given format. */ +static struct unpacker * +struct_get_unpacker(const char *fmt, Py_ssize_t itemsize) +{ + PyObject *structmodule; /* XXX cache these two */ + PyObject *Struct = NULL; /* XXX in globals? */ + PyObject *structobj = NULL; + PyObject *format = NULL; + struct unpacker *x = NULL; + + structmodule = PyImport_ImportModule("struct"); + if (structmodule == NULL) + return NULL; + + Struct = PyObject_GetAttrString(structmodule, "Struct"); + Py_DECREF(structmodule); + if (Struct == NULL) + return NULL; + + x = unpacker_new(); + if (x == NULL) + goto error; + + format = PyBytes_FromString(fmt); + if (format == NULL) + goto error; + + structobj = PyObject_CallFunctionObjArgs(Struct, format, NULL); + if (structobj == NULL) + goto error; + + x->unpack_from = PyObject_GetAttrString(structobj, "unpack_from"); + if (x->unpack_from == NULL) + goto error; + + x->item = PyMem_Malloc(itemsize); + if (x->item == NULL) { + PyErr_NoMemory(); + goto error; + } + x->itemsize = itemsize; + + x->mview = PyMemoryView_FromMemory(x->item, itemsize, PyBUF_WRITE); + if (x->mview == NULL) + goto error; + + +out: + Py_XDECREF(Struct); + Py_XDECREF(format); + Py_XDECREF(structobj); + return x; + +error: + unpacker_free(x); + x = NULL; + goto out; +} + +/* unpack a single item */ +static PyObject * +struct_unpack_single(const char *ptr, struct unpacker *x) +{ + PyObject *v; + + memcpy(x->item, ptr, x->itemsize); + v = PyObject_CallFunctionObjArgs(x->unpack_from, x->mview, NULL); + if (v == NULL) + return NULL; + + if (PyTuple_GET_SIZE(v) == 1) { + PyObject *tmp = PyTuple_GET_ITEM(v, 0); + Py_INCREF(tmp); + Py_DECREF(v); + return tmp; + } + + return v; +} + + +/****************************************************************************/ /* Representations */ /****************************************************************************/ @@ -2261,6 +2412,58 @@ static PySequenceMethods memory_as_sequence = { /* Comparisons */ /**************************************************************************/ +#define MV_COMPARE_EX -1 /* exception */ +#define MV_COMPARE_NOT_IMPL -2 /* not implemented */ + +/* Translate a StructError to "not equal". Preserve other exceptions. */ +static int +fix_struct_error_int(void) +{ + assert(PyErr_Occurred()); + /* XXX Cannot get at StructError directly? */ + if (PyErr_ExceptionMatches(PyExc_ImportError) || + PyErr_ExceptionMatches(PyExc_MemoryError)) { + return MV_COMPARE_EX; + } + /* StructError: invalid or unknown format -> not equal */ + PyErr_Clear(); + return 0; +} + +/* Unpack and compare single items of p and q using the struct module. */ +static int +struct_unpack_cmp(const char *p, const char *q, + struct unpacker *unpack_p, struct unpacker *unpack_q) +{ + PyObject *v, *w; + int ret; + + /* At this point any exception from the struct module should not be + StructError, since both formats have been accepted already. */ + v = struct_unpack_single(p, unpack_p); + if (v == NULL) + return MV_COMPARE_EX; + + w = struct_unpack_single(q, unpack_q); + if (w == NULL) { + Py_DECREF(v); + return MV_COMPARE_EX; + } + + /* MV_COMPARE_EX == -1: exceptions are preserved */ + ret = PyObject_RichCompareBool(v, w, Py_EQ); + Py_DECREF(v); + Py_DECREF(w); + + return ret; +} + +/* Unpack and compare single items of p and q. If both p and q have the same + single element native format, the comparison uses a fast path (gcc creates + a jump table and converts memcpy into simple assignments on x86/x64). + + Otherwise, the comparison is delegated to the struct module, which is + 30-60x slower. */ #define CMP_SINGLE(p, q, type) \ do { \ type x; \ @@ -2271,11 +2474,12 @@ static PySequenceMethods memory_as_sequence = { } while (0) Py_LOCAL_INLINE(int) -unpack_cmp(const char *p, const char *q, const char *fmt) +unpack_cmp(const char *p, const char *q, char fmt, + struct unpacker *unpack_p, struct unpacker *unpack_q) { int equal; - switch (fmt[0]) { + switch (fmt) { /* signed integers and fast path for 'B' */ case 'B': return *((unsigned char *)p) == *((unsigned char *)q); @@ -2317,9 +2521,17 @@ unpack_cmp(const char *p, const char *q, const char *fmt) /* pointer */ case 'P': CMP_SINGLE(p, q, void *); return equal; - /* Py_NotImplemented */ - default: return -1; + /* use the struct module */ + case '_': + assert(unpack_p); + assert(unpack_q); + return struct_unpack_cmp(p, q, unpack_p, unpack_q); } + + /* NOT REACHED */ + PyErr_SetString(PyExc_RuntimeError, + "memoryview: internal error in richcompare"); + return MV_COMPARE_EX; } /* Base case for recursive array comparisons. Assumption: ndim == 1. */ @@ -2327,7 +2539,7 @@ static int cmp_base(const char *p, const char *q, const Py_ssize_t *shape, const Py_ssize_t *pstrides, const Py_ssize_t *psuboffsets, const Py_ssize_t *qstrides, const Py_ssize_t *qsuboffsets, - const char *fmt) + char fmt, struct unpacker *unpack_p, struct unpacker *unpack_q) { Py_ssize_t i; int equal; @@ -2335,7 +2547,7 @@ cmp_base(const char *p, const char *q, const Py_ssize_t *shape, for (i = 0; i < shape[0]; p+=pstrides[0], q+=qstrides[0], i++) { const char *xp = ADJUST_PTR(p, psuboffsets); const char *xq = ADJUST_PTR(q, qsuboffsets); - equal = unpack_cmp(xp, xq, fmt); + equal = unpack_cmp(xp, xq, fmt, unpack_p, unpack_q); if (equal <= 0) return equal; } @@ -2350,7 +2562,7 @@ cmp_rec(const char *p, const char *q, Py_ssize_t ndim, const Py_ssize_t *shape, const Py_ssize_t *pstrides, const Py_ssize_t *psuboffsets, const Py_ssize_t *qstrides, const Py_ssize_t *qsuboffsets, - const char *fmt) + char fmt, struct unpacker *unpack_p, struct unpacker *unpack_q) { Py_ssize_t i; int equal; @@ -2364,7 +2576,7 @@ cmp_rec(const char *p, const char *q, return cmp_base(p, q, shape, pstrides, psuboffsets, qstrides, qsuboffsets, - fmt); + fmt, unpack_p, unpack_q); } for (i = 0; i < shape[0]; p+=pstrides[0], q+=qstrides[0], i++) { @@ -2373,7 +2585,7 @@ cmp_rec(const char *p, const char *q, equal = cmp_rec(xp, xq, ndim-1, shape+1, pstrides+1, psuboffsets ? psuboffsets+1 : NULL, qstrides+1, qsuboffsets ? qsuboffsets+1 : NULL, - fmt); + fmt, unpack_p, unpack_q); if (equal <= 0) return equal; } @@ -2385,9 +2597,12 @@ static PyObject * memory_richcompare(PyObject *v, PyObject *w, int op) { PyObject *res; - Py_buffer wbuf, *vv, *ww = NULL; - const char *vfmt, *wfmt; - int equal = -1; /* Py_NotImplemented */ + Py_buffer wbuf, *vv; + Py_buffer *ww = NULL; + struct unpacker *unpack_v = NULL; + struct unpacker *unpack_w = NULL; + char vfmt, wfmt; + int equal = MV_COMPARE_NOT_IMPL; if (op != Py_EQ && op != Py_NE) goto result; /* Py_NotImplemented */ @@ -2414,38 +2629,59 @@ memory_richcompare(PyObject *v, PyObject *w, int op) ww = &wbuf; } - vfmt = adjust_fmt(vv); - wfmt = adjust_fmt(ww); - if (vfmt == NULL || wfmt == NULL) { - PyErr_Clear(); - goto result; /* Py_NotImplemented */ - } - - if (cmp_structure(vv, ww) < 0) { + if (!equiv_shape(vv, ww)) { PyErr_Clear(); equal = 0; goto result; } + /* Use fast unpacking for identical primitive C type formats. */ + if (get_native_fmtchar(&vfmt, vv->format) < 0) + vfmt = '_'; + if (get_native_fmtchar(&wfmt, ww->format) < 0) + wfmt = '_'; + if (vfmt == '_' || wfmt == '_' || vfmt != wfmt) { + /* Use struct module unpacking. NOTE: Even for equal format strings, + memcmp() cannot be used for item comparison since it would give + incorrect results in the case of NaNs or uninitialized padding + bytes. */ + vfmt = '_'; + unpack_v = struct_get_unpacker(vv->format, vv->itemsize); + if (unpack_v == NULL) { + equal = fix_struct_error_int(); + goto result; + } + unpack_w = struct_get_unpacker(ww->format, ww->itemsize); + if (unpack_w == NULL) { + equal = fix_struct_error_int(); + goto result; + } + } + if (vv->ndim == 0) { - equal = unpack_cmp(vv->buf, ww->buf, vfmt); + equal = unpack_cmp(vv->buf, ww->buf, + vfmt, unpack_v, unpack_w); } else if (vv->ndim == 1) { equal = cmp_base(vv->buf, ww->buf, vv->shape, vv->strides, vv->suboffsets, ww->strides, ww->suboffsets, - vfmt); + vfmt, unpack_v, unpack_w); } else { equal = cmp_rec(vv->buf, ww->buf, vv->ndim, vv->shape, vv->strides, vv->suboffsets, ww->strides, ww->suboffsets, - vfmt); + vfmt, unpack_v, unpack_w); } result: - if (equal < 0) - res = Py_NotImplemented; + if (equal < 0) { + if (equal == MV_COMPARE_NOT_IMPL) + res = Py_NotImplemented; + else /* exception */ + res = NULL; + } else if ((equal && op == Py_EQ) || (!equal && op == Py_NE)) res = Py_True; else @@ -2453,7 +2689,11 @@ result: if (ww == &wbuf) PyBuffer_Release(ww); - Py_INCREF(res); + + unpacker_free(unpack_v); + unpacker_free(unpack_w); + + Py_XINCREF(res); return res; } |