summaryrefslogtreecommitdiffstats
path: root/Objects
diff options
context:
space:
mode:
Diffstat (limited to 'Objects')
-rw-r--r--Objects/object.c16
-rw-r--r--Objects/typeobject.c11
2 files changed, 21 insertions, 6 deletions
diff --git a/Objects/object.c b/Objects/object.c
index c56c3be..668bd4f 100644
--- a/Objects/object.c
+++ b/Objects/object.c
@@ -455,11 +455,25 @@ try_3way_compare(PyObject *v, PyObject *w)
/* Comparisons involving instances are given to instance_compare,
which has the same return conventions as this function. */
+ f = v->ob_type->tp_compare;
if (PyInstance_Check(v))
- return (*v->ob_type->tp_compare)(v, w);
+ return (*f)(v, w);
if (PyInstance_Check(w))
return (*w->ob_type->tp_compare)(v, w);
+ /* If both have the same (non-NULL) tp_compare, use it. */
+ if (f != NULL && f == w->ob_type->tp_compare) {
+ c = (*f)(v, w);
+ if (c < 0 && PyErr_Occurred())
+ return -1;
+ return c < 0 ? -1 : c > 0 ? 1 : 0;
+ }
+
+ /* If either tp_compare is _PyObject_SlotCompare, that's safe. */
+ if (f == _PyObject_SlotCompare ||
+ w->ob_type->tp_compare == _PyObject_SlotCompare)
+ return _PyObject_SlotCompare(v, w);
+
/* Try coercion; if it fails, give up */
c = PyNumber_CoerceEx(&v, &w);
if (c < 0)
diff --git a/Objects/typeobject.c b/Objects/typeobject.c
index 792a9f3..26ddabe 100644
--- a/Objects/typeobject.c
+++ b/Objects/typeobject.c
@@ -2761,17 +2761,18 @@ half_compare(PyObject *self, PyObject *other)
return 2;
}
-static int
-slot_tp_compare(PyObject *self, PyObject *other)
+/* This slot is published for the benefit of try_3way_compare in object.c */
+int
+_PyObject_SlotCompare(PyObject *self, PyObject *other)
{
int c;
- if (self->ob_type->tp_compare == slot_tp_compare) {
+ if (self->ob_type->tp_compare == _PyObject_SlotCompare) {
c = half_compare(self, other);
if (c <= 1)
return c;
}
- if (other->ob_type->tp_compare == slot_tp_compare) {
+ if (other->ob_type->tp_compare == _PyObject_SlotCompare) {
c = half_compare(other, self);
if (c < -1)
return -2;
@@ -3190,7 +3191,7 @@ override_slots(PyTypeObject *type, PyObject *dict)
PyDict_GetItemString(dict, "__repr__"))
type->tp_print = NULL;
- TPSLOT("__cmp__", tp_compare, slot_tp_compare);
+ TPSLOT("__cmp__", tp_compare, _PyObject_SlotCompare);
TPSLOT("__repr__", tp_repr, slot_tp_repr);
TPSLOT("__hash__", tp_hash, slot_tp_hash);
TPSLOT("__call__", tp_call, slot_tp_call);