diff options
-rw-r--r-- | Lib/test/test_collections.py | 15 | ||||
-rw-r--r-- | Modules/_collectionsmodule.c | 71 |
2 files changed, 63 insertions, 23 deletions
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index deda1cd..d785fcb 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -3,7 +3,7 @@ import unittest, doctest, operator import inspect from test import support -from collections import namedtuple, Counter, OrderedDict +from collections import namedtuple, Counter, OrderedDict, _count_elements from test import mapping_tests import pickle, copy from random import randrange, shuffle @@ -775,6 +775,19 @@ class TestCounter(unittest.TestCase): c.subtract('aaaabbcce') self.assertEqual(c, Counter(a=-1, b=0, c=-1, d=1, e=-1)) + def test_helper_function(self): + # two paths, one for real dicts and one for other mappings + elems = list('abracadabra') + + d = dict() + _count_elements(d, elems) + self.assertEqual(d, {'a': 5, 'r': 2, 'b': 2, 'c': 1, 'd': 1}) + + m = OrderedDict() + _count_elements(m, elems) + self.assertEqual(m, + OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)])) + class TestOrderedDict(unittest.TestCase): def test_init(self): diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index 684b873..f4a2c8b 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -1536,41 +1536,68 @@ _count_elements(PyObject *self, PyObject *args) if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable)) return NULL; - if (!PyDict_Check(mapping)) { - PyErr_SetString(PyExc_TypeError, - "Expected mapping argument to be a dictionary"); - return NULL; - } - it = PyObject_GetIter(iterable); if (it == NULL) return NULL; + one = PyLong_FromLong(1); if (one == NULL) { Py_DECREF(it); return NULL; } - while (1) { - key = PyIter_Next(it); - if (key == NULL) { - if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration)) - PyErr_Clear(); - break; + + if (PyDict_CheckExact(mapping)) { + while (1) { + key = PyIter_Next(it); + if (key == NULL) { + if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration)) + PyErr_Clear(); + else + break; + } + oldval = PyDict_GetItem(mapping, key); + if (oldval == NULL) { + if (PyDict_SetItem(mapping, key, one) == -1) + break; + } else { + newval = PyNumber_Add(oldval, one); + if (newval == NULL) + break; + if (PyDict_SetItem(mapping, key, newval) == -1) + break; + Py_CLEAR(newval); + } + Py_DECREF(key); } - oldval = PyDict_GetItem(mapping, key); - if (oldval == NULL) { - if (PyDict_SetItem(mapping, key, one) == -1) - break; - } else { - newval = PyNumber_Add(oldval, one); - if (newval == NULL) - break; - if (PyDict_SetItem(mapping, key, newval) == -1) + } else { + while (1) { + key = PyIter_Next(it); + if (key == NULL) { + if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration)) + PyErr_Clear(); + else + break; + } + oldval = PyObject_GetItem(mapping, key); + if (oldval == NULL) { + if (!PyErr_Occurred() || !PyErr_ExceptionMatches(PyExc_KeyError)) + break; + PyErr_Clear(); + Py_INCREF(one); + newval = one; + } else { + newval = PyNumber_Add(oldval, one); + Py_DECREF(oldval); + if (newval == NULL) + break; + } + if (PyObject_SetItem(mapping, key, newval) == -1) break; Py_CLEAR(newval); + Py_DECREF(key); } - Py_DECREF(key); } + Py_DECREF(it); Py_XDECREF(key); Py_XDECREF(newval); |