summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Doc/library/sys.rst14
-rw-r--r--Include/ceval.h3
-rw-r--r--Include/pystate.h1
-rw-r--r--Lib/test/test_coroutines.py20
-rw-r--r--Python/ceval.c39
-rw-r--r--Python/pystate.c1
6 files changed, 68 insertions, 10 deletions
diff --git a/Doc/library/sys.rst b/Doc/library/sys.rst
index 3e8fd82..f9733b2 100644
--- a/Doc/library/sys.rst
+++ b/Doc/library/sys.rst
@@ -1085,6 +1085,20 @@ always available.
If called twice, the new wrapper replaces the previous one. The function
is thread-specific.
+ The *wrapper* callable cannot define new coroutines directly or indirectly::
+
+ def wrapper(coro):
+ async def wrap(coro):
+ return await coro
+ return wrap(coro)
+ sys.set_coroutine_wrapper(wrapper)
+
+ async def foo(): pass
+
+ # The following line will fail with a RuntimeError, because
+ # `wrapper` creates a `wrap(coro)` coroutine:
+ foo()
+
See also :func:`get_coroutine_wrapper`.
.. versionadded:: 3.5
diff --git a/Include/ceval.h b/Include/ceval.h
index e558594..9f4d3f1 100644
--- a/Include/ceval.h
+++ b/Include/ceval.h
@@ -23,8 +23,9 @@ PyAPI_FUNC(PyObject *) PyEval_CallMethod(PyObject *obj,
#ifndef Py_LIMITED_API
PyAPI_FUNC(void) PyEval_SetProfile(Py_tracefunc, PyObject *);
PyAPI_FUNC(void) PyEval_SetTrace(Py_tracefunc, PyObject *);
-PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *wrapper);
+PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *);
PyAPI_FUNC(PyObject *) _PyEval_GetCoroutineWrapper(void);
+PyAPI_FUNC(PyObject *) _PyEval_ApplyCoroutineWrapper(PyObject *);
#endif
struct _frame; /* Avoid including frameobject.h */
diff --git a/Include/pystate.h b/Include/pystate.h
index 2ee81df..a2fd803 100644
--- a/Include/pystate.h
+++ b/Include/pystate.h
@@ -135,6 +135,7 @@ typedef struct _ts {
void *on_delete_data;
PyObject *coroutine_wrapper;
+ int in_coroutine_wrapper;
/* XXX signal handlers should also be here */
diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py
index e79896a..670852d 100644
--- a/Lib/test/test_coroutines.py
+++ b/Lib/test/test_coroutines.py
@@ -995,6 +995,26 @@ class SysSetCoroWrapperTest(unittest.TestCase):
sys.set_coroutine_wrapper(1)
self.assertIsNone(sys.get_coroutine_wrapper())
+ def test_set_wrapper_3(self):
+ async def foo():
+ return 'spam'
+
+ def wrapper(coro):
+ async def wrap(coro):
+ return await coro
+ return wrap(coro)
+
+ sys.set_coroutine_wrapper(wrapper)
+ try:
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "coroutine wrapper.*\.wrapper at 0x.*attempted to "
+ "recursively wrap <coroutine.*\.wrap"):
+
+ foo()
+ finally:
+ sys.set_coroutine_wrapper(None)
+
class CAPITest(unittest.TestCase):
diff --git a/Python/ceval.c b/Python/ceval.c
index bb2c0b9..2a1db17 100644
--- a/Python/ceval.c
+++ b/Python/ceval.c
@@ -3921,7 +3921,6 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
if (co->co_flags & CO_GENERATOR) {
PyObject *gen;
- PyObject *coroutine_wrapper;
/* Don't need to keep the reference to f_back, it will be set
* when the generator is resumed. */
@@ -3935,14 +3934,9 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
if (gen == NULL)
return NULL;
- if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE)) {
- coroutine_wrapper = _PyEval_GetCoroutineWrapper();
- if (coroutine_wrapper != NULL) {
- PyObject *wrapped =
- PyObject_CallFunction(coroutine_wrapper, "N", gen);
- gen = wrapped;
- }
- }
+ if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE))
+ return _PyEval_ApplyCoroutineWrapper(gen);
+
return gen;
}
@@ -4408,6 +4402,33 @@ _PyEval_GetCoroutineWrapper(void)
}
PyObject *
+_PyEval_ApplyCoroutineWrapper(PyObject *gen)
+{
+ PyObject *wrapped;
+ PyThreadState *tstate = PyThreadState_GET();
+ PyObject *wrapper = tstate->coroutine_wrapper;
+
+ if (tstate->in_coroutine_wrapper) {
+ assert(wrapper != NULL);
+ PyErr_Format(PyExc_RuntimeError,
+ "coroutine wrapper %.150R attempted "
+ "to recursively wrap %.150R",
+ wrapper,
+ gen);
+ return NULL;
+ }
+
+ if (wrapper == NULL) {
+ return gen;
+ }
+
+ tstate->in_coroutine_wrapper = 1;
+ wrapped = PyObject_CallFunction(wrapper, "N", gen);
+ tstate->in_coroutine_wrapper = 0;
+ return wrapped;
+}
+
+PyObject *
PyEval_GetBuiltins(void)
{
PyFrameObject *current_frame = PyEval_GetFrame();
diff --git a/Python/pystate.c b/Python/pystate.c
index 4ac05d6..7e0267a 100644
--- a/Python/pystate.c
+++ b/Python/pystate.c
@@ -213,6 +213,7 @@ new_threadstate(PyInterpreterState *interp, int init)
tstate->on_delete_data = NULL;
tstate->coroutine_wrapper = NULL;
+ tstate->in_coroutine_wrapper = 0;
if (init)
_PyThreadState_Init(tstate);