summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_collections.py15
-rw-r--r--Modules/_collectionsmodule.c71
2 files changed, 63 insertions, 23 deletions
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
index deda1cd..d785fcb 100644
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -3,7 +3,7 @@
import unittest, doctest, operator
import inspect
from test import support
-from collections import namedtuple, Counter, OrderedDict
+from collections import namedtuple, Counter, OrderedDict, _count_elements
from test import mapping_tests
import pickle, copy
from random import randrange, shuffle
@@ -775,6 +775,19 @@ class TestCounter(unittest.TestCase):
c.subtract('aaaabbcce')
self.assertEqual(c, Counter(a=-1, b=0, c=-1, d=1, e=-1))
+ def test_helper_function(self):
+ # two paths, one for real dicts and one for other mappings
+ elems = list('abracadabra')
+
+ d = dict()
+ _count_elements(d, elems)
+ self.assertEqual(d, {'a': 5, 'r': 2, 'b': 2, 'c': 1, 'd': 1})
+
+ m = OrderedDict()
+ _count_elements(m, elems)
+ self.assertEqual(m,
+ OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)]))
+
class TestOrderedDict(unittest.TestCase):
def test_init(self):
diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c
index 684b873..f4a2c8b 100644
--- a/Modules/_collectionsmodule.c
+++ b/Modules/_collectionsmodule.c
@@ -1536,41 +1536,68 @@ _count_elements(PyObject *self, PyObject *args)
if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable))
return NULL;
- if (!PyDict_Check(mapping)) {
- PyErr_SetString(PyExc_TypeError,
- "Expected mapping argument to be a dictionary");
- return NULL;
- }
-
it = PyObject_GetIter(iterable);
if (it == NULL)
return NULL;
+
one = PyLong_FromLong(1);
if (one == NULL) {
Py_DECREF(it);
return NULL;
}
- while (1) {
- key = PyIter_Next(it);
- if (key == NULL) {
- if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration))
- PyErr_Clear();
- break;
+
+ if (PyDict_CheckExact(mapping)) {
+ while (1) {
+ key = PyIter_Next(it);
+ if (key == NULL) {
+ if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration))
+ PyErr_Clear();
+ else
+ break;
+ }
+ oldval = PyDict_GetItem(mapping, key);
+ if (oldval == NULL) {
+ if (PyDict_SetItem(mapping, key, one) == -1)
+ break;
+ } else {
+ newval = PyNumber_Add(oldval, one);
+ if (newval == NULL)
+ break;
+ if (PyDict_SetItem(mapping, key, newval) == -1)
+ break;
+ Py_CLEAR(newval);
+ }
+ Py_DECREF(key);
}
- oldval = PyDict_GetItem(mapping, key);
- if (oldval == NULL) {
- if (PyDict_SetItem(mapping, key, one) == -1)
- break;
- } else {
- newval = PyNumber_Add(oldval, one);
- if (newval == NULL)
- break;
- if (PyDict_SetItem(mapping, key, newval) == -1)
+ } else {
+ while (1) {
+ key = PyIter_Next(it);
+ if (key == NULL) {
+ if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration))
+ PyErr_Clear();
+ else
+ break;
+ }
+ oldval = PyObject_GetItem(mapping, key);
+ if (oldval == NULL) {
+ if (!PyErr_Occurred() || !PyErr_ExceptionMatches(PyExc_KeyError))
+ break;
+ PyErr_Clear();
+ Py_INCREF(one);
+ newval = one;
+ } else {
+ newval = PyNumber_Add(oldval, one);
+ Py_DECREF(oldval);
+ if (newval == NULL)
+ break;
+ }
+ if (PyObject_SetItem(mapping, key, newval) == -1)
break;
Py_CLEAR(newval);
+ Py_DECREF(key);
}
- Py_DECREF(key);
}
+
Py_DECREF(it);
Py_XDECREF(key);
Py_XDECREF(newval);