diff options
-rw-r--r-- | Lib/test/test_richcmp.py | 28 | ||||
-rw-r--r-- | Misc/NEWS | 6 | ||||
-rw-r--r-- | Objects/dictobject.c | 74 |
3 files changed, 104 insertions, 4 deletions
diff --git a/Lib/test/test_richcmp.py b/Lib/test/test_richcmp.py index 7884c7e..4e7d459 100644 --- a/Lib/test/test_richcmp.py +++ b/Lib/test/test_richcmp.py @@ -221,6 +221,33 @@ def recursion(): check('not a==b') if verbose: print "recursion tests ok" +def dicts(): + # Verify that __eq__ and __ne__ work for dicts even if the keys and + # values don't support anything other than __eq__ and __ne__. Complex + # numbers are a fine example of that. + import random + imag1a = {} + for i in range(50): + imag1a[random.randrange(100)*1j] = random.randrange(100)*1j + items = imag1a.items() + random.shuffle(items) + imag1b = {} + for k, v in items: + imag1b[k] = v + imag2 = imag1b.copy() + imag2[k] = v + 1.0 + verify(imag1a == imag1a, "imag1a == imag1a should have worked") + verify(imag1a == imag1b, "imag1a == imag1b should have worked") + verify(imag2 == imag2, "imag2 == imag2 should have worked") + verify(imag1a != imag2, "imag1a != imag2 should have worked") + for op in "<", "<=", ">", ">=": + try: + eval("imag1a %s imag2" % op) + except TypeError: + pass + else: + raise TestFailed("expected TypeError from imag1a %s imag2" % op) + def main(): basic() tabulate() @@ -229,5 +256,6 @@ def main(): testvector() misbehavin() recursion() + dicts() main() @@ -17,14 +17,16 @@ Core - The following functions were generalized to work nicely with iterator arguments: - map(), filter(), reduce() + map(), filter(), reduce(), zip() list(), tuple() (PySequence_Tuple() and PySequence_Fast() in C API) max(), min() - zip() .join() method of strings 'x in y' and 'x not in y' (PySequence_Contains() in C API) operator.countOf() (PySequence_Count() in C API) +- Comparing dictionary objects via == and != is faster, and now works even + if the keys and values don't support comparisons other than ==. + What's New in Python 2.1 (final)? ================================= diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 96d779d..56cc08f 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -1047,6 +1047,76 @@ dict_compare(dictobject *a, dictobject *b) return res; } +/* Return 1 if dicts equal, 0 if not, -1 if error. + * Gets out as soon as any difference is detected. + * Uses only Py_EQ comparison. + */ +static int +dict_equal(dictobject *a, dictobject *b) +{ + int i; + + if (a->ma_used != b->ma_used) + /* can't be equal if # of entries differ */ + return 0; + + /* Same # of entries -- check all of 'em. Exit early on any diff. */ + for (i = 0; i < a->ma_size; i++) { + PyObject *aval = a->ma_table[i].me_value; + if (aval != NULL) { + int cmp; + PyObject *bval; + PyObject *key = a->ma_table[i].me_key; + /* temporarily bump aval's refcount to ensure it stays + alive until we're done with it */ + Py_INCREF(aval); + bval = PyDict_GetItem((PyObject *)b, key); + if (bval == NULL) { + Py_DECREF(aval); + return 0; + } + cmp = PyObject_RichCompareBool(aval, bval, Py_EQ); + Py_DECREF(aval); + if (cmp <= 0) /* error or not equal */ + return cmp; + } + } + return 1; + } + +static PyObject * +dict_richcompare(PyObject *v, PyObject *w, int op) +{ + int cmp; + PyObject *res; + + if (!PyDict_Check(v) || !PyDict_Check(w)) { + res = Py_NotImplemented; + } + else if (op == Py_EQ || op == Py_NE) { + cmp = dict_equal((dictobject *)v, (dictobject *)w); + if (cmp < 0) + return NULL; + res = (cmp == (op == Py_EQ)) ? Py_True : Py_False; + } + else { + cmp = dict_compare((dictobject *)v, (dictobject *)w); + if (cmp < 0 && PyErr_Occurred()) + return NULL; + switch (op) { + case Py_LT: cmp = cmp < 0; break; + case Py_LE: cmp = cmp <= 0; break; + case Py_GT: cmp = cmp > 0; break; + case Py_GE: cmp = cmp >= 0; break; + default: + assert(!"op unexpected"); + } + res = cmp ? Py_True : Py_False; + } + Py_INCREF(res); + return res; + } + static PyObject * dict_has_key(register dictobject *mp, PyObject *args) { @@ -1410,7 +1480,7 @@ PyTypeObject PyDict_Type = { (printfunc)dict_print, /* tp_print */ (getattrfunc)dict_getattr, /* tp_getattr */ 0, /* tp_setattr */ - (cmpfunc)dict_compare, /* tp_compare */ + 0, /* tp_compare */ (reprfunc)dict_repr, /* tp_repr */ 0, /* tp_as_number */ &dict_as_sequence, /* tp_as_sequence */ @@ -1425,7 +1495,7 @@ PyTypeObject PyDict_Type = { 0, /* tp_doc */ (traverseproc)dict_traverse, /* tp_traverse */ (inquiry)dict_tp_clear, /* tp_clear */ - 0, /* tp_richcompare */ + dict_richcompare, /* tp_richcompare */ 0, /* tp_weaklistoffset */ (getiterfunc)dict_iter, /* tp_iter */ 0, /* tp_iternext */ |