diff options
-rw-r--r-- | Objects/stringobject.c | 90 |
1 files changed, 52 insertions, 38 deletions
diff --git a/Objects/stringobject.c b/Objects/stringobject.c index eed4687..df3ab49 100644 --- a/Objects/stringobject.c +++ b/Objects/stringobject.c @@ -794,46 +794,55 @@ static PyObject * string_join(PyStringObject *self, PyObject *args) { char *sep = PyString_AS_STRING(self); - int seplen = PyString_GET_SIZE(self); + const int seplen = PyString_GET_SIZE(self); PyObject *res = NULL; - int reslen = 0; char *p; int seqlen = 0; - int sz = 100; - int i, slen, sz_incr; + size_t sz = 0; + int i; PyObject *orig, *seq, *item; if (!PyArg_ParseTuple(args, "O:join", &orig)) return NULL; - if (!(seq = PySequence_Fast(orig, ""))) { + seq = PySequence_Fast(orig, ""); + if (seq == NULL) { if (PyErr_ExceptionMatches(PyExc_TypeError)) PyErr_Format(PyExc_TypeError, "sequence expected, %.80s found", orig->ob_type->tp_name); return NULL; } - /* From here on out, errors go through finally: for proper - * reference count manipulations. - */ + seqlen = PySequence_Size(seq); + if (seqlen == 0) { + Py_DECREF(seq); + return PyString_FromString(""); + } if (seqlen == 1) { item = PySequence_Fast_GET_ITEM(seq, 0); + if (!PyString_Check(item) && !PyUnicode_Check(item)) { + PyErr_Format(PyExc_TypeError, + "sequence item 0: expected string," + " %.80s found", + item->ob_type->tp_name); + Py_DECREF(seq); + return NULL; + } Py_INCREF(item); Py_DECREF(seq); return item; } - if (!(res = PyString_FromStringAndSize((char*)NULL, sz))) - goto finally; - - p = PyString_AS_STRING(res); - + /* There are at least two things to join. Do a pre-pass to figure out + * the total amount of space we'll need (sz), see whether any argument + * is absurd, and defer to the Unicode join if appropriate. + */ for (i = 0; i < seqlen; i++) { + const size_t old_sz = sz; item = PySequence_Fast_GET_ITEM(seq, i); if (!PyString_Check(item)){ if (PyUnicode_Check(item)) { - Py_DECREF(res); Py_DECREF(seq); return PyUnicode_Join((PyObject *)self, orig); } @@ -841,40 +850,45 @@ string_join(PyStringObject *self, PyObject *args) "sequence item %i: expected string," " %.80s found", i, item->ob_type->tp_name); - goto finally; + Py_DECREF(seq); + return NULL; } - slen = PyString_GET_SIZE(item); - while (reslen + slen + seplen >= sz) { - /* at least double the size of the string */ - sz_incr = slen + seplen > sz ? slen + seplen : sz; - if (_PyString_Resize(&res, sz + sz_incr)) { - goto finally; - } - sz += sz_incr; - p = PyString_AS_STRING(res) + reslen; + sz += PyString_GET_SIZE(item); + if (i != 0) + sz += seplen; + if (sz < old_sz || sz > INT_MAX) { + PyErr_SetString(PyExc_OverflowError, + "join() is too long for a Python string"); + Py_DECREF(seq); + return NULL; } - if (i > 0) { + } + + /* Allocate result space. */ + res = PyString_FromStringAndSize((char*)NULL, (int)sz); + if (res == NULL) { + Py_DECREF(seq); + return NULL; + } + + /* Catenate everything. */ + p = PyString_AS_STRING(res); + for (i = 0; i < seqlen; ++i) { + size_t n; + item = PySequence_Fast_GET_ITEM(seq, i); + n = PyString_GET_SIZE(item); + memcpy(p, PyString_AS_STRING(item), n); + p += n; + if (i < seqlen - 1) { memcpy(p, sep, seplen); p += seplen; - reslen += seplen; } - memcpy(p, PyString_AS_STRING(item), slen); - p += slen; - reslen += slen; } - if (_PyString_Resize(&res, reslen)) - goto finally; - Py_DECREF(seq); - return res; - finally: Py_DECREF(seq); - Py_XDECREF(res); - return NULL; + return res; } - - static long string_find_internal(PyStringObject *self, PyObject *args, int dir) { |