diff options
-rw-r--r-- | Lib/test/test_descrtut.py | 1 | ||||
-rw-r--r-- | Objects/iterobject.c | 24 | ||||
-rw-r--r-- | Objects/listobject.c | 119 |
3 files changed, 126 insertions, 18 deletions
diff --git a/Lib/test/test_descrtut.py b/Lib/test/test_descrtut.py index 2c93b7e..32ca73d 100644 --- a/Lib/test/test_descrtut.py +++ b/Lib/test/test_descrtut.py @@ -202,6 +202,7 @@ Instead, you can get the same information from the list type: '__iadd__', '__imul__', '__init__', + '__iter__', '__le__', '__len__', '__lt__', diff --git a/Objects/iterobject.c b/Objects/iterobject.c index de9f2f9..ce1fe3d 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -68,25 +68,15 @@ iter_iternext(PyObject *iterator) it = (seqiterobject *)iterator; seq = it->it_seq; - if (PyList_CheckExact(seq)) { - PyObject *item; - if (it->it_index >= PyList_GET_SIZE(seq)) { - return NULL; - } - item = PyList_GET_ITEM(seq, it->it_index); - it->it_index++; - Py_INCREF(item); - return item; - } if (PyTuple_CheckExact(seq)) { - PyObject *item; - if (it->it_index >= PyTuple_GET_SIZE(seq)) { - return NULL; + if (it->it_index < PyTuple_GET_SIZE(seq)) { + PyObject *item; + item = PyTuple_GET_ITEM(seq, it->it_index); + it->it_index++; + Py_INCREF(item); + return item; } - item = PyTuple_GET_ITEM(seq, it->it_index); - it->it_index++; - Py_INCREF(item); - return item; + return NULL; } else { PyObject *result = PySequence_ITEM(seq, it->it_index); diff --git a/Objects/listobject.c b/Objects/listobject.c index 83a8c70..c2892c6 100644 --- a/Objects/listobject.c +++ b/Objects/listobject.c @@ -1682,6 +1682,8 @@ static char list_doc[] = "list() -> new list\n" "list(sequence) -> new list initialized from sequence's items"; +staticforward PyObject * list_iter(PyObject *seq); + PyTypeObject PyList_Type = { PyObject_HEAD_INIT(&PyType_Type) 0, @@ -1710,7 +1712,7 @@ PyTypeObject PyList_Type = { (inquiry)list_clear, /* tp_clear */ list_richcompare, /* tp_richcompare */ 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ + list_iter, /* tp_iter */ 0, /* tp_iternext */ list_methods, /* tp_methods */ 0, /* tp_members */ @@ -1811,3 +1813,118 @@ static PyTypeObject immutable_list_type = { 0, /* tp_init */ /* NOTE: This is *not* the standard list_type struct! */ }; + + +/*********************** List Iterator **************************/ + +typedef struct { + PyObject_HEAD + long it_index; + PyObject *it_seq; +} listiterobject; + +PyTypeObject PyListIter_Type; + +PyObject * +list_iter(PyObject *seq) +{ + listiterobject *it; + + if (!PyList_Check(seq)) { + PyErr_BadInternalCall(); + return NULL; + } + it = PyObject_GC_New(listiterobject, &PyListIter_Type); + if (it == NULL) + return NULL; + it->it_index = 0; + Py_INCREF(seq); + it->it_seq = seq; + _PyObject_GC_TRACK(it); + return (PyObject *)it; +} + +static void +listiter_dealloc(listiterobject *it) +{ + _PyObject_GC_UNTRACK(it); + Py_DECREF(it->it_seq); + PyObject_GC_Del(it); +} + +static int +listiter_traverse(listiterobject *it, visitproc visit, void *arg) +{ + return visit(it->it_seq, arg); +} + + +static PyObject * +listiter_getiter(PyObject *it) +{ + Py_INCREF(it); + return it; +} + +static PyObject * +listiter_next(PyObject *it) +{ + PyObject *seq; + PyObject *item; + + assert(PyList_Check(it)); + seq = ((listiterobject *)it)->it_seq; + + if (((listiterobject *)it)->it_index < PyList_GET_SIZE(seq)) { + item = ((PyListObject *)(seq))->ob_item[((listiterobject *)it)->it_index++]; + Py_INCREF(item); + return item; + } + return NULL; +} + +static PyMethodDef listiter_methods[] = { + {"next", (PyCFunction)listiter_next, METH_NOARGS, + "it.next() -- get the next value, or raise StopIteration"}, + {NULL, NULL} /* sentinel */ +}; + +PyTypeObject PyListIter_Type = { + PyObject_HEAD_INIT(&PyType_Type) + 0, /* ob_size */ + "listiterator", /* tp_name */ + sizeof(listiterobject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)listiter_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,/* tp_flags */ + 0, /* tp_doc */ + (traverseproc)listiter_traverse, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + (getiterfunc)listiter_getiter, /* tp_iter */ + (iternextfunc)listiter_next, /* tp_iternext */ + listiter_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ +}; + |