From 998cf1f03a61de8a0cd3811faa97973d4022bc55 Mon Sep 17 00:00:00 2001 From: Forest Gregg Date: Mon, 26 Aug 2019 02:17:43 -0500 Subject: bpo-27575: port set intersection logic into dictview intersection (GH-7696) --- Lib/test/test_dictviews.py | 14 ++++ .../2018-06-14-13-55-45.bpo-27575.mMYgzv.rst | 2 + Objects/dictobject.c | 81 ++++++++++++++++++++-- 3 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 Misc/NEWS.d/next/Core and Builtins/2018-06-14-13-55-45.bpo-27575.mMYgzv.rst diff --git a/Lib/test/test_dictviews.py b/Lib/test/test_dictviews.py index 2763cbf..b15cfeb 100644 --- a/Lib/test/test_dictviews.py +++ b/Lib/test/test_dictviews.py @@ -92,6 +92,12 @@ class DictSetTest(unittest.TestCase): d1 = {'a': 1, 'b': 2} d2 = {'b': 3, 'c': 2} d3 = {'d': 4, 'e': 5} + d4 = {'d': 4} + + class CustomSet(set): + def intersection(self, other): + return CustomSet(super().intersection(other)) + self.assertEqual(d1.keys() & d1.keys(), {'a', 'b'}) self.assertEqual(d1.keys() & d2.keys(), {'b'}) self.assertEqual(d1.keys() & d3.keys(), set()) @@ -99,6 +105,14 @@ class DictSetTest(unittest.TestCase): self.assertEqual(d1.keys() & set(d2.keys()), {'b'}) self.assertEqual(d1.keys() & set(d3.keys()), set()) self.assertEqual(d1.keys() & tuple(d1.keys()), {'a', 'b'}) + self.assertEqual(d3.keys() & d4.keys(), {'d'}) + self.assertEqual(d4.keys() & d3.keys(), {'d'}) + self.assertEqual(d4.keys() & set(d3.keys()), {'d'}) + self.assertIsInstance(d4.keys() & frozenset(d3.keys()), set) + self.assertIsInstance(frozenset(d3.keys()) & d4.keys(), set) + self.assertIs(type(d4.keys() & CustomSet(d3.keys())), set) + self.assertIs(type(d1.keys() & []), set) + self.assertIs(type([] & d1.keys()), set) self.assertEqual(d1.keys() | d1.keys(), {'a', 'b'}) self.assertEqual(d1.keys() | d2.keys(), {'a', 'b', 'c'}) diff --git a/Misc/NEWS.d/next/Core and Builtins/2018-06-14-13-55-45.bpo-27575.mMYgzv.rst b/Misc/NEWS.d/next/Core and Builtins/2018-06-14-13-55-45.bpo-27575.mMYgzv.rst new file mode 100644 index 0000000..2c250dc --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2018-06-14-13-55-45.bpo-27575.mMYgzv.rst @@ -0,0 +1,2 @@ +Improve speed of dictview intersection by directly using set intersection +logic. Patch by David Su. diff --git a/Objects/dictobject.c b/Objects/dictobject.c index f168ad5..fec3a87 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -4169,24 +4169,97 @@ dictviews_sub(PyObject* self, PyObject *other) return result; } -PyObject* +static int +dictitems_contains(_PyDictViewObject *dv, PyObject *obj); + +PyObject * _PyDictView_Intersect(PyObject* self, PyObject *other) { - PyObject *result = PySet_New(self); + PyObject *result; + PyObject *it; + PyObject *key; + Py_ssize_t len_self; + int rv; + int (*dict_contains)(_PyDictViewObject *, PyObject *); PyObject *tmp; - _Py_IDENTIFIER(intersection_update); + /* Python interpreter swaps parameters when dict view + is on right side of & */ + if (!PyDictViewSet_Check(self)) { + PyObject *tmp = other; + other = self; + self = tmp; + } + + len_self = dictview_len((_PyDictViewObject *)self); + + /* if other is a set and self is smaller than other, + reuse set intersection logic */ + if (Py_TYPE(other) == &PySet_Type && len_self <= PyObject_Size(other)) { + _Py_IDENTIFIER(intersection); + return _PyObject_CallMethodIdObjArgs(other, &PyId_intersection, self, NULL); + } + + /* if other is another dict view, and it is bigger than self, + swap them */ + if (PyDictViewSet_Check(other)) { + Py_ssize_t len_other = dictview_len((_PyDictViewObject *)other); + if (len_other > len_self) { + PyObject *tmp = other; + other = self; + self = tmp; + } + } + + /* at this point, two things should be true + 1. self is a dictview + 2. if other is a dictview then it is smaller than self */ + result = PySet_New(NULL); if (result == NULL) return NULL; + it = PyObject_GetIter(other); + + _Py_IDENTIFIER(intersection_update); tmp = _PyObject_CallMethodIdOneArg(result, &PyId_intersection_update, other); if (tmp == NULL) { Py_DECREF(result); return NULL; } - Py_DECREF(tmp); + + if (PyDictKeys_Check(self)) { + dict_contains = dictkeys_contains; + } + /* else PyDictItems_Check(self) */ + else { + dict_contains = dictitems_contains; + } + + while ((key = PyIter_Next(it)) != NULL) { + rv = dict_contains((_PyDictViewObject *)self, key); + if (rv < 0) { + goto error; + } + if (rv) { + if (PySet_Add(result, key)) { + goto error; + } + } + Py_DECREF(key); + } + Py_DECREF(it); + if (PyErr_Occurred()) { + Py_DECREF(result); + return NULL; + } return result; + +error: + Py_DECREF(it); + Py_DECREF(result); + Py_DECREF(key); + return NULL; } static PyObject* -- cgit v0.12