diff options
-rw-r--r-- | Include/genobject.h | 1 | ||||
-rw-r--r-- | Lib/test/test_coroutines.py | 18 | ||||
-rw-r--r-- | Objects/genobject.c | 12 | ||||
-rw-r--r-- | Python/ceval.c | 15 |
4 files changed, 40 insertions, 6 deletions
diff --git a/Include/genobject.h b/Include/genobject.h index 4c71861..30cb023 100644 --- a/Include/genobject.h +++ b/Include/genobject.h @@ -43,6 +43,7 @@ PyAPI_FUNC(PyObject *) PyGen_NewWithQualName(struct _frame *, PyAPI_FUNC(int) PyGen_NeedsFinalizing(PyGenObject *); PyAPI_FUNC(int) _PyGen_FetchStopIterationValue(PyObject **); PyObject *_PyGen_Send(PyGenObject *, PyObject *); +PyObject *_PyGen_yf(PyGenObject *); PyAPI_FUNC(void) _PyGen_Finalize(PyObject *self); #ifndef Py_LIMITED_API diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py index 954a9a1..187348d 100644 --- a/Lib/test/test_coroutines.py +++ b/Lib/test/test_coroutines.py @@ -942,6 +942,24 @@ class CoroutineTest(unittest.TestCase): with self.assertRaises(Marker): c.throw(ZeroDivisionError) + def test_await_15(self): + @types.coroutine + def nop(): + yield + + async def coroutine(): + await nop() + + async def waiter(coro): + await coro + + coro = coroutine() + coro.send(None) + + with self.assertRaisesRegex(RuntimeError, + "coroutine is being awaited already"): + waiter(coro).send(None) + def test_with_1(self): class Manager: def __init__(self, name): diff --git a/Objects/genobject.c b/Objects/genobject.c index 72d44c1..fdb3c03 100644 --- a/Objects/genobject.c +++ b/Objects/genobject.c @@ -267,8 +267,8 @@ gen_close_iter(PyObject *yf) return 0; } -static PyObject * -gen_yf(PyGenObject *gen) +PyObject * +_PyGen_yf(PyGenObject *gen) { PyObject *yf = NULL; PyFrameObject *f = gen->gi_frame; @@ -290,7 +290,7 @@ static PyObject * gen_close(PyGenObject *gen, PyObject *args) { PyObject *retval; - PyObject *yf = gen_yf(gen); + PyObject *yf = _PyGen_yf(gen); int err = 0; if (yf) { @@ -330,7 +330,7 @@ gen_throw(PyGenObject *gen, PyObject *args) PyObject *typ; PyObject *tb = NULL; PyObject *val = NULL; - PyObject *yf = gen_yf(gen); + PyObject *yf = _PyGen_yf(gen); _Py_IDENTIFIER(throw); if (!PyArg_UnpackTuple(args, "throw", 1, 3, &typ, &val, &tb)) @@ -556,7 +556,7 @@ gen_set_qualname(PyGenObject *op, PyObject *value) static PyObject * gen_getyieldfrom(PyGenObject *gen) { - PyObject *yf = gen_yf(gen); + PyObject *yf = _PyGen_yf(gen); if (yf == NULL) Py_RETURN_NONE; return yf; @@ -791,7 +791,7 @@ coro_await(PyCoroObject *coro) static PyObject * coro_get_cr_await(PyCoroObject *coro) { - PyObject *yf = gen_yf((PyGenObject *) coro); + PyObject *yf = _PyGen_yf((PyGenObject *) coro); if (yf == NULL) Py_RETURN_NONE; return yf; diff --git a/Python/ceval.c b/Python/ceval.c index 8904d7a..7b07475 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -2021,6 +2021,21 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag) Py_DECREF(iterable); + if (iter != NULL && PyCoro_CheckExact(iter)) { + PyObject *yf = _PyGen_yf((PyGenObject*)iter); + if (yf != NULL) { + /* `iter` is a coroutine object that is being + awaited, `yf` is a pointer to the current awaitable + being awaited on. */ + Py_DECREF(yf); + Py_CLEAR(iter); + PyErr_SetString( + PyExc_RuntimeError, + "coroutine is being awaited already"); + /* The code below jumps to `error` if `iter` is NULL. */ + } + } + SET_TOP(iter); /* Even if it's NULL */ if (iter == NULL) { |