summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Include/genobject.h1
-rw-r--r--Lib/test/test_coroutines.py18
-rw-r--r--Objects/genobject.c12
-rw-r--r--Python/ceval.c15
4 files changed, 40 insertions, 6 deletions
diff --git a/Include/genobject.h b/Include/genobject.h
index 4c71861..30cb023 100644
--- a/Include/genobject.h
+++ b/Include/genobject.h
@@ -43,6 +43,7 @@ PyAPI_FUNC(PyObject *) PyGen_NewWithQualName(struct _frame *,
PyAPI_FUNC(int) PyGen_NeedsFinalizing(PyGenObject *);
PyAPI_FUNC(int) _PyGen_FetchStopIterationValue(PyObject **);
PyObject *_PyGen_Send(PyGenObject *, PyObject *);
+PyObject *_PyGen_yf(PyGenObject *);
PyAPI_FUNC(void) _PyGen_Finalize(PyObject *self);
#ifndef Py_LIMITED_API
diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py
index 954a9a1..187348d 100644
--- a/Lib/test/test_coroutines.py
+++ b/Lib/test/test_coroutines.py
@@ -942,6 +942,24 @@ class CoroutineTest(unittest.TestCase):
with self.assertRaises(Marker):
c.throw(ZeroDivisionError)
+ def test_await_15(self):
+ @types.coroutine
+ def nop():
+ yield
+
+ async def coroutine():
+ await nop()
+
+ async def waiter(coro):
+ await coro
+
+ coro = coroutine()
+ coro.send(None)
+
+ with self.assertRaisesRegex(RuntimeError,
+ "coroutine is being awaited already"):
+ waiter(coro).send(None)
+
def test_with_1(self):
class Manager:
def __init__(self, name):
diff --git a/Objects/genobject.c b/Objects/genobject.c
index 72d44c1..fdb3c03 100644
--- a/Objects/genobject.c
+++ b/Objects/genobject.c
@@ -267,8 +267,8 @@ gen_close_iter(PyObject *yf)
return 0;
}
-static PyObject *
-gen_yf(PyGenObject *gen)
+PyObject *
+_PyGen_yf(PyGenObject *gen)
{
PyObject *yf = NULL;
PyFrameObject *f = gen->gi_frame;
@@ -290,7 +290,7 @@ static PyObject *
gen_close(PyGenObject *gen, PyObject *args)
{
PyObject *retval;
- PyObject *yf = gen_yf(gen);
+ PyObject *yf = _PyGen_yf(gen);
int err = 0;
if (yf) {
@@ -330,7 +330,7 @@ gen_throw(PyGenObject *gen, PyObject *args)
PyObject *typ;
PyObject *tb = NULL;
PyObject *val = NULL;
- PyObject *yf = gen_yf(gen);
+ PyObject *yf = _PyGen_yf(gen);
_Py_IDENTIFIER(throw);
if (!PyArg_UnpackTuple(args, "throw", 1, 3, &typ, &val, &tb))
@@ -556,7 +556,7 @@ gen_set_qualname(PyGenObject *op, PyObject *value)
static PyObject *
gen_getyieldfrom(PyGenObject *gen)
{
- PyObject *yf = gen_yf(gen);
+ PyObject *yf = _PyGen_yf(gen);
if (yf == NULL)
Py_RETURN_NONE;
return yf;
@@ -791,7 +791,7 @@ coro_await(PyCoroObject *coro)
static PyObject *
coro_get_cr_await(PyCoroObject *coro)
{
- PyObject *yf = gen_yf((PyGenObject *) coro);
+ PyObject *yf = _PyGen_yf((PyGenObject *) coro);
if (yf == NULL)
Py_RETURN_NONE;
return yf;
diff --git a/Python/ceval.c b/Python/ceval.c
index 8904d7a..7b07475 100644
--- a/Python/ceval.c
+++ b/Python/ceval.c
@@ -2021,6 +2021,21 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
Py_DECREF(iterable);
+ if (iter != NULL && PyCoro_CheckExact(iter)) {
+ PyObject *yf = _PyGen_yf((PyGenObject*)iter);
+ if (yf != NULL) {
+ /* `iter` is a coroutine object that is being
+ awaited, `yf` is a pointer to the current awaitable
+ being awaited on. */
+ Py_DECREF(yf);
+ Py_CLEAR(iter);
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ "coroutine is being awaited already");
+ /* The code below jumps to `error` if `iter` is NULL. */
+ }
+ }
+
SET_TOP(iter); /* Even if it's NULL */
if (iter == NULL) {