diff options
author | Armin Rigo <arigo@tunes.org> | 2005-12-29 15:59:19 (GMT) |
---|---|---|
committer | Armin Rigo <arigo@tunes.org> | 2005-12-29 15:59:19 (GMT) |
commit | fd163f92cee2aa8189879bd43670782f4cfd2cf8 (patch) | |
tree | 9bd3785dde016396b17a006059c68b3e4b60023e | |
parent | c4308d5be64a622ee7be685c5eb05f90782711c1 (diff) | |
download | cpython-fd163f92cee2aa8189879bd43670782f4cfd2cf8.zip cpython-fd163f92cee2aa8189879bd43670782f4cfd2cf8.tar.gz cpython-fd163f92cee2aa8189879bd43670782f4cfd2cf8.tar.bz2 |
SF patch #1390657:
* set sq_repeat and sq_concat to NULL for user-defined new-style
classes, as a way to fix a number of related problems. See
test_descr.notimplemented()). One of these problems was fixed
in r25556 and r25557 but many more existed; this is a general
fix and thus reverts r25556-r25557.
* to avoid having PySequence_Repeat()/PySequence_Concat() failing
on user-defined classes, they now fall back to nb_add/nb_mul if
sq_concat/sq_repeat are not defined and the arguments appear to
be sequences.
* added tests.
Backport candidate.
-rw-r--r-- | Lib/test/test_descr.py | 72 | ||||
-rw-r--r-- | Lib/test/test_operator.py | 40 | ||||
-rw-r--r-- | Objects/abstract.c | 51 | ||||
-rw-r--r-- | Objects/typeobject.c | 31 |
4 files changed, 172 insertions, 22 deletions
diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index f594ca8..2ea8186 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -3990,6 +3990,77 @@ def methodwrapper(): verify(l.__add__.__objclass__ is list) vereq(l.__add__.__doc__, list.__add__.__doc__) +def notimplemented(): + # all binary methods should be able to return a NotImplemented + if verbose: + print "Testing NotImplemented..." + + import sys + import types + import operator + + def specialmethod(self, other): + return NotImplemented + + def check(expr, x, y): + try: + exec expr in {'x': x, 'y': y, 'operator': operator} + except TypeError: + pass + else: + raise TestFailed("no TypeError from %r" % (expr,)) + + N1 = sys.maxint + 1L # might trigger OverflowErrors instead of TypeErrors + N2 = sys.maxint # if sizeof(int) < sizeof(long), might trigger + # ValueErrors instead of TypeErrors + for metaclass in [type, types.ClassType]: + for name, expr, iexpr in [ + ('__add__', 'x + y', 'x += y'), + ('__sub__', 'x - y', 'x -= y'), + ('__mul__', 'x * y', 'x *= y'), + ('__truediv__', 'operator.truediv(x, y)', None), + ('__floordiv__', 'operator.floordiv(x, y)', None), + ('__div__', 'x / y', 'x /= y'), + ('__mod__', 'x % y', 'x %= y'), + ('__divmod__', 'divmod(x, y)', None), + ('__pow__', 'x ** y', 'x **= y'), + ('__lshift__', 'x << y', 'x <<= y'), + ('__rshift__', 'x >> y', 'x >>= y'), + ('__and__', 'x & y', 'x &= y'), + ('__or__', 'x | y', 'x |= y'), + ('__xor__', 'x ^ y', 'x ^= y'), + ('__coerce__', 'coerce(x, y)', None)]: + if name == '__coerce__': + rname = name + else: + rname = '__r' + name[2:] + A = metaclass('A', (), {name: specialmethod}) + B = metaclass('B', (), {rname: specialmethod}) + a = A() + b = B() + check(expr, a, a) + check(expr, a, b) + check(expr, b, a) + check(expr, b, b) + check(expr, a, N1) + check(expr, a, N2) + check(expr, N1, b) + check(expr, N2, b) + if iexpr: + check(iexpr, a, a) + check(iexpr, a, b) + check(iexpr, b, a) + check(iexpr, b, b) + check(iexpr, a, N1) + check(iexpr, a, N2) + iname = '__i' + name[2:] + C = metaclass('C', (), {iname: specialmethod}) + c = C() + check(iexpr, c, a) + check(iexpr, c, b) + check(iexpr, c, N1) + check(iexpr, c, N2) + def test_main(): weakref_segfault() # Must be first, somehow do_this_first() @@ -4084,6 +4155,7 @@ def test_main(): vicious_descriptor_nonsense() test_init() methodwrapper() + notimplemented() if verbose: print "All OK" diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py index 725b2d9..6cc7945 100644 --- a/Lib/test/test_operator.py +++ b/Lib/test/test_operator.py @@ -3,6 +3,34 @@ import unittest from test import test_support +class Seq1: + def __init__(self, lst): + self.lst = lst + def __len__(self): + return len(self.lst) + def __getitem__(self, i): + return self.lst[i] + def __add__(self, other): + return self.lst + other.lst + def __mul__(self, other): + return self.lst * other + def __rmul__(self, other): + return other * self.lst + +class Seq2(object): + def __init__(self, lst): + self.lst = lst + def __len__(self): + return len(self.lst) + def __getitem__(self, i): + return self.lst[i] + def __add__(self, other): + return self.lst + other.lst + def __mul__(self, other): + return self.lst * other + def __rmul__(self, other): + return other * self.lst + class OperatorTestCase(unittest.TestCase): def test_lt(self): @@ -92,6 +120,9 @@ class OperatorTestCase(unittest.TestCase): self.failUnlessRaises(TypeError, operator.concat, None, None) self.failUnless(operator.concat('py', 'thon') == 'python') self.failUnless(operator.concat([1, 2], [3, 4]) == [1, 2, 3, 4]) + self.failUnless(operator.concat(Seq1([5, 6]), Seq1([7])) == [5, 6, 7]) + self.failUnless(operator.concat(Seq2([5, 6]), Seq2([7])) == [5, 6, 7]) + self.failUnlessRaises(TypeError, operator.concat, 13, 29) def test_countOf(self): self.failUnlessRaises(TypeError, operator.countOf) @@ -246,6 +277,15 @@ class OperatorTestCase(unittest.TestCase): self.failUnless(operator.repeat(a, 2) == a+a) self.failUnless(operator.repeat(a, 1) == a) self.failUnless(operator.repeat(a, 0) == '') + a = Seq1([4, 5, 6]) + self.failUnless(operator.repeat(a, 2) == [4, 5, 6, 4, 5, 6]) + self.failUnless(operator.repeat(a, 1) == [4, 5, 6]) + self.failUnless(operator.repeat(a, 0) == []) + a = Seq2([4, 5, 6]) + self.failUnless(operator.repeat(a, 2) == [4, 5, 6, 4, 5, 6]) + self.failUnless(operator.repeat(a, 1) == [4, 5, 6]) + self.failUnless(operator.repeat(a, 0) == []) + self.failUnlessRaises(TypeError, operator.repeat, 6, 7) def test_rshift(self): self.failUnlessRaises(TypeError, operator.rshift) diff --git a/Objects/abstract.c b/Objects/abstract.c index 1f8feb5..6e070a9 100644 --- a/Objects/abstract.c +++ b/Objects/abstract.c @@ -635,14 +635,11 @@ PyNumber_Add(PyObject *v, PyObject *w) PyObject *result = binary_op1(v, w, NB_SLOT(nb_add)); if (result == Py_NotImplemented) { PySequenceMethods *m = v->ob_type->tp_as_sequence; + Py_DECREF(result); if (m && m->sq_concat) { - Py_DECREF(result); - result = (*m->sq_concat)(v, w); + return (*m->sq_concat)(v, w); } - if (result == Py_NotImplemented) { - Py_DECREF(result); - return binop_type_error(v, w, "+"); - } + result = binop_type_error(v, w, "+"); } return result; } @@ -1144,6 +1141,15 @@ PySequence_Concat(PyObject *s, PyObject *o) if (m && m->sq_concat) return m->sq_concat(s, o); + /* Instances of user classes defining an __add__() method only + have an nb_add slot, not an sq_concat slot. So we fall back + to nb_add if both arguments appear to be sequences. */ + if (PySequence_Check(s) && PySequence_Check(o)) { + PyObject *result = binary_op1(s, o, NB_SLOT(nb_add)); + if (result != Py_NotImplemented) + return result; + Py_DECREF(result); + } return type_error("object can't be concatenated"); } @@ -1159,6 +1165,20 @@ PySequence_Repeat(PyObject *o, int count) if (m && m->sq_repeat) return m->sq_repeat(o, count); + /* Instances of user classes defining a __mul__() method only + have an nb_multiply slot, not an sq_repeat slot. so we fall back + to nb_multiply if o appears to be a sequence. */ + if (PySequence_Check(o)) { + PyObject *n, *result; + n = PyInt_FromLong(count); + if (n == NULL) + return NULL; + result = binary_op1(o, n, NB_SLOT(nb_multiply)); + Py_DECREF(n); + if (result != Py_NotImplemented) + return result; + Py_DECREF(result); + } return type_error("object can't be repeated"); } @@ -1176,6 +1196,13 @@ PySequence_InPlaceConcat(PyObject *s, PyObject *o) if (m && m->sq_concat) return m->sq_concat(s, o); + if (PySequence_Check(s) && PySequence_Check(o)) { + PyObject *result = binary_iop1(s, o, NB_SLOT(nb_inplace_add), + NB_SLOT(nb_add)); + if (result != Py_NotImplemented) + return result; + Py_DECREF(result); + } return type_error("object can't be concatenated"); } @@ -1193,6 +1220,18 @@ PySequence_InPlaceRepeat(PyObject *o, int count) if (m && m->sq_repeat) return m->sq_repeat(o, count); + if (PySequence_Check(o)) { + PyObject *n, *result; + n = PyInt_FromLong(count); + if (n == NULL) + return NULL; + result = binary_iop1(o, n, NB_SLOT(nb_inplace_multiply), + NB_SLOT(nb_multiply)); + Py_DECREF(n); + if (result != Py_NotImplemented) + return result; + Py_DECREF(result); + } return type_error("object can't be repeated"); } diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 7c36ba4..b74fa1a 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -4095,9 +4095,6 @@ slot_sq_length(PyObject *self) return len; } -SLOT1(slot_sq_concat, "__add__", PyObject *, "O") -SLOT1(slot_sq_repeat, "__mul__", int, "i") - /* Super-optimized version of slot_sq_item. Other slots could do the same... */ static PyObject * @@ -4211,9 +4208,6 @@ slot_sq_contains(PyObject *self, PyObject *value) return result; } -SLOT1(slot_sq_inplace_concat, "__iadd__", PyObject *, "O") -SLOT1(slot_sq_inplace_repeat, "__imul__", int, "i") - #define slot_mp_length slot_sq_length SLOT1(slot_mp_subscript, "__getitem__", PyObject *, "O") @@ -4926,12 +4920,17 @@ typedef struct wrapperbase slotdef; static slotdef slotdefs[] = { SQSLOT("__len__", sq_length, slot_sq_length, wrap_inquiry, "x.__len__() <==> len(x)"), - SQSLOT("__add__", sq_concat, slot_sq_concat, wrap_binaryfunc, - "x.__add__(y) <==> x+y"), - SQSLOT("__mul__", sq_repeat, slot_sq_repeat, wrap_intargfunc, - "x.__mul__(n) <==> x*n"), - SQSLOT("__rmul__", sq_repeat, slot_sq_repeat, wrap_intargfunc, - "x.__rmul__(n) <==> n*x"), + /* Heap types defining __add__/__mul__ have sq_concat/sq_repeat == NULL. + The logic in abstract.c always falls back to nb_add/nb_multiply in + this case. Defining both the nb_* and the sq_* slots to call the + user-defined methods has unexpected side-effects, as shown by + test_descr.notimplemented() */ + SQSLOT("__add__", sq_concat, NULL, wrap_binaryfunc, + "x.__add__(y) <==> x+y"), + SQSLOT("__mul__", sq_repeat, NULL, wrap_intargfunc, + "x.__mul__(n) <==> x*n"), + SQSLOT("__rmul__", sq_repeat, NULL, wrap_intargfunc, + "x.__rmul__(n) <==> n*x"), SQSLOT("__getitem__", sq_item, slot_sq_item, wrap_sq_item, "x.__getitem__(y) <==> x[y]"), SQSLOT("__getslice__", sq_slice, slot_sq_slice, wrap_intintargfunc, @@ -4953,10 +4952,10 @@ static slotdef slotdefs[] = { Use of negative indices is not supported."), SQSLOT("__contains__", sq_contains, slot_sq_contains, wrap_objobjproc, "x.__contains__(y) <==> y in x"), - SQSLOT("__iadd__", sq_inplace_concat, slot_sq_inplace_concat, - wrap_binaryfunc, "x.__iadd__(y) <==> x+=y"), - SQSLOT("__imul__", sq_inplace_repeat, slot_sq_inplace_repeat, - wrap_intargfunc, "x.__imul__(y) <==> x*=y"), + SQSLOT("__iadd__", sq_inplace_concat, NULL, + wrap_binaryfunc, "x.__iadd__(y) <==> x+=y"), + SQSLOT("__imul__", sq_inplace_repeat, NULL, + wrap_intargfunc, "x.__imul__(y) <==> x*=y"), MPSLOT("__len__", mp_length, slot_mp_length, wrap_inquiry, "x.__len__() <==> len(x)"), |