diff options
author | Raymond Hettinger <python@rcn.com> | 2013-10-04 23:51:02 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2013-10-04 23:51:02 (GMT) |
commit | cb1d96f782f48b5ad1bc39a1024845492a4a123b (patch) | |
tree | 398b5d760459656e6a0ca06afce715ab963b06ec | |
parent | 5b22dd87aa39087f07987f788a0bbd2464e2a8b5 (diff) | |
download | cpython-cb1d96f782f48b5ad1bc39a1024845492a4a123b.zip cpython-cb1d96f782f48b5ad1bc39a1024845492a4a123b.tar.gz cpython-cb1d96f782f48b5ad1bc39a1024845492a4a123b.tar.bz2 |
Issue #18594: Make the C code more closely match the pure python code.
-rw-r--r-- | Lib/test/test_collections.py | 24 | ||||
-rw-r--r-- | Modules/_collectionsmodule.c | 29 |
2 files changed, 39 insertions, 14 deletions
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index af27d22..ff52755 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -818,6 +818,24 @@ class TestCollectionABCs(ABCTestCase): ### Counter ################################################################################ +class CounterSubclassWithSetItem(Counter): + # Test a counter subclass that overrides __setitem__ + def __init__(self, *args, **kwds): + self.called = False + Counter.__init__(self, *args, **kwds) + def __setitem__(self, key, value): + self.called = True + Counter.__setitem__(self, key, value) + +class CounterSubclassWithGet(Counter): + # Test a counter subclass that overrides get() + def __init__(self, *args, **kwds): + self.called = False + Counter.__init__(self, *args, **kwds) + def get(self, key, default): + self.called = True + return Counter.get(self, key, default) + class TestCounter(unittest.TestCase): def test_basics(self): @@ -1022,6 +1040,12 @@ class TestCounter(unittest.TestCase): self.assertEqual(m, OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)])) + # test fidelity to the pure python version + c = CounterSubclassWithSetItem('abracadabra') + self.assertTrue(c.called) + c = CounterSubclassWithGet('abracadabra') + self.assertTrue(c.called) + ################################################################################ ### OrderedDict diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index b244667..c6c7983 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -1689,17 +1689,17 @@ Count elements in the iterable, updating the mappping"); static PyObject * _count_elements(PyObject *self, PyObject *args) { - _Py_IDENTIFIER(__getitem__); + _Py_IDENTIFIER(get); _Py_IDENTIFIER(__setitem__); PyObject *it, *iterable, *mapping, *oldval; PyObject *newval = NULL; PyObject *key = NULL; PyObject *zero = NULL; PyObject *one = NULL; - PyObject *mapping_get = NULL; - PyObject *mapping_getitem; + PyObject *bound_get = NULL; + PyObject *mapping_get; + PyObject *dict_get; PyObject *mapping_setitem; - PyObject *dict_getitem; PyObject *dict_setitem; if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable)) @@ -1713,15 +1713,16 @@ _count_elements(PyObject *self, PyObject *args) if (one == NULL) goto done; - mapping_getitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___getitem__); - dict_getitem = _PyType_LookupId(&PyDict_Type, &PyId___getitem__); + /* Only take the fast path when get() and __setitem__() + * have not been overridden. + */ + mapping_get = _PyType_LookupId(Py_TYPE(mapping), &PyId_get); + dict_get = _PyType_LookupId(&PyDict_Type, &PyId_get); mapping_setitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___setitem__); dict_setitem = _PyType_LookupId(&PyDict_Type, &PyId___setitem__); - if (mapping_getitem != NULL && - mapping_getitem == dict_getitem && - mapping_setitem != NULL && - mapping_setitem == dict_setitem) { + if (mapping_get != NULL && mapping_get == dict_get && + mapping_setitem != NULL && mapping_setitem == dict_setitem) { while (1) { key = PyIter_Next(it); if (key == NULL) @@ -1741,8 +1742,8 @@ _count_elements(PyObject *self, PyObject *args) Py_DECREF(key); } } else { - mapping_get = PyObject_GetAttrString(mapping, "get"); - if (mapping_get == NULL) + bound_get = PyObject_GetAttrString(mapping, "get"); + if (bound_get == NULL) goto done; zero = PyLong_FromLong(0); @@ -1753,7 +1754,7 @@ _count_elements(PyObject *self, PyObject *args) key = PyIter_Next(it); if (key == NULL) break; - oldval = PyObject_CallFunctionObjArgs(mapping_get, key, zero, NULL); + oldval = PyObject_CallFunctionObjArgs(bound_get, key, zero, NULL); if (oldval == NULL) break; newval = PyNumber_Add(oldval, one); @@ -1771,7 +1772,7 @@ done: Py_DECREF(it); Py_XDECREF(key); Py_XDECREF(newval); - Py_XDECREF(mapping_get); + Py_XDECREF(bound_get); Py_XDECREF(zero); Py_XDECREF(one); if (PyErr_Occurred()) |