diff options
Diffstat (limited to 'Modules/_collectionsmodule.c')
-rw-r--r-- | Modules/_collectionsmodule.c | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index b244667..c6c7983 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -1689,17 +1689,17 @@ Count elements in the iterable, updating the mappping"); static PyObject * _count_elements(PyObject *self, PyObject *args) { - _Py_IDENTIFIER(__getitem__); + _Py_IDENTIFIER(get); _Py_IDENTIFIER(__setitem__); PyObject *it, *iterable, *mapping, *oldval; PyObject *newval = NULL; PyObject *key = NULL; PyObject *zero = NULL; PyObject *one = NULL; - PyObject *mapping_get = NULL; - PyObject *mapping_getitem; + PyObject *bound_get = NULL; + PyObject *mapping_get; + PyObject *dict_get; PyObject *mapping_setitem; - PyObject *dict_getitem; PyObject *dict_setitem; if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable)) @@ -1713,15 +1713,16 @@ _count_elements(PyObject *self, PyObject *args) if (one == NULL) goto done; - mapping_getitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___getitem__); - dict_getitem = _PyType_LookupId(&PyDict_Type, &PyId___getitem__); + /* Only take the fast path when get() and __setitem__() + * have not been overridden. + */ + mapping_get = _PyType_LookupId(Py_TYPE(mapping), &PyId_get); + dict_get = _PyType_LookupId(&PyDict_Type, &PyId_get); mapping_setitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___setitem__); dict_setitem = _PyType_LookupId(&PyDict_Type, &PyId___setitem__); - if (mapping_getitem != NULL && - mapping_getitem == dict_getitem && - mapping_setitem != NULL && - mapping_setitem == dict_setitem) { + if (mapping_get != NULL && mapping_get == dict_get && + mapping_setitem != NULL && mapping_setitem == dict_setitem) { while (1) { key = PyIter_Next(it); if (key == NULL) @@ -1741,8 +1742,8 @@ _count_elements(PyObject *self, PyObject *args) Py_DECREF(key); } } else { - mapping_get = PyObject_GetAttrString(mapping, "get"); - if (mapping_get == NULL) + bound_get = PyObject_GetAttrString(mapping, "get"); + if (bound_get == NULL) goto done; zero = PyLong_FromLong(0); @@ -1753,7 +1754,7 @@ _count_elements(PyObject *self, PyObject *args) key = PyIter_Next(it); if (key == NULL) break; - oldval = PyObject_CallFunctionObjArgs(mapping_get, key, zero, NULL); + oldval = PyObject_CallFunctionObjArgs(bound_get, key, zero, NULL); if (oldval == NULL) break; newval = PyNumber_Add(oldval, one); @@ -1771,7 +1772,7 @@ done: Py_DECREF(it); Py_XDECREF(key); Py_XDECREF(newval); - Py_XDECREF(mapping_get); + Py_XDECREF(bound_get); Py_XDECREF(zero); Py_XDECREF(one); if (PyErr_Occurred()) |