From 93156880efd14ad7adc7d3512552b434f5543890 Mon Sep 17 00:00:00 2001 From: Irit Katriel <1055913+iritkatriel@users.noreply.github.com> Date: Wed, 3 Jul 2024 10:18:34 +0100 Subject: gh-121272: set ste_coroutine during symtable construction (#121297) compiler no longer modifies the symtable after this. --- Python/compile.c | 6 +++--- Python/symtable.c | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/Python/compile.c b/Python/compile.c index d33db69..30708e1 100644 --- a/Python/compile.c +++ b/Python/compile.c @@ -3059,7 +3059,7 @@ compiler_async_for(struct compiler *c, stmt_ty s) { location loc = LOC(s); if (IS_TOP_LEVEL_AWAIT(c)){ - c->u->u_ste->ste_coroutine = 1; + assert(c->u->u_ste->ste_coroutine == 1); } else if (c->u->u_scope_type != COMPILER_SCOPE_ASYNC_FUNCTION) { return compiler_error(c, loc, "'async for' outside async function"); } @@ -5782,7 +5782,7 @@ compiler_comprehension(struct compiler *c, expr_ty e, int type, co = optimize_and_assemble(c, 1); compiler_exit_scope(c); if (is_top_level_await && is_async_generator){ - c->u->u_ste->ste_coroutine = 1; + assert(c->u->u_ste->ste_coroutine == 1); } if (co == NULL) { goto error; @@ -5926,7 +5926,7 @@ compiler_async_with(struct compiler *c, stmt_ty s, int pos) assert(s->kind == AsyncWith_kind); if (IS_TOP_LEVEL_AWAIT(c)){ - c->u->u_ste->ste_coroutine = 1; + assert(c->u->u_ste->ste_coroutine == 1); } else if (c->u->u_scope_type != COMPILER_SCOPE_ASYNC_FUNCTION){ return compiler_error(c, loc, "'async with' outside async function"); } diff --git a/Python/symtable.c b/Python/symtable.c index 61fa5c6..65677f8 100644 --- a/Python/symtable.c +++ b/Python/symtable.c @@ -1681,6 +1681,16 @@ check_import_from(struct symtable *st, stmt_ty s) return 1; } +static void +maybe_set_ste_coroutine_for_module(struct symtable *st, stmt_ty s) +{ + if ((st->st_future->ff_features & PyCF_ALLOW_TOP_LEVEL_AWAIT) && + (st->st_cur->ste_type == ModuleBlock)) + { + st->st_cur->ste_coroutine = 1; + } +} + static int symtable_visit_stmt(struct symtable *st, stmt_ty s) { @@ -2074,10 +2084,12 @@ symtable_visit_stmt(struct symtable *st, stmt_ty s) break; } case AsyncWith_kind: + maybe_set_ste_coroutine_for_module(st, s); VISIT_SEQ(st, withitem, s->v.AsyncWith.items); VISIT_SEQ(st, stmt, s->v.AsyncWith.body); break; case AsyncFor_kind: + maybe_set_ste_coroutine_for_module(st, s); VISIT(st, expr, s->v.AsyncFor.target); VISIT(st, expr, s->v.AsyncFor.iter); VISIT_SEQ(st, stmt, s->v.AsyncFor.body); -- cgit v0.12