diff options
-rw-r--r-- | Include/setobject.h | 3 | ||||
-rw-r--r-- | Lib/test/test_set.py | 8 | ||||
-rw-r--r-- | Objects/dictobject.c | 18 | ||||
-rw-r--r-- | Objects/setobject.c | 20 |
4 files changed, 44 insertions, 5 deletions
diff --git a/Include/setobject.h b/Include/setobject.h index a16c2f7..750a2a8 100644 --- a/Include/setobject.h +++ b/Include/setobject.h @@ -82,7 +82,8 @@ PyAPI_FUNC(int) PySet_Clear(PyObject *set); PyAPI_FUNC(int) PySet_Contains(PyObject *anyset, PyObject *key); PyAPI_FUNC(int) PySet_Discard(PyObject *set, PyObject *key); PyAPI_FUNC(int) PySet_Add(PyObject *set, PyObject *key); -PyAPI_FUNC(int) _PySet_Next(PyObject *set, Py_ssize_t *pos, PyObject **entry); +PyAPI_FUNC(int) _PySet_Next(PyObject *set, Py_ssize_t *pos, PyObject **key); +PyAPI_FUNC(int) _PySet_NextEntry(PyObject *set, Py_ssize_t *pos, PyObject **key, long *hash); PyAPI_FUNC(PyObject *) PySet_Pop(PyObject *set); PyAPI_FUNC(int) _PySet_Update(PyObject *set, PyObject *iterable); diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 45f61b2..b46cac4 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -285,10 +285,14 @@ class TestJointOps(unittest.TestCase): s = self.thetype(d) self.assertEqual(sum(elem.hash_count for elem in d), n) s.difference(d) - self.assertEqual(sum(elem.hash_count for elem in d), n) + self.assertEqual(sum(elem.hash_count for elem in d), n) if hasattr(s, 'symmetric_difference_update'): s.symmetric_difference_update(d) - self.assertEqual(sum(elem.hash_count for elem in d), n) + self.assertEqual(sum(elem.hash_count for elem in d), n) + d2 = dict.fromkeys(set(d)) + self.assertEqual(sum(elem.hash_count for elem in d), n) + d3 = dict.fromkeys(frozenset(d)) + self.assertEqual(sum(elem.hash_count for elem in d), n) class TestSet(TestJointOps): thetype = set diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 1cb3ee6..acf5ae3 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -1175,6 +1175,24 @@ dict_fromkeys(PyObject *cls, PyObject *args) if (d == NULL) return NULL; + if (PyDict_CheckExact(d) && PyAnySet_CheckExact(seq)) { + dictobject *mp = (dictobject *)d; + Py_ssize_t pos = 0; + PyObject *key; + long hash; + + if (dictresize(mp, PySet_GET_SIZE(seq))) + return NULL; + + while (_PySet_NextEntry(seq, &pos, &key, &hash)) { + Py_INCREF(key); + Py_INCREF(Py_None); + if (insertdict(mp, key, hash, Py_None)) + return NULL; + } + return d; + } + it = PyObject_GetIter(seq); if (it == NULL){ Py_DECREF(d); diff --git a/Objects/setobject.c b/Objects/setobject.c index 07ba996..a896d93 100644 --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -2137,7 +2137,7 @@ PySet_Add(PyObject *set, PyObject *key) } int -_PySet_Next(PyObject *set, Py_ssize_t *pos, PyObject **entry) +_PySet_Next(PyObject *set, Py_ssize_t *pos, PyObject **key) { setentry *entry_ptr; @@ -2147,7 +2147,23 @@ _PySet_Next(PyObject *set, Py_ssize_t *pos, PyObject **entry) } if (set_next((PySetObject *)set, pos, &entry_ptr) == 0) return 0; - *entry = entry_ptr->key; + *key = entry_ptr->key; + return 1; +} + +int +_PySet_NextEntry(PyObject *set, Py_ssize_t *pos, PyObject **key, long *hash) +{ + setentry *entry; + + if (!PyAnySet_Check(set)) { + PyErr_BadInternalCall(); + return -1; + } + if (set_next((PySetObject *)set, pos, &entry) == 0) + return 0; + *key = entry->key; + *hash = entry->hash; return 1; } |