From 2205642fe0af9c00bbfa713dae1c8ba4562d2236 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 24 Sep 2001 17:52:04 +0000 Subject: Do the same thing to complex that I did to str: the rich comparison function returns NotImplemented when comparing objects whose tp_richcompare slot is not itself. --- Lib/test/test_descr.py | 15 +++++++++++++++ Objects/complexobject.c | 17 ++++++++++------- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index ed3cea4..42e1384 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -1863,6 +1863,21 @@ def classic_comparisons(): def rich_comparisons(): if verbose: print "Testing rich comparisons..." + class Z(complex): + pass + z = Z(1) + verify(z == 1+0j) + verify(1+0j == z) + class ZZ(complex): + def __eq__(self, other): + try: + return abs(self - other) <= 1e-6 + except: + return NotImplemented + zz = ZZ(1.0000003) + verify(zz == 1+0j) + verify(1+0j == zz) + class classic: pass for base in (classic, int, object, list): diff --git a/Objects/complexobject.c b/Objects/complexobject.c index 191dcba..32f2b24 100644 --- a/Objects/complexobject.c +++ b/Objects/complexobject.c @@ -553,12 +553,6 @@ complex_richcompare(PyObject *v, PyObject *w, int op) Py_complex i, j; PyObject *res; - if (op != Py_EQ && op != Py_NE) { - PyErr_SetString(PyExc_TypeError, - "cannot compare complex numbers using <, <=, >, >="); - return NULL; - } - c = PyNumber_CoerceEx(&v, &w); if (c < 0) return NULL; @@ -566,7 +560,10 @@ complex_richcompare(PyObject *v, PyObject *w, int op) Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } - if (!PyComplex_Check(v) || !PyComplex_Check(w)) { + /* May sure both arguments use complex comparison. + This implies PyComplex_Check(a) && PyComplex_Check(b). */ + if (v->ob_type->tp_richcompare != complex_richcompare || + w->ob_type->tp_richcompare != complex_richcompare) { Py_DECREF(v); Py_DECREF(w); Py_INCREF(Py_NotImplemented); @@ -578,6 +575,12 @@ complex_richcompare(PyObject *v, PyObject *w, int op) Py_DECREF(v); Py_DECREF(w); + if (op != Py_EQ && op != Py_NE) { + PyErr_SetString(PyExc_TypeError, + "cannot compare complex numbers using <, <=, >, >="); + return NULL; + } + if ((i.real == j.real && i.imag == j.imag) == (op == Py_EQ)) res = Py_True; else -- cgit v0.12