diff options
-rw-r--r-- | Include/setobject.h | 6 | ||||
-rw-r--r-- | Lib/test/test_set.py | 25 | ||||
-rw-r--r-- | Objects/setobject.c | 106 |
3 files changed, 97 insertions, 40 deletions
diff --git a/Include/setobject.h b/Include/setobject.h index eeffa8a..6289f9c 100644 --- a/Include/setobject.h +++ b/Include/setobject.h @@ -20,6 +20,12 @@ typedef struct { PyAPI_DATA(PyTypeObject) PySet_Type; PyAPI_DATA(PyTypeObject) PyFrozenSet_Type; + +#define PyAnySet_Check(ob) \ + ((ob)->ob_type == &PySet_Type || (ob)->ob_type == &PyFrozenSet_Type || \ + PyType_IsSubtype((ob)->ob_type, &PySet_Type) || \ + PyType_IsSubtype((ob)->ob_type, &PyFrozenSet_Type)) + #ifdef __cplusplus } #endif diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 1edb2dd..8329fd1 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -152,6 +152,13 @@ class TestJointOps(unittest.TestCase): class TestSet(TestJointOps): thetype = set + def test_init(self): + s = set() + s.__init__(self.word) + self.assertEqual(s, set(self.word)) + s.__init__(self.otherword) + self.assertEqual(s, set(self.otherword)) + def test_hash(self): self.assertRaises(TypeError, hash, self.s) @@ -252,10 +259,20 @@ class TestSet(TestJointOps): else: self.assert_(c not in self.s) +class SetSubclass(set): + pass + +class TestSetSubclass(TestSet): + thetype = SetSubclass class TestFrozenSet(TestJointOps): thetype = frozenset + def test_init(self): + s = frozenset() + s.__init__(self.word) + self.assertEqual(s, frozenset()) + def test_hash(self): self.assertEqual(hash(frozenset('abcdeb')), hash(frozenset('ebecda'))) @@ -273,6 +290,12 @@ class TestFrozenSet(TestJointOps): f = frozenset('abcdcda') self.assertEqual(hash(f), hash(f)) +class FrozenSetSubclass(frozenset): + pass + +class TestFrozenSetSubclass(TestFrozenSet): + thetype = FrozenSetSubclass + # Tests taken from test_sets.py ============================================= empty_set = set() @@ -1137,7 +1160,9 @@ def test_main(verbose=None): from test import test_sets test_classes = ( TestSet, + TestSetSubclass, TestFrozenSet, + TestFrozenSetSubclass, TestSetOfSets, TestExceptionPropagation, TestBasicOpsEmpty, diff --git a/Objects/setobject.c b/Objects/setobject.c index 61ba853..7ad8af0 100644 --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -12,7 +12,6 @@ /* Fast access macros */ #define DICT_CONTAINS(d, k) (d->ob_type->tp_as_sequence->sq_contains(d, k)) -#define IS_SET(so) (so->ob_type == &PySet_Type || so->ob_type == &PyFrozenSet_Type) /* set object **********************************************************/ @@ -42,8 +41,6 @@ make_new_set(PyTypeObject *type, PyObject *iterable) Py_DECREF(it); Py_DECREF(data); Py_DECREF(item); - PyErr_SetString(PyExc_TypeError, - "all set entries must be immutable"); return NULL; } Py_DECREF(item); @@ -67,7 +64,7 @@ make_new_set(PyTypeObject *type, PyObject *iterable) } static PyObject * -set_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +frozenset_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { PyObject *iterable = NULL; @@ -76,6 +73,14 @@ set_new(PyTypeObject *type, PyObject *args, PyObject *kwds) return make_new_set(type, iterable); } +static PyObject * +set_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + PyObject *iterable = NULL; + + return make_new_set(type, NULL); +} + static void set_dealloc(PySetObject *so) { @@ -139,6 +144,8 @@ set_union(PySetObject *so, PyObject *other) PyObject *item, *data, *it; result = (PySetObject *)set_copy(so); + if (result == NULL) + return NULL; it = PyObject_GetIter(other); if (it == NULL) { Py_DECREF(result); @@ -150,8 +157,6 @@ set_union(PySetObject *so, PyObject *other) Py_DECREF(it); Py_DECREF(result); Py_DECREF(item); - PyErr_SetString(PyExc_TypeError, - "all set entries must be immutable"); return NULL; } Py_DECREF(item); @@ -183,8 +188,6 @@ set_union_update(PySetObject *so, PyObject *other) if (PyDict_SetItem(data, item, Py_True) == -1) { Py_DECREF(it); Py_DECREF(item); - PyErr_SetString(PyExc_TypeError, - "all set entries must be immutable"); return NULL; } Py_DECREF(item); @@ -201,7 +204,7 @@ PyDoc_STRVAR(union_update_doc, static PyObject * set_or(PySetObject *so, PyObject *other) { - if (!IS_SET(so) || !IS_SET(other)) { + if (!PyAnySet_Check(so) || !PyAnySet_Check(other)) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -213,7 +216,7 @@ set_ior(PySetObject *so, PyObject *other) { PyObject *result; - if (!IS_SET(other)) { + if (!PyAnySet_Check(other)) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -249,8 +252,6 @@ set_intersection(PySetObject *so, PyObject *other) Py_DECREF(it); Py_DECREF(result); Py_DECREF(item); - PyErr_SetString(PyExc_TypeError, - "all set entries must be immutable"); return NULL; } } @@ -291,8 +292,6 @@ set_intersection_update(PySetObject *so, PyObject *other) Py_DECREF(newdict); Py_DECREF(it); Py_DECREF(item); - PyErr_SetString(PyExc_TypeError, - "all set entries must be immutable"); return NULL; } } @@ -315,7 +314,7 @@ PyDoc_STRVAR(intersection_update_doc, static PyObject * set_and(PySetObject *so, PyObject *other) { - if (!IS_SET(so) || !IS_SET(other)) { + if (!PyAnySet_Check(so) || !PyAnySet_Check(other)) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -327,7 +326,7 @@ set_iand(PySetObject *so, PyObject *other) { PyObject *result; - if (!IS_SET(other)) { + if (!PyAnySet_Check(other)) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -416,7 +415,7 @@ PyDoc_STRVAR(difference_update_doc, static PyObject * set_sub(PySetObject *so, PyObject *other) { - if (!IS_SET(so) || !IS_SET(other)) { + if (!PyAnySet_Check(so) || !PyAnySet_Check(other)) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -428,7 +427,7 @@ set_isub(PySetObject *so, PyObject *other) { PyObject *result; - if (!IS_SET(other)) { + if (!PyAnySet_Check(other)) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -475,8 +474,6 @@ set_symmetric_difference(PySetObject *so, PyObject *other) Py_DECREF(otherset); Py_DECREF(result); Py_DECREF(item); - PyErr_SetString(PyExc_TypeError, - "all set entries must be immutable"); return NULL; } } @@ -506,7 +503,7 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other) if (PyDict_Check(other)) otherdata = other; - else if (IS_SET(other)) + else if (PyAnySet_Check(other)) otherdata = ((PySetObject *)other)->data; else { otherset = (PySetObject *)make_new_set(so->ob_type, other); @@ -525,8 +522,6 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other) Py_XDECREF(otherset); Py_DECREF(it); Py_DECREF(item); - PyErr_SetString(PyExc_TypeError, - "all set entries must be immutable"); return NULL; } } else { @@ -534,8 +529,6 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other) Py_XDECREF(otherset); Py_DECREF(it); Py_DECREF(item); - PyErr_SetString(PyExc_TypeError, - "all set entries must be immutable"); return NULL; } } @@ -554,7 +547,7 @@ PyDoc_STRVAR(symmetric_difference_update_doc, static PyObject * set_xor(PySetObject *so, PyObject *other) { - if (!IS_SET(so) || !IS_SET(other)) { + if (!PyAnySet_Check(so) || !PyAnySet_Check(other)) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -566,7 +559,7 @@ set_ixor(PySetObject *so, PyObject *other) { PyObject *result; - if (!IS_SET(other)) { + if (!PyAnySet_Check(other)) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -583,7 +576,7 @@ set_issubset(PySetObject *so, PyObject *other) { PyObject *otherdata, *it, *item; - if (!IS_SET(other)) { + if (!PyAnySet_Check(other)) { PyErr_SetString(PyExc_TypeError, "can only compare to a set"); return NULL; } @@ -604,6 +597,8 @@ set_issubset(PySetObject *so, PyObject *other) Py_DECREF(item); } Py_DECREF(it); + if (PyErr_Occurred()) + return NULL; Py_RETURN_TRUE; } @@ -612,7 +607,7 @@ PyDoc_STRVAR(issubset_doc, "Report whether another set contains this set."); static PyObject * set_issuperset(PySetObject *so, PyObject *other) { - if (!IS_SET(other)) { + if (!PyAnySet_Check(other)) { PyErr_SetString(PyExc_TypeError, "can only compare to a set"); return NULL; } @@ -653,20 +648,21 @@ frozenset_hash(PyObject *self) hash ^= PyObject_Hash(item); Py_DECREF(item); } - so->hash = hash; Py_DECREF(it); + if (PyErr_Occurred()) + return -1; + so->hash = hash; return hash; } static PyObject * set_richcompare(PySetObject *v, PyObject *w, int op) { - /* XXX factor out is_set test */ - if (op == Py_EQ && !IS_SET(w)) - Py_RETURN_FALSE; - else if (op == Py_NE && !IS_SET(w)) - Py_RETURN_TRUE; - if (!IS_SET(w)) { + if(!PyAnySet_Check(w)) { + if (op == Py_EQ) + Py_RETURN_FALSE; + if (op == Py_NE) + Py_RETURN_TRUE; PyErr_SetString(PyExc_TypeError, "can only compare to a set"); return NULL; } @@ -698,8 +694,12 @@ set_repr(PySetObject *so) PyObject *keys, *result, *listrepr; keys = PyDict_Keys(so->data); + if (keys == NULL) + return NULL; listrepr = PyObject_Repr(keys); Py_DECREF(keys); + if (listrepr == NULL) + return NULL; result = PyString_FromFormat("%s(%s)", so->ob_type->tp_name, PyString_AS_STRING(listrepr)); @@ -732,6 +732,8 @@ set_tp_print(PySetObject *so, FILE *fp, int flags) } Py_DECREF(it); fprintf(fp, "])"); + if (PyErr_Occurred()) + return -1; return 0; } @@ -810,8 +812,10 @@ set_pop(PySetObject *so) return NULL; } Py_INCREF(key); - if (PyDict_DelItem(so->data, key) == -1) - PyErr_Clear(); + if (PyDict_DelItem(so->data, key) == -1) { + Py_DECREF(key); + return NULL; + } return key; } @@ -837,6 +841,28 @@ done: PyDoc_STRVAR(reduce_doc, "Return state information for pickling."); +static int +set_init(PySetObject *self, PyObject *args, PyObject *kwds) +{ + PyObject *iterable = NULL; + PyObject *result; + + if (!PyAnySet_Check(self)) + return -1; + if (!PyArg_UnpackTuple(args, self->ob_type->tp_name, 0, 1, &iterable)) + return -1; + PyDict_Clear(self->data); + self->hash = -1; + if (iterable == NULL) + return 0; + result = set_union_update(self, iterable); + if (result != NULL) { + Py_DECREF(result); + return 0; + } + return -1; +} + static PySequenceMethods set_as_sequence = { (inquiry)set_len, /* sq_length */ 0, /* sq_concat */ @@ -971,7 +997,7 @@ PyTypeObject PySet_Type = { 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ - 0, /* tp_init */ + (initproc)set_init, /* tp_init */ PyType_GenericAlloc, /* tp_alloc */ set_new, /* tp_new */ PyObject_GC_Del, /* tp_free */ @@ -1068,6 +1094,6 @@ PyTypeObject PyFrozenSet_Type = { 0, /* tp_dictoffset */ 0, /* tp_init */ PyType_GenericAlloc, /* tp_alloc */ - set_new, /* tp_new */ + frozenset_new, /* tp_new */ PyObject_GC_Del, /* tp_free */ }; |