From 67a05de17ca811459e0e856d8e51d0eaf0f76232 Mon Sep 17 00:00:00 2001 From: Irit Katriel <1055913+iritkatriel@users.noreply.github.com> Date: Thu, 4 Jul 2024 14:47:21 +0100 Subject: gh-121272: move async for/with validation from compiler to symtable (#121361) --- Python/compile.c | 13 ------------- Python/symtable.c | 22 ++++++++++++++++++++++ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/Python/compile.c b/Python/compile.c index 30708e1..1d6b54d 100644 --- a/Python/compile.c +++ b/Python/compile.c @@ -3058,11 +3058,6 @@ static int compiler_async_for(struct compiler *c, stmt_ty s) { location loc = LOC(s); - if (IS_TOP_LEVEL_AWAIT(c)){ - assert(c->u->u_ste->ste_coroutine == 1); - } else if (c->u->u_scope_type != COMPILER_SCOPE_ASYNC_FUNCTION) { - return compiler_error(c, loc, "'async for' outside async function"); - } NEW_JUMP_TARGET_LABEL(c, start); NEW_JUMP_TARGET_LABEL(c, except); @@ -5781,9 +5776,6 @@ compiler_comprehension(struct compiler *c, expr_ty e, int type, co = optimize_and_assemble(c, 1); compiler_exit_scope(c); - if (is_top_level_await && is_async_generator){ - assert(c->u->u_ste->ste_coroutine == 1); - } if (co == NULL) { goto error; } @@ -5925,11 +5917,6 @@ compiler_async_with(struct compiler *c, stmt_ty s, int pos) withitem_ty item = asdl_seq_GET(s->v.AsyncWith.items, pos); assert(s->kind == AsyncWith_kind); - if (IS_TOP_LEVEL_AWAIT(c)){ - assert(c->u->u_ste->ste_coroutine == 1); - } else if (c->u->u_scope_type != COMPILER_SCOPE_ASYNC_FUNCTION){ - return compiler_error(c, loc, "'async with' outside async function"); - } NEW_JUMP_TARGET_LABEL(c, block); NEW_JUMP_TARGET_LABEL(c, final); diff --git a/Python/symtable.c b/Python/symtable.c index 6ff0707..10103db 100644 --- a/Python/symtable.c +++ b/Python/symtable.c @@ -70,6 +70,11 @@ #define DUPLICATE_TYPE_PARAM \ "duplicate type parameter '%U'" +#define ASYNC_WITH_OUTISDE_ASYNC_FUNC \ +"'async with' outside async function" + +#define ASYNC_FOR_OUTISDE_ASYNC_FUNC \ +"'async for' outside async function" #define LOCATION(x) SRC_LOCATION_FROM_AST(x) @@ -251,6 +256,7 @@ static int symtable_visit_withitem(struct symtable *st, withitem_ty item); static int symtable_visit_match_case(struct symtable *st, match_case_ty m); static int symtable_visit_pattern(struct symtable *st, pattern_ty s); static int symtable_raise_if_annotation_block(struct symtable *st, const char *, expr_ty); +static int symtable_raise_if_not_coroutine(struct symtable *st, const char *msg, _Py_SourceLocation loc); static int symtable_raise_if_comprehension_block(struct symtable *st, expr_ty); /* For debugging purposes only */ @@ -2048,11 +2054,17 @@ symtable_visit_stmt(struct symtable *st, stmt_ty s) } case AsyncWith_kind: maybe_set_ste_coroutine_for_module(st, s); + if (!symtable_raise_if_not_coroutine(st, ASYNC_WITH_OUTISDE_ASYNC_FUNC, LOCATION(s))) { + VISIT_QUIT(st, 0); + } VISIT_SEQ(st, withitem, s->v.AsyncWith.items); VISIT_SEQ(st, stmt, s->v.AsyncWith.body); break; case AsyncFor_kind: maybe_set_ste_coroutine_for_module(st, s); + if (!symtable_raise_if_not_coroutine(st, ASYNC_FOR_OUTISDE_ASYNC_FUNC, LOCATION(s))) { + VISIT_QUIT(st, 0); + } VISIT(st, expr, s->v.AsyncFor.target); VISIT(st, expr, s->v.AsyncFor.iter); VISIT_SEQ(st, stmt, s->v.AsyncFor.body); @@ -2865,6 +2877,16 @@ symtable_raise_if_comprehension_block(struct symtable *st, expr_ty e) { VISIT_QUIT(st, 0); } +static int +symtable_raise_if_not_coroutine(struct symtable *st, const char *msg, _Py_SourceLocation loc) { + if (!st->st_cur->ste_coroutine) { + PyErr_SetString(PyExc_SyntaxError, msg); + SET_ERROR_LOCATION(st->st_filename, loc); + return 0; + } + return 1; +} + struct symtable * _Py_SymtableStringObjectFlags(const char *str, PyObject *filename, int start, PyCompilerFlags *flags) -- cgit v0.12