summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_richcmp.py28
-rw-r--r--Misc/NEWS6
-rw-r--r--Objects/dictobject.c74
3 files changed, 104 insertions, 4 deletions
diff --git a/Lib/test/test_richcmp.py b/Lib/test/test_richcmp.py
index 7884c7e..4e7d459 100644
--- a/Lib/test/test_richcmp.py
+++ b/Lib/test/test_richcmp.py
@@ -221,6 +221,33 @@ def recursion():
check('not a==b')
if verbose: print "recursion tests ok"
+def dicts():
+ # Verify that __eq__ and __ne__ work for dicts even if the keys and
+ # values don't support anything other than __eq__ and __ne__. Complex
+ # numbers are a fine example of that.
+ import random
+ imag1a = {}
+ for i in range(50):
+ imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
+ items = imag1a.items()
+ random.shuffle(items)
+ imag1b = {}
+ for k, v in items:
+ imag1b[k] = v
+ imag2 = imag1b.copy()
+ imag2[k] = v + 1.0
+ verify(imag1a == imag1a, "imag1a == imag1a should have worked")
+ verify(imag1a == imag1b, "imag1a == imag1b should have worked")
+ verify(imag2 == imag2, "imag2 == imag2 should have worked")
+ verify(imag1a != imag2, "imag1a != imag2 should have worked")
+ for op in "<", "<=", ">", ">=":
+ try:
+ eval("imag1a %s imag2" % op)
+ except TypeError:
+ pass
+ else:
+ raise TestFailed("expected TypeError from imag1a %s imag2" % op)
+
def main():
basic()
tabulate()
@@ -229,5 +256,6 @@ def main():
testvector()
misbehavin()
recursion()
+ dicts()
main()
diff --git a/Misc/NEWS b/Misc/NEWS
index 1f971cd..f2150d5 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -17,14 +17,16 @@ Core
- The following functions were generalized to work nicely with iterator
arguments:
- map(), filter(), reduce()
+ map(), filter(), reduce(), zip()
list(), tuple() (PySequence_Tuple() and PySequence_Fast() in C API)
max(), min()
- zip()
.join() method of strings
'x in y' and 'x not in y' (PySequence_Contains() in C API)
operator.countOf() (PySequence_Count() in C API)
+- Comparing dictionary objects via == and != is faster, and now works even
+ if the keys and values don't support comparisons other than ==.
+
What's New in Python 2.1 (final)?
=================================
diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index 96d779d..56cc08f 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -1047,6 +1047,76 @@ dict_compare(dictobject *a, dictobject *b)
return res;
}
+/* Return 1 if dicts equal, 0 if not, -1 if error.
+ * Gets out as soon as any difference is detected.
+ * Uses only Py_EQ comparison.
+ */
+static int
+dict_equal(dictobject *a, dictobject *b)
+{
+ int i;
+
+ if (a->ma_used != b->ma_used)
+ /* can't be equal if # of entries differ */
+ return 0;
+
+ /* Same # of entries -- check all of 'em. Exit early on any diff. */
+ for (i = 0; i < a->ma_size; i++) {
+ PyObject *aval = a->ma_table[i].me_value;
+ if (aval != NULL) {
+ int cmp;
+ PyObject *bval;
+ PyObject *key = a->ma_table[i].me_key;
+ /* temporarily bump aval's refcount to ensure it stays
+ alive until we're done with it */
+ Py_INCREF(aval);
+ bval = PyDict_GetItem((PyObject *)b, key);
+ if (bval == NULL) {
+ Py_DECREF(aval);
+ return 0;
+ }
+ cmp = PyObject_RichCompareBool(aval, bval, Py_EQ);
+ Py_DECREF(aval);
+ if (cmp <= 0) /* error or not equal */
+ return cmp;
+ }
+ }
+ return 1;
+ }
+
+static PyObject *
+dict_richcompare(PyObject *v, PyObject *w, int op)
+{
+ int cmp;
+ PyObject *res;
+
+ if (!PyDict_Check(v) || !PyDict_Check(w)) {
+ res = Py_NotImplemented;
+ }
+ else if (op == Py_EQ || op == Py_NE) {
+ cmp = dict_equal((dictobject *)v, (dictobject *)w);
+ if (cmp < 0)
+ return NULL;
+ res = (cmp == (op == Py_EQ)) ? Py_True : Py_False;
+ }
+ else {
+ cmp = dict_compare((dictobject *)v, (dictobject *)w);
+ if (cmp < 0 && PyErr_Occurred())
+ return NULL;
+ switch (op) {
+ case Py_LT: cmp = cmp < 0; break;
+ case Py_LE: cmp = cmp <= 0; break;
+ case Py_GT: cmp = cmp > 0; break;
+ case Py_GE: cmp = cmp >= 0; break;
+ default:
+ assert(!"op unexpected");
+ }
+ res = cmp ? Py_True : Py_False;
+ }
+ Py_INCREF(res);
+ return res;
+ }
+
static PyObject *
dict_has_key(register dictobject *mp, PyObject *args)
{
@@ -1410,7 +1480,7 @@ PyTypeObject PyDict_Type = {
(printfunc)dict_print, /* tp_print */
(getattrfunc)dict_getattr, /* tp_getattr */
0, /* tp_setattr */
- (cmpfunc)dict_compare, /* tp_compare */
+ 0, /* tp_compare */
(reprfunc)dict_repr, /* tp_repr */
0, /* tp_as_number */
&dict_as_sequence, /* tp_as_sequence */
@@ -1425,7 +1495,7 @@ PyTypeObject PyDict_Type = {
0, /* tp_doc */
(traverseproc)dict_traverse, /* tp_traverse */
(inquiry)dict_tp_clear, /* tp_clear */
- 0, /* tp_richcompare */
+ dict_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
(getiterfunc)dict_iter, /* tp_iter */
0, /* tp_iternext */