summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Include/complexobject.h1
-rw-r--r--Lib/test/test_descr.py24
-rw-r--r--Objects/complexobject.c30
-rw-r--r--Objects/floatobject.c2
4 files changed, 45 insertions, 12 deletions
diff --git a/Include/complexobject.h b/Include/complexobject.h
index edd2069..b82eb90 100644
--- a/Include/complexobject.h
+++ b/Include/complexobject.h
@@ -43,6 +43,7 @@ typedef struct {
extern DL_IMPORT(PyTypeObject) PyComplex_Type;
#define PyComplex_Check(op) PyObject_TypeCheck(op, &PyComplex_Type)
+#define PyComplex_CheckExact(op) ((op)->ob_type == &PyComplex_Type)
extern DL_IMPORT(PyObject *) PyComplex_FromCComplex(Py_complex);
extern DL_IMPORT(PyObject *) PyComplex_FromDoubles(double real, double imag);
diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py
index b029979..f8b7fc7 100644
--- a/Lib/test/test_descr.py
+++ b/Lib/test/test_descr.py
@@ -1430,6 +1430,30 @@ def inherits():
verify(hash(a) == hash(12345.0))
verify((+a).__class__ is float)
+ class madcomplex(complex):
+ def __repr__(self):
+ return "%.17gj%+.17g" % (self.imag, self.real)
+ a = madcomplex(-3, 4)
+ verify(repr(a) == "4j-3")
+ base = complex(-3, 4)
+ verify(base.__class__ is complex)
+ verify(complex(a) == base)
+ verify(complex(a).__class__ is complex)
+ a = madcomplex(a) # just trying another form of the constructor
+ verify(repr(a) == "4j-3")
+ verify(complex(a) == base)
+ verify(complex(a).__class__ is complex)
+ verify(hash(a) == hash(base))
+ verify((+a).__class__ is complex)
+ verify((a + 0).__class__ is complex)
+ verify(a + 0 == base)
+ verify((a - 0).__class__ is complex)
+ verify(a - 0 == base)
+ verify((a * 1).__class__ is complex)
+ verify(a * 1 == base)
+ verify((a / 1).__class__ is complex)
+ verify(a / 1 == base)
+
class madtuple(tuple):
_rev = None
def rev(self):
diff --git a/Objects/complexobject.c b/Objects/complexobject.c
index 7404993..a8419e3 100644
--- a/Objects/complexobject.c
+++ b/Objects/complexobject.c
@@ -489,8 +489,12 @@ complex_neg(PyComplexObject *v)
static PyObject *
complex_pos(PyComplexObject *v)
{
- Py_INCREF(v);
- return (PyObject *)v;
+ if (PyComplex_CheckExact(v)) {
+ Py_INCREF(v);
+ return (PyObject *)v;
+ }
+ else
+ return PyComplex_FromCComplex(v->cval);
}
static PyObject *
@@ -792,11 +796,12 @@ complex_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
return NULL;
if (PyString_Check(r) || PyUnicode_Check(r))
return complex_subtype_from_string(type, r);
- if ((nbr = r->ob_type->tp_as_number) == NULL ||
- nbr->nb_float == NULL ||
- (i != NULL &&
- ((nbi = i->ob_type->tp_as_number) == NULL ||
- nbi->nb_float == NULL))) {
+
+ nbr = r->ob_type->tp_as_number;
+ if (i != NULL)
+ nbi = i->ob_type->tp_as_number;
+ if (nbr == NULL || nbr->nb_float == NULL ||
+ ((i != NULL) && (nbi == NULL || nbi->nb_float == NULL))) {
PyErr_SetString(PyExc_TypeError,
"complex() arg can't be converted to complex");
return NULL;
@@ -826,6 +831,9 @@ complex_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
}
}
if (PyComplex_Check(r)) {
+ /* Note that if r is of a complex subtype, we're only
+ retaining its real & imag parts here, and the return
+ value is (properly) of the builtin complex type. */
cr = ((PyComplexObject*)r)->cval;
if (own_r) {
Py_DECREF(r);
@@ -868,10 +876,10 @@ complex_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
}
static char complex_doc[] =
-"complex(real[, imag]) -> complex number\n\
-\n\
-Create a complex number from a real part and an optional imaginary part.\n\
-This is equivalent to (real + imag*1j) where imag defaults to 0.";
+"complex(real[, imag]) -> complex number\n"
+"\n"
+"Create a complex number from a real part and an optional imaginary part.\n"
+"This is equivalent to (real + imag*1j) where imag defaults to 0.";
static PyNumberMethods complex_as_number = {
(binaryfunc)complex_add, /* nb_add */
diff --git a/Objects/floatobject.c b/Objects/floatobject.c
index 880eb0e..b9a5e1b 100644
--- a/Objects/floatobject.c
+++ b/Objects/floatobject.c
@@ -659,7 +659,7 @@ float_subtype_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
tmp = float_new(&PyFloat_Type, args, kwds);
if (tmp == NULL)
return NULL;
- assert(PyFloat_Check(tmp));
+ assert(PyFloat_CheckExact(tmp));
new = type->tp_alloc(type, 0);
if (new == NULL)
return NULL;