summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_dictviews.py14
-rw-r--r--Misc/NEWS.d/next/Core and Builtins/2018-06-14-13-55-45.bpo-27575.mMYgzv.rst2
-rw-r--r--Objects/dictobject.c81
3 files changed, 93 insertions, 4 deletions
diff --git a/Lib/test/test_dictviews.py b/Lib/test/test_dictviews.py
index 2763cbf..b15cfeb 100644
--- a/Lib/test/test_dictviews.py
+++ b/Lib/test/test_dictviews.py
@@ -92,6 +92,12 @@ class DictSetTest(unittest.TestCase):
d1 = {'a': 1, 'b': 2}
d2 = {'b': 3, 'c': 2}
d3 = {'d': 4, 'e': 5}
+ d4 = {'d': 4}
+
+ class CustomSet(set):
+ def intersection(self, other):
+ return CustomSet(super().intersection(other))
+
self.assertEqual(d1.keys() & d1.keys(), {'a', 'b'})
self.assertEqual(d1.keys() & d2.keys(), {'b'})
self.assertEqual(d1.keys() & d3.keys(), set())
@@ -99,6 +105,14 @@ class DictSetTest(unittest.TestCase):
self.assertEqual(d1.keys() & set(d2.keys()), {'b'})
self.assertEqual(d1.keys() & set(d3.keys()), set())
self.assertEqual(d1.keys() & tuple(d1.keys()), {'a', 'b'})
+ self.assertEqual(d3.keys() & d4.keys(), {'d'})
+ self.assertEqual(d4.keys() & d3.keys(), {'d'})
+ self.assertEqual(d4.keys() & set(d3.keys()), {'d'})
+ self.assertIsInstance(d4.keys() & frozenset(d3.keys()), set)
+ self.assertIsInstance(frozenset(d3.keys()) & d4.keys(), set)
+ self.assertIs(type(d4.keys() & CustomSet(d3.keys())), set)
+ self.assertIs(type(d1.keys() & []), set)
+ self.assertIs(type([] & d1.keys()), set)
self.assertEqual(d1.keys() | d1.keys(), {'a', 'b'})
self.assertEqual(d1.keys() | d2.keys(), {'a', 'b', 'c'})
diff --git a/Misc/NEWS.d/next/Core and Builtins/2018-06-14-13-55-45.bpo-27575.mMYgzv.rst b/Misc/NEWS.d/next/Core and Builtins/2018-06-14-13-55-45.bpo-27575.mMYgzv.rst
new file mode 100644
index 0000000..2c250dc
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2018-06-14-13-55-45.bpo-27575.mMYgzv.rst
@@ -0,0 +1,2 @@
+Improve speed of dictview intersection by directly using set intersection
+logic. Patch by David Su.
diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index f168ad5..fec3a87 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -4169,24 +4169,97 @@ dictviews_sub(PyObject* self, PyObject *other)
return result;
}
-PyObject*
+static int
+dictitems_contains(_PyDictViewObject *dv, PyObject *obj);
+
+PyObject *
_PyDictView_Intersect(PyObject* self, PyObject *other)
{
- PyObject *result = PySet_New(self);
+ PyObject *result;
+ PyObject *it;
+ PyObject *key;
+ Py_ssize_t len_self;
+ int rv;
+ int (*dict_contains)(_PyDictViewObject *, PyObject *);
PyObject *tmp;
- _Py_IDENTIFIER(intersection_update);
+ /* Python interpreter swaps parameters when dict view
+ is on right side of & */
+ if (!PyDictViewSet_Check(self)) {
+ PyObject *tmp = other;
+ other = self;
+ self = tmp;
+ }
+
+ len_self = dictview_len((_PyDictViewObject *)self);
+
+ /* if other is a set and self is smaller than other,
+ reuse set intersection logic */
+ if (Py_TYPE(other) == &PySet_Type && len_self <= PyObject_Size(other)) {
+ _Py_IDENTIFIER(intersection);
+ return _PyObject_CallMethodIdObjArgs(other, &PyId_intersection, self, NULL);
+ }
+
+ /* if other is another dict view, and it is bigger than self,
+ swap them */
+ if (PyDictViewSet_Check(other)) {
+ Py_ssize_t len_other = dictview_len((_PyDictViewObject *)other);
+ if (len_other > len_self) {
+ PyObject *tmp = other;
+ other = self;
+ self = tmp;
+ }
+ }
+
+ /* at this point, two things should be true
+ 1. self is a dictview
+ 2. if other is a dictview then it is smaller than self */
+ result = PySet_New(NULL);
if (result == NULL)
return NULL;
+ it = PyObject_GetIter(other);
+
+ _Py_IDENTIFIER(intersection_update);
tmp = _PyObject_CallMethodIdOneArg(result, &PyId_intersection_update, other);
if (tmp == NULL) {
Py_DECREF(result);
return NULL;
}
-
Py_DECREF(tmp);
+
+ if (PyDictKeys_Check(self)) {
+ dict_contains = dictkeys_contains;
+ }
+ /* else PyDictItems_Check(self) */
+ else {
+ dict_contains = dictitems_contains;
+ }
+
+ while ((key = PyIter_Next(it)) != NULL) {
+ rv = dict_contains((_PyDictViewObject *)self, key);
+ if (rv < 0) {
+ goto error;
+ }
+ if (rv) {
+ if (PySet_Add(result, key)) {
+ goto error;
+ }
+ }
+ Py_DECREF(key);
+ }
+ Py_DECREF(it);
+ if (PyErr_Occurred()) {
+ Py_DECREF(result);
+ return NULL;
+ }
return result;
+
+error:
+ Py_DECREF(it);
+ Py_DECREF(result);
+ Py_DECREF(key);
+ return NULL;
}
static PyObject*