From 2b3eb4062c5e50abf854f7e68038243ca7c07217 Mon Sep 17 00:00:00 2001 From: Armin Rigo Date: Tue, 28 Oct 2003 12:05:48 +0000 Subject: Deleting cyclic object comparison. SF patch 825639 http://mail.python.org/pipermail/python-dev/2003-October/039445.html --- Include/ceval.h | 14 ++++ Lib/test/pickletester.py | 25 +++---- Lib/test/test_builtin.py | 12 +-- Lib/test/test_copy.py | 12 +-- Lib/test/test_richcmp.py | 63 ++++++---------- Misc/NEWS | 4 + Objects/classobject.c | 10 +-- Objects/object.c | 188 ++--------------------------------------------- Python/ceval.c | 56 +++++++++----- 9 files changed, 109 insertions(+), 275 deletions(-) diff --git a/Include/ceval.h b/Include/ceval.h index 411cf3e..dc3864b 100644 --- a/Include/ceval.h +++ b/Include/ceval.h @@ -43,9 +43,23 @@ PyAPI_FUNC(int) Py_FlushLine(void); PyAPI_FUNC(int) Py_AddPendingCall(int (*func)(void *), void *arg); PyAPI_FUNC(int) Py_MakePendingCalls(void); +/* Protection against deeply nested recursive calls */ PyAPI_FUNC(void) Py_SetRecursionLimit(int); PyAPI_FUNC(int) Py_GetRecursionLimit(void); +#define Py_EnterRecursiveCall(where) \ + (_Py_MakeRecCheck(PyThreadState_GET()->recursion_depth) && \ + _Py_CheckRecursiveCall(where)) +#define Py_LeaveRecursiveCall() \ + (--PyThreadState_GET()->recursion_depth) +PyAPI_FUNC(int) _Py_CheckRecursiveCall(char *where); +PyAPI_DATA(int) _Py_CheckRecursionLimit; +#ifdef USE_STACKCHECK +# define _Py_MakeRecCheck(x) (++(x) > --_Py_CheckRecursionLimit) +#else +# define _Py_MakeRecCheck(x) (++(x) > _Py_CheckRecursionLimit) +#endif + PyAPI_FUNC(char *) PyEval_GetFuncName(PyObject *); PyAPI_FUNC(char *) PyEval_GetFuncDesc(PyObject *); diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index cf1bb37..6e6d97d 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -424,9 +424,8 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(l, proto) x = self.loads(s) - self.assertEqual(x, l) - self.assertEqual(x, x[0]) - self.assertEqual(id(x), id(x[0])) + self.assertEqual(len(x), 1) + self.assert_(x is x[0]) def test_recursive_dict(self): d = {} @@ -434,9 +433,8 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(d, proto) x = self.loads(s) - self.assertEqual(x, d) - self.assertEqual(x[1], x) - self.assertEqual(id(x[1]), id(x)) + self.assertEqual(x.keys(), [1]) + self.assert_(x[1] is x) def test_recursive_inst(self): i = C() @@ -444,9 +442,8 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(i, 2) x = self.loads(s) - self.assertEqual(x, i) - self.assertEqual(x.attr, x) - self.assertEqual(id(x.attr), id(x)) + self.assertEqual(dir(x), dir(i)) + self.assert_(x.attr is x) def test_recursive_multi(self): l = [] @@ -457,12 +454,10 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(l, proto) x = self.loads(s) - self.assertEqual(x, l) - self.assertEqual(x[0], i) - self.assertEqual(x[0].attr, d) - self.assertEqual(x[0].attr[1], x) - self.assertEqual(x[0].attr[1][0], i) - self.assertEqual(x[0].attr[1][0].attr, d) + self.assertEqual(len(x), 1) + self.assertEqual(dir(x[0]), dir(i)) + self.assertEqual(x[0].attr.keys(), [1]) + self.assert_(x[0].attr[1] is x) def test_garyp(self): self.assertRaises(self.error, self.loads, 'garyp') diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index 6521634..e84cfbd 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -167,16 +167,16 @@ class BuiltinTest(unittest.TestCase): self.assertEqual(cmp(-1, 1), -1) self.assertEqual(cmp(1, -1), 1) self.assertEqual(cmp(1, 1), 0) - # verify that circular objects are handled + # verify that circular objects are not handled a = []; a.append(a) b = []; b.append(b) from UserList import UserList c = UserList(); c.append(c) - self.assertEqual(cmp(a, b), 0) - self.assertEqual(cmp(b, c), 0) - self.assertEqual(cmp(c, a), 0) - self.assertEqual(cmp(a, c), 0) - # okay, now break the cycles + self.assertRaises(RuntimeError, cmp, a, b) + self.assertRaises(RuntimeError, cmp, b, c) + self.assertRaises(RuntimeError, cmp, c, a) + self.assertRaises(RuntimeError, cmp, a, c) + # okay, now break the cycles a.pop(); b.pop(); c.pop() self.assertRaises(TypeError, cmp) diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py index 3d44304..6e32ddd 100644 --- a/Lib/test/test_copy.py +++ b/Lib/test/test_copy.py @@ -272,10 +272,10 @@ class TestCopy(unittest.TestCase): x = [] x.append(x) y = copy.deepcopy(x) - self.assertEqual(y, x) + self.assertRaises(RuntimeError, cmp, y, x) self.assert_(y is not x) - self.assert_(y[0] is not x[0]) - self.assert_(y is y[0]) + self.assert_(y[0] is y) + self.assertEqual(len(y), 1) def test_deepcopy_tuple(self): x = ([1, 2], 3) @@ -288,7 +288,7 @@ class TestCopy(unittest.TestCase): x = ([],) x[0].append(x) y = copy.deepcopy(x) - self.assertEqual(y, x) + self.assertRaises(RuntimeError, cmp, y, x) self.assert_(y is not x) self.assert_(y[0] is not x[0]) self.assert_(y[0][0] is y) @@ -304,10 +304,10 @@ class TestCopy(unittest.TestCase): x = {} x['foo'] = x y = copy.deepcopy(x) - self.assertEqual(y, x) + self.assertRaises(RuntimeError, cmp, y, x) self.assert_(y is not x) self.assert_(y['foo'] is y) - self.assertEqual(y, {'foo': y}) + self.assertEqual(len(y), 1) def test_deepcopy_keepalive(self): memo = {} diff --git a/Lib/test/test_richcmp.py b/Lib/test/test_richcmp.py index 5ade8ed..006b152 100644 --- a/Lib/test/test_richcmp.py +++ b/Lib/test/test_richcmp.py @@ -224,57 +224,36 @@ class MiscTest(unittest.TestCase): self.assertRaises(Exc, func, Bad()) def test_recursion(self): - # Check comparison for recursive objects + # Check that comparison for recursive objects fails gracefully from UserList import UserList - a = UserList(); a.append(a) - b = UserList(); b.append(b) - - self.assert_(a == b) - self.assert_(not a != b) - a.append(1) - self.assert_(a == a[0]) - self.assert_(not a != a[0]) - self.assert_(a != b) - self.assert_(not a == b) - b.append(0) - self.assert_(a != b) - self.assert_(not a == b) - a[1] = -1 - self.assert_(a != b) - self.assert_(not a == b) - a = UserList() b = UserList() a.append(b) b.append(a) - self.assert_(a == b) - self.assert_(not a != b) + self.assertRaises(RuntimeError, operator.eq, a, b) + self.assertRaises(RuntimeError, operator.ne, a, b) + self.assertRaises(RuntimeError, operator.lt, a, b) + self.assertRaises(RuntimeError, operator.le, a, b) + self.assertRaises(RuntimeError, operator.gt, a, b) + self.assertRaises(RuntimeError, operator.ge, a, b) b.append(17) + # Even recursive lists of different lengths are different, + # but they cannot be ordered + self.assert_(not (a == b)) self.assert_(a != b) - self.assert_(not a == b) + self.assertRaises(RuntimeError, operator.lt, a, b) + self.assertRaises(RuntimeError, operator.le, a, b) + self.assertRaises(RuntimeError, operator.gt, a, b) + self.assertRaises(RuntimeError, operator.ge, a, b) a.append(17) - self.assert_(a == b) - self.assert_(not a != b) - - def test_recursion2(self): - # This test exercises the circular structure handling code - # in PyObject_RichCompare() - class Weird(object): - def __eq__(self, other): - return self != other - def __ne__(self, other): - return self == other - def __lt__(self, other): - return self > other - def __gt__(self, other): - return self < other - - self.assert_(Weird() == Weird()) - self.assert_(not (Weird() != Weird())) - - for op in opmap["lt"]: - self.assertRaises(ValueError, op, Weird(), Weird()) + self.assertRaises(RuntimeError, operator.eq, a, b) + self.assertRaises(RuntimeError, operator.ne, a, b) + a.insert(0, 11) + b.insert(0, 12) + self.assert_(not (a == b)) + self.assert_(a != b) + self.assert_(a < b) class DictTest(unittest.TestCase): diff --git a/Misc/NEWS b/Misc/NEWS index 74096c0..a16f119 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -47,6 +47,10 @@ Core and builtins - obj.__contains__() now returns True/False instead of 1/0. SF patch 820195. +- Python no longer tries to be smart about recursive comparisons. + When comparing containers with cyclic references to themselves it + will now just hit the recursion limit. See SF patch 825639. + Extension modules ----------------- diff --git a/Objects/classobject.c b/Objects/classobject.c index b0e1934..84b297c 100644 --- a/Objects/classobject.c +++ b/Objects/classobject.c @@ -1970,7 +1970,6 @@ instance_iternext(PyInstanceObject *self) static PyObject * instance_call(PyObject *func, PyObject *arg, PyObject *kw) { - PyThreadState *tstate = PyThreadState_GET(); PyObject *res, *call = PyObject_GetAttrString(func, "__call__"); if (call == NULL) { PyInstanceObject *inst = (PyInstanceObject*) func; @@ -1990,14 +1989,13 @@ instance_call(PyObject *func, PyObject *arg, PyObject *kw) a() # infinite recursion This bounces between instance_call() and PyObject_Call() without ever hitting eval_frame() (which has the main recursion check). */ - if (tstate->recursion_depth++ > Py_GetRecursionLimit()) { - PyErr_SetString(PyExc_RuntimeError, - "maximum __call__ recursion depth exceeded"); + if (Py_EnterRecursiveCall(" in __call__")) { res = NULL; } - else + else { res = PyObject_Call(call, arg, kw); - tstate->recursion_depth--; + Py_LeaveRecursiveCall(); + } Py_DECREF(call); return res; } diff --git a/Objects/object.c b/Objects/object.c index 8c4bd0e..d85f697 100644 --- a/Objects/object.c +++ b/Objects/object.c @@ -740,120 +740,6 @@ do_cmp(PyObject *v, PyObject *w) return default_3way_compare(v, w); } -/* compare_nesting is incremented before calling compare (for - some types) and decremented on exit. If the count exceeds the - nesting limit, enable code to detect circular data structures. - - This is a tunable parameter that should only affect the performance - of comparisons, nothing else. Setting it high makes comparing deeply - nested non-cyclical data structures faster, but makes comparing cyclical - data structures slower. -*/ -#define NESTING_LIMIT 20 - -static int compare_nesting = 0; - -static PyObject* -get_inprogress_dict(void) -{ - static PyObject *key; - PyObject *tstate_dict, *inprogress; - - if (key == NULL) { - key = PyString_InternFromString("cmp_state"); - if (key == NULL) - return NULL; - } - - tstate_dict = PyThreadState_GetDict(); - if (tstate_dict == NULL) { - PyErr_BadInternalCall(); - return NULL; - } - - inprogress = PyDict_GetItem(tstate_dict, key); - if (inprogress == NULL) { - inprogress = PyDict_New(); - if (inprogress == NULL) - return NULL; - if (PyDict_SetItem(tstate_dict, key, inprogress) == -1) { - Py_DECREF(inprogress); - return NULL; - } - Py_DECREF(inprogress); - } - - return inprogress; -} - -/* If the comparison "v op w" is already in progress in this thread, returns - * a borrowed reference to Py_None (the caller must not decref). - * If it's not already in progress, returns "a token" which must eventually - * be passed to delete_token(). The caller must not decref this either - * (delete_token decrefs it). The token must not survive beyond any point - * where v or w may die. - * If an error occurs (out-of-memory), returns NULL. - */ -static PyObject * -check_recursion(PyObject *v, PyObject *w, int op) -{ - PyObject *inprogress; - PyObject *token; - Py_uintptr_t iv = (Py_uintptr_t)v; - Py_uintptr_t iw = (Py_uintptr_t)w; - PyObject *x, *y, *z; - - inprogress = get_inprogress_dict(); - if (inprogress == NULL) - return NULL; - - token = PyTuple_New(3); - if (token == NULL) - return NULL; - - if (iv <= iw) { - PyTuple_SET_ITEM(token, 0, x = PyLong_FromVoidPtr((void *)v)); - PyTuple_SET_ITEM(token, 1, y = PyLong_FromVoidPtr((void *)w)); - if (op >= 0) - op = swapped_op[op]; - } else { - PyTuple_SET_ITEM(token, 0, x = PyLong_FromVoidPtr((void *)w)); - PyTuple_SET_ITEM(token, 1, y = PyLong_FromVoidPtr((void *)v)); - } - PyTuple_SET_ITEM(token, 2, z = PyInt_FromLong((long)op)); - if (x == NULL || y == NULL || z == NULL) { - Py_DECREF(token); - return NULL; - } - - if (PyDict_GetItem(inprogress, token) != NULL) { - Py_DECREF(token); - return Py_None; /* Without INCREF! */ - } - - if (PyDict_SetItem(inprogress, token, token) < 0) { - Py_DECREF(token); - return NULL; - } - - return token; -} - -static void -delete_token(PyObject *token) -{ - PyObject *inprogress; - - if (token == NULL || token == Py_None) - return; - inprogress = get_inprogress_dict(); - if (inprogress == NULL) - PyErr_Clear(); - else - PyDict_DelItem(inprogress, token); - Py_DECREF(token); -} - /* Compare v to w. Return -1 if v < w or exception (PyErr_Occurred() true in latter case). 0 if v == w. @@ -867,12 +753,6 @@ PyObject_Compare(PyObject *v, PyObject *w) PyTypeObject *vtp; int result; -#if defined(USE_STACKCHECK) - if (PyOS_CheckStack()) { - PyErr_SetString(PyExc_MemoryError, "Stack overflow"); - return -1; - } -#endif if (v == NULL || w == NULL) { PyErr_BadInternalCall(); return -1; @@ -880,31 +760,10 @@ PyObject_Compare(PyObject *v, PyObject *w) if (v == w) return 0; vtp = v->ob_type; - compare_nesting++; - if (compare_nesting > NESTING_LIMIT && - (vtp->tp_as_mapping || vtp->tp_as_sequence) && - !PyString_CheckExact(v) && - !PyTuple_CheckExact(v)) { - /* try to detect circular data structures */ - PyObject *token = check_recursion(v, w, -1); - - if (token == NULL) { - result = -1; - } - else if (token == Py_None) { - /* already comparing these objects. assume - they're equal until shown otherwise */ - result = 0; - } - else { - result = do_cmp(v, w); - delete_token(token); - } - } - else { - result = do_cmp(v, w); - } - compare_nesting--; + if (Py_EnterRecursiveCall(" in cmp")) + return -1; + result = do_cmp(v, w); + Py_LeaveRecursiveCall(); return result < 0 ? -1 : result; } @@ -975,41 +834,10 @@ PyObject_RichCompare(PyObject *v, PyObject *w, int op) PyObject *res; assert(Py_LT <= op && op <= Py_GE); + if (Py_EnterRecursiveCall(" in cmp")) + return NULL; - compare_nesting++; - if (compare_nesting > NESTING_LIMIT && - (v->ob_type->tp_as_mapping || v->ob_type->tp_as_sequence) && - !PyString_CheckExact(v) && - !PyTuple_CheckExact(v)) { - /* try to detect circular data structures */ - PyObject *token = check_recursion(v, w, op); - if (token == NULL) { - res = NULL; - goto Done; - } - else if (token == Py_None) { - /* already comparing these objects with this operator. - assume they're equal until shown otherwise */ - if (op == Py_EQ) - res = Py_True; - else if (op == Py_NE) - res = Py_False; - else { - PyErr_SetString(PyExc_ValueError, - "can't order recursive values"); - res = NULL; - } - Py_XINCREF(res); - } - else { - res = do_richcmp(v, w, op); - delete_token(token); - } - goto Done; - } - - /* No nesting extremism. - If the types are equal, and not old-style instances, try to + /* If the types are equal, and not old-style instances, try to get out cheap (don't bother with coercions etc.). */ if (v->ob_type == w->ob_type && !PyInstance_Check(v)) { cmpfunc fcmp; @@ -1041,7 +869,7 @@ PyObject_RichCompare(PyObject *v, PyObject *w, int op) /* Fast path not taken, or couldn't deliver a useful result. */ res = do_richcmp(v, w, op); Done: - compare_nesting--; + Py_LeaveRecursiveCall(); return res; } diff --git a/Python/ceval.c b/Python/ceval.c index e6b7424..fe8aca5 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -497,6 +497,7 @@ Py_MakePendingCalls(void) /* The interpreter's recursion limit */ static int recursion_limit = 1000; +int _Py_CheckRecursionLimit = 1000; int Py_GetRecursionLimit(void) @@ -508,8 +509,38 @@ void Py_SetRecursionLimit(int new_limit) { recursion_limit = new_limit; + _Py_CheckRecursionLimit = recursion_limit; } +/* the macro Py_EnterRecursiveCall() only calls _Py_CheckRecursiveCall() + if the recursion_depth reaches _Py_CheckRecursionLimit. + If USE_STACKCHECK, the macro decrements _Py_CheckRecursionLimit + to guarantee that _Py_CheckRecursiveCall() is regularly called. + Without USE_STACKCHECK, there is no need for this. */ +int +_Py_CheckRecursiveCall(char *where) +{ + PyThreadState *tstate = PyThreadState_GET(); + +#ifdef USE_STACKCHECK + if (PyOS_CheckStack()) { + --tstate->recursion_depth; + PyErr_SetString(PyExc_MemoryError, "Stack overflow"); + return -1; + } +#endif + if (tstate->recursion_depth > recursion_limit) { + --tstate->recursion_depth; + PyErr_Format(PyExc_RuntimeError, + "maximum recursion depth exceeded%s", + where); + return -1; + } + _Py_CheckRecursionLimit = recursion_limit; + return 0; +} + + /* Status code for main loop (reason for stack unwind) */ enum why_code { @@ -674,21 +705,9 @@ eval_frame(PyFrameObject *f) if (f == NULL) return NULL; -#ifdef USE_STACKCHECK - if (tstate->recursion_depth%10 == 0 && PyOS_CheckStack()) { - PyErr_SetString(PyExc_MemoryError, "Stack overflow"); - return NULL; - } -#endif - /* push frame */ - if (++tstate->recursion_depth > recursion_limit) { - --tstate->recursion_depth; - PyErr_SetString(PyExc_RuntimeError, - "maximum recursion depth exceeded"); - tstate->frame = f->f_back; + if (Py_EnterRecursiveCall("")) return NULL; - } tstate->frame = f; @@ -710,9 +729,7 @@ eval_frame(PyFrameObject *f) if (call_trace(tstate->c_tracefunc, tstate->c_traceobj, f, PyTrace_CALL, Py_None)) { /* Trace function raised an error */ - --tstate->recursion_depth; - tstate->frame = f->f_back; - return NULL; + goto exit_eval_frame; } } if (tstate->c_profilefunc != NULL) { @@ -722,9 +739,7 @@ eval_frame(PyFrameObject *f) tstate->c_profileobj, f, PyTrace_CALL, Py_None)) { /* Profile function raised an error */ - --tstate->recursion_depth; - tstate->frame = f->f_back; - return NULL; + goto exit_eval_frame; } } } @@ -2428,7 +2443,8 @@ eval_frame(PyFrameObject *f) reset_exc_info(tstate); /* pop frame */ - --tstate->recursion_depth; + exit_eval_frame: + Py_LeaveRecursiveCall(); tstate->frame = f->f_back; return retval; -- cgit v0.12