diff options
-rw-r--r-- | Lib/test/test_coroutines.py | 4 | ||||
-rw-r--r-- | Python/ceval.c | 50 |
2 files changed, 21 insertions, 33 deletions
diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py index 670852d..e3a3304 100644 --- a/Lib/test/test_coroutines.py +++ b/Lib/test/test_coroutines.py @@ -1006,10 +1006,10 @@ class SysSetCoroWrapperTest(unittest.TestCase): sys.set_coroutine_wrapper(wrapper) try: - with self.assertRaisesRegex( + with silence_coro_gc(), self.assertRaisesRegex( RuntimeError, "coroutine wrapper.*\.wrapper at 0x.*attempted to " - "recursively wrap <coroutine.*\.wrap"): + "recursively wrap .* wrap .*"): foo() finally: diff --git a/Python/ceval.c b/Python/ceval.c index 96ed6ed..641f9db 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -146,8 +146,6 @@ static void format_exc_unbound(PyCodeObject *co, int oparg); static PyObject * unicode_concatenate(PyObject *, PyObject *, PyFrameObject *, unsigned char *); static PyObject * special_lookup(PyObject *, _Py_Identifier *); -static PyObject * apply_coroutine_wrapper(PyObject *); - #define NAME_ERROR_MSG \ "name '%.200s' is not defined" @@ -3923,6 +3921,18 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals, if (co->co_flags & CO_GENERATOR) { PyObject *gen; + PyObject *coro_wrapper = tstate->coroutine_wrapper; + int is_coro = co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE); + + if (is_coro && tstate->in_coroutine_wrapper) { + assert(coro_wrapper != NULL); + PyErr_Format(PyExc_RuntimeError, + "coroutine wrapper %.200R attempted " + "to recursively wrap %.200R", + coro_wrapper, + co); + goto fail; + } /* Don't need to keep the reference to f_back, it will be set * when the generator is resumed. */ @@ -3936,8 +3946,13 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals, if (gen == NULL) return NULL; - if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE)) - return apply_coroutine_wrapper(gen); + if (is_coro && coro_wrapper != NULL) { + PyObject *wrapped; + tstate->in_coroutine_wrapper = 1; + wrapped = PyObject_CallFunction(coro_wrapper, "N", gen); + tstate->in_coroutine_wrapper = 0; + return wrapped; + } return gen; } @@ -5232,33 +5247,6 @@ unicode_concatenate(PyObject *v, PyObject *w, return res; } -static PyObject * -apply_coroutine_wrapper(PyObject *gen) -{ - PyObject *wrapped; - PyThreadState *tstate = PyThreadState_GET(); - PyObject *wrapper = tstate->coroutine_wrapper; - - if (tstate->in_coroutine_wrapper) { - assert(wrapper != NULL); - PyErr_Format(PyExc_RuntimeError, - "coroutine wrapper %.200R attempted " - "to recursively wrap %.200R", - wrapper, - gen); - return NULL; - } - - if (wrapper == NULL) { - return gen; - } - - tstate->in_coroutine_wrapper = 1; - wrapped = PyObject_CallFunction(wrapper, "N", gen); - tstate->in_coroutine_wrapper = 0; - return wrapped; -} - #ifdef DYNAMIC_EXECUTION_PROFILE static PyObject * |