diff options
author | Raymond Hettinger <python@rcn.com> | 2013-10-04 23:52:39 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2013-10-04 23:52:39 (GMT) |
commit | 07573d7b2418dd3a043c8fc95f40517fa4543048 (patch) | |
tree | 049d77259a79df33d154deddf43ab9cb59bae219 | |
parent | 3ad327ec3ad1887a9e5ab738ba12b5f78751a791 (diff) | |
parent | cb1d96f782f48b5ad1bc39a1024845492a4a123b (diff) | |
download | cpython-07573d7b2418dd3a043c8fc95f40517fa4543048.zip cpython-07573d7b2418dd3a043c8fc95f40517fa4543048.tar.gz cpython-07573d7b2418dd3a043c8fc95f40517fa4543048.tar.bz2 |
merge
-rw-r--r-- | Lib/test/test_collections.py | 24 | ||||
-rw-r--r-- | Modules/_collectionsmodule.c | 29 |
2 files changed, 39 insertions, 14 deletions
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 6c733ee..ade6ee7 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -852,6 +852,24 @@ class TestCollectionABCs(ABCTestCase): ### Counter ################################################################################ +class CounterSubclassWithSetItem(Counter): + # Test a counter subclass that overrides __setitem__ + def __init__(self, *args, **kwds): + self.called = False + Counter.__init__(self, *args, **kwds) + def __setitem__(self, key, value): + self.called = True + Counter.__setitem__(self, key, value) + +class CounterSubclassWithGet(Counter): + # Test a counter subclass that overrides get() + def __init__(self, *args, **kwds): + self.called = False + Counter.__init__(self, *args, **kwds) + def get(self, key, default): + self.called = True + return Counter.get(self, key, default) + class TestCounter(unittest.TestCase): def test_basics(self): @@ -1059,6 +1077,12 @@ class TestCounter(unittest.TestCase): self.assertEqual(m, OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)])) + # test fidelity to the pure python version + c = CounterSubclassWithSetItem('abracadabra') + self.assertTrue(c.called) + c = CounterSubclassWithGet('abracadabra') + self.assertTrue(c.called) + ################################################################################ ### OrderedDict diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index 3fadb70..0cc013b 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -1763,17 +1763,17 @@ Count elements in the iterable, updating the mappping"); static PyObject * _count_elements(PyObject *self, PyObject *args) { - _Py_IDENTIFIER(__getitem__); + _Py_IDENTIFIER(get); _Py_IDENTIFIER(__setitem__); PyObject *it, *iterable, *mapping, *oldval; PyObject *newval = NULL; PyObject *key = NULL; PyObject *zero = NULL; PyObject *one = NULL; - PyObject *mapping_get = NULL; - PyObject *mapping_getitem; + PyObject *bound_get = NULL; + PyObject *mapping_get; + PyObject *dict_get; PyObject *mapping_setitem; - PyObject *dict_getitem; PyObject *dict_setitem; if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable)) @@ -1787,15 +1787,16 @@ _count_elements(PyObject *self, PyObject *args) if (one == NULL) goto done; - mapping_getitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___getitem__); - dict_getitem = _PyType_LookupId(&PyDict_Type, &PyId___getitem__); + /* Only take the fast path when get() and __setitem__() + * have not been overridden. + */ + mapping_get = _PyType_LookupId(Py_TYPE(mapping), &PyId_get); + dict_get = _PyType_LookupId(&PyDict_Type, &PyId_get); mapping_setitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___setitem__); dict_setitem = _PyType_LookupId(&PyDict_Type, &PyId___setitem__); - if (mapping_getitem != NULL && - mapping_getitem == dict_getitem && - mapping_setitem != NULL && - mapping_setitem == dict_setitem) { + if (mapping_get != NULL && mapping_get == dict_get && + mapping_setitem != NULL && mapping_setitem == dict_setitem) { while (1) { key = PyIter_Next(it); if (key == NULL) @@ -1815,8 +1816,8 @@ _count_elements(PyObject *self, PyObject *args) Py_DECREF(key); } } else { - mapping_get = PyObject_GetAttrString(mapping, "get"); - if (mapping_get == NULL) + bound_get = PyObject_GetAttrString(mapping, "get"); + if (bound_get == NULL) goto done; zero = PyLong_FromLong(0); @@ -1827,7 +1828,7 @@ _count_elements(PyObject *self, PyObject *args) key = PyIter_Next(it); if (key == NULL) break; - oldval = PyObject_CallFunctionObjArgs(mapping_get, key, zero, NULL); + oldval = PyObject_CallFunctionObjArgs(bound_get, key, zero, NULL); if (oldval == NULL) break; newval = PyNumber_Add(oldval, one); @@ -1845,7 +1846,7 @@ done: Py_DECREF(it); Py_XDECREF(key); Py_XDECREF(newval); - Py_XDECREF(mapping_get); + Py_XDECREF(bound_get); Py_XDECREF(zero); Py_XDECREF(one); if (PyErr_Occurred()) |