diff options
-rw-r--r-- | Lib/test/test_iter.py | 35 | ||||
-rw-r--r-- | Misc/NEWS | 1 | ||||
-rw-r--r-- | Python/bltinmodule.c | 134 |
3 files changed, 104 insertions, 66 deletions
diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py index 3563661..c87f5ec 100644 --- a/Lib/test/test_iter.py +++ b/Lib/test/test_iter.py @@ -351,4 +351,39 @@ class TestCase(unittest.TestCase): except OSError: pass + # Test map()'s use of iterators. + def test_builtin_map(self): + self.assertEqual(map(None, SequenceClass(5)), range(5)) + self.assertEqual(map(lambda x: x+1, SequenceClass(5)), range(1, 6)) + + d = {"one": 1, "two": 2, "three": 3} + self.assertEqual(map(None, d), d.keys()) + self.assertEqual(map(lambda k, d=d: (k, d[k]), d), d.items()) + dkeys = d.keys() + expected = [(i < len(d) and dkeys[i] or None, + i, + i < len(d) and dkeys[i] or None) + for i in range(5)] + self.assertEqual(map(None, d, + SequenceClass(5), + iter(d.iterkeys())), + expected) + + f = open(TESTFN, "w") + try: + for i in range(10): + f.write("xy" * i + "\n") # line i has len 2*i+1 + finally: + f.close() + f = open(TESTFN, "r") + try: + self.assertEqual(map(len, f), range(1, 21, 2)) + f.seek(0, 0) + finally: + f.close() + try: + unlink(TESTFN) + except OSError: + pass + run_unittest(TestCase) @@ -19,6 +19,7 @@ Core arguments: filter() list() + map() max() min() diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c index 9e8a227..0c20d10 100644 --- a/Python/bltinmodule.c +++ b/Python/bltinmodule.c @@ -936,9 +936,8 @@ static PyObject * builtin_map(PyObject *self, PyObject *args) { typedef struct { - PyObject *seq; - PySequenceMethods *sqf; - int saw_IndexError; + PyObject *it; /* the iterator object */ + int saw_StopIteration; /* bool: did the iterator end? */ } sequence; PyObject *func, *result; @@ -961,104 +960,105 @@ builtin_map(PyObject *self, PyObject *args) return PySequence_List(PyTuple_GetItem(args, 1)); } + /* Get space for sequence descriptors. Must NULL out the iterator + * pointers so that jumping to Fail_2 later doesn't see trash. + */ if ((seqs = PyMem_NEW(sequence, n)) == NULL) { PyErr_NoMemory(); - goto Fail_2; + return NULL; + } + for (i = 0; i < n; ++i) { + seqs[i].it = (PyObject*)NULL; + seqs[i].saw_StopIteration = 0; } - /* Do a first pass to (a) verify the args are sequences; (b) set - * len to the largest of their lengths; (c) initialize the seqs - * descriptor vector. + /* Do a first pass to obtain iterators for the arguments, and set len + * to the largest of their lengths. */ - for (len = 0, i = 0, sqp = seqs; i < n; ++i, ++sqp) { + len = 0; + for (i = 0, sqp = seqs; i < n; ++i, ++sqp) { + PyObject *curseq; int curlen; - PySequenceMethods *sqf; - if ((sqp->seq = PyTuple_GetItem(args, i + 1)) == NULL) - goto Fail_2; - - sqp->saw_IndexError = 0; - - sqp->sqf = sqf = sqp->seq->ob_type->tp_as_sequence; - if (sqf == NULL || - sqf->sq_item == NULL) - { + /* Get iterator. */ + curseq = PyTuple_GetItem(args, i+1); + sqp->it = PyObject_GetIter(curseq); + if (sqp->it == NULL) { static char errmsg[] = - "argument %d to map() must be a sequence object"; + "argument %d to map() must support iteration"; char errbuf[sizeof(errmsg) + 25]; - sprintf(errbuf, errmsg, i+2); PyErr_SetString(PyExc_TypeError, errbuf); goto Fail_2; } - if (sqf->sq_length == NULL) - /* doesn't matter -- make something up */ - curlen = 8; - else - curlen = (*sqf->sq_length)(sqp->seq); + /* Update len. */ + curlen = -1; /* unknown */ + if (PySequence_Check(curseq) && + curseq->ob_type->tp_as_sequence->sq_length) { + curlen = PySequence_Size(curseq); + if (curlen < 0) + PyErr_Clear(); + } if (curlen < 0) - goto Fail_2; + curlen = 8; /* arbitrary */ if (curlen > len) len = curlen; } + /* Get space for the result list. */ if ((result = (PyObject *) PyList_New(len)) == NULL) goto Fail_2; - /* Iterate over the sequences until all have raised IndexError. */ + /* Iterate over the sequences until all have stopped. */ for (i = 0; ; ++i) { PyObject *alist, *item=NULL, *value; - int any = 0; + int numactive = 0; if (func == Py_None && n == 1) alist = NULL; - else { - if ((alist = PyTuple_New(n)) == NULL) - goto Fail_1; - } + else if ((alist = PyTuple_New(n)) == NULL) + goto Fail_1; for (j = 0, sqp = seqs; j < n; ++j, ++sqp) { - if (sqp->saw_IndexError) { + if (sqp->saw_StopIteration) { Py_INCREF(Py_None); item = Py_None; } else { - item = (*sqp->sqf->sq_item)(sqp->seq, i); - if (item == NULL) { - if (PyErr_ExceptionMatches( - PyExc_IndexError)) - { - PyErr_Clear(); - Py_INCREF(Py_None); - item = Py_None; - sqp->saw_IndexError = 1; - } - else { - goto Fail_0; + item = PyIter_Next(sqp->it); + if (item) + ++numactive; + else { + /* StopIteration is *implied* by a + * NULL return from PyIter_Next() if + * PyErr_Occurred() is false. + */ + if (PyErr_Occurred()) { + if (PyErr_ExceptionMatches( + PyExc_StopIteration)) + PyErr_Clear(); + else { + Py_XDECREF(alist); + goto Fail_1; + } } + Py_INCREF(Py_None); + item = Py_None; + sqp->saw_StopIteration = 1; } - else - any = 1; } - if (!alist) + if (alist) + PyTuple_SET_ITEM(alist, j, item); + else break; - if (PyTuple_SetItem(alist, j, item) < 0) { - Py_DECREF(item); - goto Fail_0; - } - continue; - - Fail_0: - Py_XDECREF(alist); - goto Fail_1; } if (!alist) alist = item; - if (!any) { + if (numactive == 0) { Py_DECREF(alist); break; } @@ -1077,23 +1077,25 @@ builtin_map(PyObject *self, PyObject *args) if (status < 0) goto Fail_1; } - else { - if (PyList_SetItem(result, i, value) < 0) - goto Fail_1; - } + else if (PyList_SetItem(result, i, value) < 0) + goto Fail_1; } if (i < len && PyList_SetSlice(result, i, len, NULL) < 0) goto Fail_1; - PyMem_DEL(seqs); - return result; + goto Succeed; Fail_1: Py_DECREF(result); Fail_2: - if (seqs) PyMem_DEL(seqs); - return NULL; + result = NULL; +Succeed: + assert(seqs); + for (i = 0; i < n; ++i) + Py_XDECREF(seqs[i].it); + PyMem_DEL(seqs); + return result; } static char map_doc[] = |