diff options
author | Nick Coghlan <ncoghlan@gmail.com> | 2021-04-29 05:58:44 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-29 05:58:44 (GMT) |
commit | 1e7b858575d0ad782939f86aae4a2fa1c29e9f14 (patch) | |
tree | 9445a7a82905c5bb253564853f33dacfceac6e93 /Python | |
parent | e52ab42cedd2a5ef4c3c1a47d0cf96a8f06d051f (diff) | |
download | cpython-1e7b858575d0ad782939f86aae4a2fa1c29e9f14.zip cpython-1e7b858575d0ad782939f86aae4a2fa1c29e9f14.tar.gz cpython-1e7b858575d0ad782939f86aae4a2fa1c29e9f14.tar.bz2 |
bpo-43892: Make match patterns explicit in the AST (GH-25585)
Co-authored-by: Brandt Bucher <brandtbucher@gmail.com>
Diffstat (limited to 'Python')
-rw-r--r-- | Python/Python-ast.c | 1160 | ||||
-rw-r--r-- | Python/ast.c | 238 | ||||
-rw-r--r-- | Python/ast_opt.c | 124 | ||||
-rw-r--r-- | Python/ast_unparse.c | 13 | ||||
-rw-r--r-- | Python/bltinmodule.c | 6 | ||||
-rw-r--r-- | Python/compile.c | 381 | ||||
-rw-r--r-- | Python/symtable.c | 64 |
7 files changed, 1498 insertions, 488 deletions
diff --git a/Python/Python-ast.c b/Python/Python-ast.c index 779ee3e..5d7a0ae 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -106,7 +106,13 @@ void _PyAST_Fini(PyInterpreterState *interp) Py_CLEAR(state->MatMult_singleton); Py_CLEAR(state->MatMult_type); Py_CLEAR(state->MatchAs_type); + Py_CLEAR(state->MatchClass_type); + Py_CLEAR(state->MatchMapping_type); Py_CLEAR(state->MatchOr_type); + Py_CLEAR(state->MatchSequence_type); + Py_CLEAR(state->MatchSingleton_type); + Py_CLEAR(state->MatchStar_type); + Py_CLEAR(state->MatchValue_type); Py_CLEAR(state->Match_type); Py_CLEAR(state->Mod_singleton); Py_CLEAR(state->Mod_type); @@ -173,6 +179,7 @@ void _PyAST_Fini(PyInterpreterState *interp) Py_CLEAR(state->boolop_type); Py_CLEAR(state->cases); Py_CLEAR(state->cause); + Py_CLEAR(state->cls); Py_CLEAR(state->cmpop_type); Py_CLEAR(state->col_offset); Py_CLEAR(state->comparators); @@ -208,6 +215,8 @@ void _PyAST_Fini(PyInterpreterState *interp) Py_CLEAR(state->kind); Py_CLEAR(state->kw_defaults); Py_CLEAR(state->kwarg); + Py_CLEAR(state->kwd_attrs); + Py_CLEAR(state->kwd_patterns); Py_CLEAR(state->kwonlyargs); Py_CLEAR(state->left); Py_CLEAR(state->level); @@ -226,8 +235,10 @@ void _PyAST_Fini(PyInterpreterState *interp) Py_CLEAR(state->optional_vars); Py_CLEAR(state->orelse); Py_CLEAR(state->pattern); + Py_CLEAR(state->pattern_type); Py_CLEAR(state->patterns); Py_CLEAR(state->posonlyargs); + Py_CLEAR(state->rest); Py_CLEAR(state->returns); Py_CLEAR(state->right); Py_CLEAR(state->simple); @@ -276,6 +287,7 @@ static int init_identifiers(struct ast_state *state) if ((state->body = PyUnicode_InternFromString("body")) == NULL) return 0; if ((state->cases = PyUnicode_InternFromString("cases")) == NULL) return 0; if ((state->cause = PyUnicode_InternFromString("cause")) == NULL) return 0; + if ((state->cls = PyUnicode_InternFromString("cls")) == NULL) return 0; if ((state->col_offset = PyUnicode_InternFromString("col_offset")) == NULL) return 0; if ((state->comparators = PyUnicode_InternFromString("comparators")) == NULL) return 0; if ((state->context_expr = PyUnicode_InternFromString("context_expr")) == NULL) return 0; @@ -305,6 +317,8 @@ static int init_identifiers(struct ast_state *state) if ((state->kind = PyUnicode_InternFromString("kind")) == NULL) return 0; if ((state->kw_defaults = PyUnicode_InternFromString("kw_defaults")) == NULL) return 0; if ((state->kwarg = PyUnicode_InternFromString("kwarg")) == NULL) return 0; + if ((state->kwd_attrs = PyUnicode_InternFromString("kwd_attrs")) == NULL) return 0; + if ((state->kwd_patterns = PyUnicode_InternFromString("kwd_patterns")) == NULL) return 0; if ((state->kwonlyargs = PyUnicode_InternFromString("kwonlyargs")) == NULL) return 0; if ((state->left = PyUnicode_InternFromString("left")) == NULL) return 0; if ((state->level = PyUnicode_InternFromString("level")) == NULL) return 0; @@ -322,6 +336,7 @@ static int init_identifiers(struct ast_state *state) if ((state->pattern = PyUnicode_InternFromString("pattern")) == NULL) return 0; if ((state->patterns = PyUnicode_InternFromString("patterns")) == NULL) return 0; if ((state->posonlyargs = PyUnicode_InternFromString("posonlyargs")) == NULL) return 0; + if ((state->rest = PyUnicode_InternFromString("rest")) == NULL) return 0; if ((state->returns = PyUnicode_InternFromString("returns")) == NULL) return 0; if ((state->right = PyUnicode_InternFromString("right")) == NULL) return 0; if ((state->simple = PyUnicode_InternFromString("simple")) == NULL) return 0; @@ -353,6 +368,7 @@ GENERATE_ASDL_SEQ_CONSTRUCTOR(keyword, keyword_ty) GENERATE_ASDL_SEQ_CONSTRUCTOR(alias, alias_ty) GENERATE_ASDL_SEQ_CONSTRUCTOR(withitem, withitem_ty) GENERATE_ASDL_SEQ_CONSTRUCTOR(match_case, match_case_ty) +GENERATE_ASDL_SEQ_CONSTRUCTOR(pattern, pattern_ty) GENERATE_ASDL_SEQ_CONSTRUCTOR(type_ignore, type_ignore_ty) static PyObject* ast2obj_mod(struct ast_state *state, void*); @@ -610,13 +626,6 @@ static const char * const Slice_fields[]={ "upper", "step", }; -static const char * const MatchAs_fields[]={ - "pattern", - "name", -}; -static const char * const MatchOr_fields[]={ - "patterns", -}; static PyObject* ast2obj_expr_context(struct ast_state *state, expr_context_ty); static PyObject* ast2obj_boolop(struct ast_state *state, boolop_ty); static PyObject* ast2obj_operator(struct ast_state *state, operator_ty); @@ -696,6 +705,43 @@ static const char * const match_case_fields[]={ "guard", "body", }; +static const char * const pattern_attributes[] = { + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", +}; +static PyObject* ast2obj_pattern(struct ast_state *state, void*); +static const char * const MatchValue_fields[]={ + "value", +}; +static const char * const MatchSingleton_fields[]={ + "value", +}; +static const char * const MatchSequence_fields[]={ + "patterns", +}; +static const char * const MatchMapping_fields[]={ + "keys", + "patterns", + "rest", +}; +static const char * const MatchClass_fields[]={ + "cls", + "patterns", + "kwd_attrs", + "kwd_patterns", +}; +static const char * const MatchStar_fields[]={ + "name", +}; +static const char * const MatchAs_fields[]={ + "pattern", + "name", +}; +static const char * const MatchOr_fields[]={ + "patterns", +}; static PyObject* ast2obj_type_ignore(struct ast_state *state, void*); static const char * const TypeIgnore_fields[]={ "lineno", @@ -1275,9 +1321,7 @@ init_types(struct ast_state *state) " | Name(identifier id, expr_context ctx)\n" " | List(expr* elts, expr_context ctx)\n" " | Tuple(expr* elts, expr_context ctx)\n" - " | Slice(expr? lower, expr? upper, expr? step)\n" - " | MatchAs(expr pattern, identifier name)\n" - " | MatchOr(expr* patterns)"); + " | Slice(expr? lower, expr? upper, expr? step)"); if (!state->expr_type) return 0; if (!add_attributes(state, state->expr_type, expr_attributes, 4)) return 0; if (PyObject_SetAttr(state->expr_type, state->end_lineno, Py_None) == -1) @@ -1410,14 +1454,6 @@ init_types(struct ast_state *state) return 0; if (PyObject_SetAttr(state->Slice_type, state->step, Py_None) == -1) return 0; - state->MatchAs_type = make_type(state, "MatchAs", state->expr_type, - MatchAs_fields, 2, - "MatchAs(expr pattern, identifier name)"); - if (!state->MatchAs_type) return 0; - state->MatchOr_type = make_type(state, "MatchOr", state->expr_type, - MatchOr_fields, 1, - "MatchOr(expr* patterns)"); - if (!state->MatchOr_type) return 0; state->expr_context_type = make_type(state, "expr_context", state->AST_type, NULL, 0, "expr_context = Load | Store | Del"); @@ -1732,11 +1768,68 @@ init_types(struct ast_state *state) return 0; state->match_case_type = make_type(state, "match_case", state->AST_type, match_case_fields, 3, - "match_case(expr pattern, expr? guard, stmt* body)"); + "match_case(pattern pattern, expr? guard, stmt* body)"); if (!state->match_case_type) return 0; if (!add_attributes(state, state->match_case_type, NULL, 0)) return 0; if (PyObject_SetAttr(state->match_case_type, state->guard, Py_None) == -1) return 0; + state->pattern_type = make_type(state, "pattern", state->AST_type, NULL, 0, + "pattern = MatchValue(expr value)\n" + " | MatchSingleton(constant value)\n" + " | MatchSequence(pattern* patterns)\n" + " | MatchMapping(expr* keys, pattern* patterns, identifier? rest)\n" + " | MatchClass(expr cls, pattern* patterns, identifier* kwd_attrs, pattern* kwd_patterns)\n" + " | MatchStar(identifier? name)\n" + " | MatchAs(pattern? pattern, identifier? name)\n" + " | MatchOr(pattern* patterns)"); + if (!state->pattern_type) return 0; + if (!add_attributes(state, state->pattern_type, pattern_attributes, 4)) + return 0; + state->MatchValue_type = make_type(state, "MatchValue", + state->pattern_type, MatchValue_fields, + 1, + "MatchValue(expr value)"); + if (!state->MatchValue_type) return 0; + state->MatchSingleton_type = make_type(state, "MatchSingleton", + state->pattern_type, + MatchSingleton_fields, 1, + "MatchSingleton(constant value)"); + if (!state->MatchSingleton_type) return 0; + state->MatchSequence_type = make_type(state, "MatchSequence", + state->pattern_type, + MatchSequence_fields, 1, + "MatchSequence(pattern* patterns)"); + if (!state->MatchSequence_type) return 0; + state->MatchMapping_type = make_type(state, "MatchMapping", + state->pattern_type, + MatchMapping_fields, 3, + "MatchMapping(expr* keys, pattern* patterns, identifier? rest)"); + if (!state->MatchMapping_type) return 0; + if (PyObject_SetAttr(state->MatchMapping_type, state->rest, Py_None) == -1) + return 0; + state->MatchClass_type = make_type(state, "MatchClass", + state->pattern_type, MatchClass_fields, + 4, + "MatchClass(expr cls, pattern* patterns, identifier* kwd_attrs, pattern* kwd_patterns)"); + if (!state->MatchClass_type) return 0; + state->MatchStar_type = make_type(state, "MatchStar", state->pattern_type, + MatchStar_fields, 1, + "MatchStar(identifier? name)"); + if (!state->MatchStar_type) return 0; + if (PyObject_SetAttr(state->MatchStar_type, state->name, Py_None) == -1) + return 0; + state->MatchAs_type = make_type(state, "MatchAs", state->pattern_type, + MatchAs_fields, 2, + "MatchAs(pattern? pattern, identifier? name)"); + if (!state->MatchAs_type) return 0; + if (PyObject_SetAttr(state->MatchAs_type, state->pattern, Py_None) == -1) + return 0; + if (PyObject_SetAttr(state->MatchAs_type, state->name, Py_None) == -1) + return 0; + state->MatchOr_type = make_type(state, "MatchOr", state->pattern_type, + MatchOr_fields, 1, + "MatchOr(pattern* patterns)"); + if (!state->MatchOr_type) return 0; state->type_ignore_type = make_type(state, "type_ignore", state->AST_type, NULL, 0, "type_ignore = TypeIgnore(int lineno, string tag)"); @@ -1784,6 +1877,8 @@ static int obj2ast_withitem(struct ast_state *state, PyObject* obj, withitem_ty* out, PyArena* arena); static int obj2ast_match_case(struct ast_state *state, PyObject* obj, match_case_ty* out, PyArena* arena); +static int obj2ast_pattern(struct ast_state *state, PyObject* obj, pattern_ty* + out, PyArena* arena); static int obj2ast_type_ignore(struct ast_state *state, PyObject* obj, type_ignore_ty* out, PyArena* arena); @@ -3127,51 +3222,6 @@ _PyAST_Slice(expr_ty lower, expr_ty upper, expr_ty step, int lineno, int return p; } -expr_ty -_PyAST_MatchAs(expr_ty pattern, identifier name, int lineno, int col_offset, - int end_lineno, int end_col_offset, PyArena *arena) -{ - expr_ty p; - if (!pattern) { - PyErr_SetString(PyExc_ValueError, - "field 'pattern' is required for MatchAs"); - return NULL; - } - if (!name) { - PyErr_SetString(PyExc_ValueError, - "field 'name' is required for MatchAs"); - return NULL; - } - p = (expr_ty)_PyArena_Malloc(arena, sizeof(*p)); - if (!p) - return NULL; - p->kind = MatchAs_kind; - p->v.MatchAs.pattern = pattern; - p->v.MatchAs.name = name; - p->lineno = lineno; - p->col_offset = col_offset; - p->end_lineno = end_lineno; - p->end_col_offset = end_col_offset; - return p; -} - -expr_ty -_PyAST_MatchOr(asdl_expr_seq * patterns, int lineno, int col_offset, int - end_lineno, int end_col_offset, PyArena *arena) -{ - expr_ty p; - p = (expr_ty)_PyArena_Malloc(arena, sizeof(*p)); - if (!p) - return NULL; - p->kind = MatchOr_kind; - p->v.MatchOr.patterns = patterns; - p->lineno = lineno; - p->col_offset = col_offset; - p->end_lineno = end_lineno; - p->end_col_offset = end_col_offset; - return p; -} - comprehension_ty _PyAST_comprehension(expr_ty target, expr_ty iter, asdl_expr_seq * ifs, int is_async, PyArena *arena) @@ -3322,8 +3372,8 @@ _PyAST_withitem(expr_ty context_expr, expr_ty optional_vars, PyArena *arena) } match_case_ty -_PyAST_match_case(expr_ty pattern, expr_ty guard, asdl_stmt_seq * body, PyArena - *arena) +_PyAST_match_case(pattern_ty pattern, expr_ty guard, asdl_stmt_seq * body, + PyArena *arena) { match_case_ty p; if (!pattern) { @@ -3340,6 +3390,166 @@ _PyAST_match_case(expr_ty pattern, expr_ty guard, asdl_stmt_seq * body, PyArena return p; } +pattern_ty +_PyAST_MatchValue(expr_ty value, int lineno, int col_offset, int end_lineno, + int end_col_offset, PyArena *arena) +{ + pattern_ty p; + if (!value) { + PyErr_SetString(PyExc_ValueError, + "field 'value' is required for MatchValue"); + return NULL; + } + p = (pattern_ty)_PyArena_Malloc(arena, sizeof(*p)); + if (!p) + return NULL; + p->kind = MatchValue_kind; + p->v.MatchValue.value = value; + p->lineno = lineno; + p->col_offset = col_offset; + p->end_lineno = end_lineno; + p->end_col_offset = end_col_offset; + return p; +} + +pattern_ty +_PyAST_MatchSingleton(constant value, int lineno, int col_offset, int + end_lineno, int end_col_offset, PyArena *arena) +{ + pattern_ty p; + if (!value) { + PyErr_SetString(PyExc_ValueError, + "field 'value' is required for MatchSingleton"); + return NULL; + } + p = (pattern_ty)_PyArena_Malloc(arena, sizeof(*p)); + if (!p) + return NULL; + p->kind = MatchSingleton_kind; + p->v.MatchSingleton.value = value; + p->lineno = lineno; + p->col_offset = col_offset; + p->end_lineno = end_lineno; + p->end_col_offset = end_col_offset; + return p; +} + +pattern_ty +_PyAST_MatchSequence(asdl_pattern_seq * patterns, int lineno, int col_offset, + int end_lineno, int end_col_offset, PyArena *arena) +{ + pattern_ty p; + p = (pattern_ty)_PyArena_Malloc(arena, sizeof(*p)); + if (!p) + return NULL; + p->kind = MatchSequence_kind; + p->v.MatchSequence.patterns = patterns; + p->lineno = lineno; + p->col_offset = col_offset; + p->end_lineno = end_lineno; + p->end_col_offset = end_col_offset; + return p; +} + +pattern_ty +_PyAST_MatchMapping(asdl_expr_seq * keys, asdl_pattern_seq * patterns, + identifier rest, int lineno, int col_offset, int + end_lineno, int end_col_offset, PyArena *arena) +{ + pattern_ty p; + p = (pattern_ty)_PyArena_Malloc(arena, sizeof(*p)); + if (!p) + return NULL; + p->kind = MatchMapping_kind; + p->v.MatchMapping.keys = keys; + p->v.MatchMapping.patterns = patterns; + p->v.MatchMapping.rest = rest; + p->lineno = lineno; + p->col_offset = col_offset; + p->end_lineno = end_lineno; + p->end_col_offset = end_col_offset; + return p; +} + +pattern_ty +_PyAST_MatchClass(expr_ty cls, asdl_pattern_seq * patterns, asdl_identifier_seq + * kwd_attrs, asdl_pattern_seq * kwd_patterns, int lineno, int + col_offset, int end_lineno, int end_col_offset, PyArena + *arena) +{ + pattern_ty p; + if (!cls) { + PyErr_SetString(PyExc_ValueError, + "field 'cls' is required for MatchClass"); + return NULL; + } + p = (pattern_ty)_PyArena_Malloc(arena, sizeof(*p)); + if (!p) + return NULL; + p->kind = MatchClass_kind; + p->v.MatchClass.cls = cls; + p->v.MatchClass.patterns = patterns; + p->v.MatchClass.kwd_attrs = kwd_attrs; + p->v.MatchClass.kwd_patterns = kwd_patterns; + p->lineno = lineno; + p->col_offset = col_offset; + p->end_lineno = end_lineno; + p->end_col_offset = end_col_offset; + return p; +} + +pattern_ty +_PyAST_MatchStar(identifier name, int lineno, int col_offset, int end_lineno, + int end_col_offset, PyArena *arena) +{ + pattern_ty p; + p = (pattern_ty)_PyArena_Malloc(arena, sizeof(*p)); + if (!p) + return NULL; + p->kind = MatchStar_kind; + p->v.MatchStar.name = name; + p->lineno = lineno; + p->col_offset = col_offset; + p->end_lineno = end_lineno; + p->end_col_offset = end_col_offset; + return p; +} + +pattern_ty +_PyAST_MatchAs(pattern_ty pattern, identifier name, int lineno, int col_offset, + int end_lineno, int end_col_offset, PyArena *arena) +{ + pattern_ty p; + p = (pattern_ty)_PyArena_Malloc(arena, sizeof(*p)); + if (!p) + return NULL; + p->kind = MatchAs_kind; + p->v.MatchAs.pattern = pattern; + p->v.MatchAs.name = name; + p->lineno = lineno; + p->col_offset = col_offset; + p->end_lineno = end_lineno; + p->end_col_offset = end_col_offset; + return p; +} + +pattern_ty +_PyAST_MatchOr(asdl_pattern_seq * patterns, int lineno, int col_offset, int + end_lineno, int end_col_offset, PyArena *arena) +{ + pattern_ty p; + p = (pattern_ty)_PyArena_Malloc(arena, sizeof(*p)); + if (!p) + return NULL; + p->kind = MatchOr_kind; + p->v.MatchOr.patterns = patterns; + p->lineno = lineno; + p->col_offset = col_offset; + p->end_lineno = end_lineno; + p->end_col_offset = end_col_offset; + return p; +} + type_ignore_ty _PyAST_TypeIgnore(int lineno, string tag, PyArena *arena) { @@ -4410,32 +4620,6 @@ ast2obj_expr(struct ast_state *state, void* _o) goto failed; Py_DECREF(value); break; - case MatchAs_kind: - tp = (PyTypeObject *)state->MatchAs_type; - result = PyType_GenericNew(tp, NULL, NULL); - if (!result) goto failed; - value = ast2obj_expr(state, o->v.MatchAs.pattern); - if (!value) goto failed; - if (PyObject_SetAttr(result, state->pattern, value) == -1) - goto failed; - Py_DECREF(value); - value = ast2obj_identifier(state, o->v.MatchAs.name); - if (!value) goto failed; - if (PyObject_SetAttr(result, state->name, value) == -1) - goto failed; - Py_DECREF(value); - break; - case MatchOr_kind: - tp = (PyTypeObject *)state->MatchOr_type; - result = PyType_GenericNew(tp, NULL, NULL); - if (!result) goto failed; - value = ast2obj_list(state, (asdl_seq*)o->v.MatchOr.patterns, - ast2obj_expr); - if (!value) goto failed; - if (PyObject_SetAttr(result, state->patterns, value) == -1) - goto failed; - Py_DECREF(value); - break; } value = ast2obj_int(state, o->lineno); if (!value) goto failed; @@ -4935,7 +5119,7 @@ ast2obj_match_case(struct ast_state *state, void* _o) tp = (PyTypeObject *)state->match_case_type; result = PyType_GenericNew(tp, NULL, NULL); if (!result) return NULL; - value = ast2obj_expr(state, o->pattern); + value = ast2obj_pattern(state, o->pattern); if (!value) goto failed; if (PyObject_SetAttr(result, state->pattern, value) == -1) goto failed; @@ -4958,6 +5142,161 @@ failed: } PyObject* +ast2obj_pattern(struct ast_state *state, void* _o) +{ + pattern_ty o = (pattern_ty)_o; + PyObject *result = NULL, *value = NULL; + PyTypeObject *tp; + if (!o) { + Py_RETURN_NONE; + } + switch (o->kind) { + case MatchValue_kind: + tp = (PyTypeObject *)state->MatchValue_type; + result = PyType_GenericNew(tp, NULL, NULL); + if (!result) goto failed; + value = ast2obj_expr(state, o->v.MatchValue.value); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->value, value) == -1) + goto failed; + Py_DECREF(value); + break; + case MatchSingleton_kind: + tp = (PyTypeObject *)state->MatchSingleton_type; + result = PyType_GenericNew(tp, NULL, NULL); + if (!result) goto failed; + value = ast2obj_constant(state, o->v.MatchSingleton.value); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->value, value) == -1) + goto failed; + Py_DECREF(value); + break; + case MatchSequence_kind: + tp = (PyTypeObject *)state->MatchSequence_type; + result = PyType_GenericNew(tp, NULL, NULL); + if (!result) goto failed; + value = ast2obj_list(state, (asdl_seq*)o->v.MatchSequence.patterns, + ast2obj_pattern); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->patterns, value) == -1) + goto failed; + Py_DECREF(value); + break; + case MatchMapping_kind: + tp = (PyTypeObject *)state->MatchMapping_type; + result = PyType_GenericNew(tp, NULL, NULL); + if (!result) goto failed; + value = ast2obj_list(state, (asdl_seq*)o->v.MatchMapping.keys, + ast2obj_expr); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->keys, value) == -1) + goto failed; + Py_DECREF(value); + value = ast2obj_list(state, (asdl_seq*)o->v.MatchMapping.patterns, + ast2obj_pattern); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->patterns, value) == -1) + goto failed; + Py_DECREF(value); + value = ast2obj_identifier(state, o->v.MatchMapping.rest); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->rest, value) == -1) + goto failed; + Py_DECREF(value); + break; + case MatchClass_kind: + tp = (PyTypeObject *)state->MatchClass_type; + result = PyType_GenericNew(tp, NULL, NULL); + if (!result) goto failed; + value = ast2obj_expr(state, o->v.MatchClass.cls); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->cls, value) == -1) + goto failed; + Py_DECREF(value); + value = ast2obj_list(state, (asdl_seq*)o->v.MatchClass.patterns, + ast2obj_pattern); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->patterns, value) == -1) + goto failed; + Py_DECREF(value); + value = ast2obj_list(state, (asdl_seq*)o->v.MatchClass.kwd_attrs, + ast2obj_identifier); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->kwd_attrs, value) == -1) + goto failed; + Py_DECREF(value); + value = ast2obj_list(state, (asdl_seq*)o->v.MatchClass.kwd_patterns, + ast2obj_pattern); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->kwd_patterns, value) == -1) + goto failed; + Py_DECREF(value); + break; + case MatchStar_kind: + tp = (PyTypeObject *)state->MatchStar_type; + result = PyType_GenericNew(tp, NULL, NULL); + if (!result) goto failed; + value = ast2obj_identifier(state, o->v.MatchStar.name); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->name, value) == -1) + goto failed; + Py_DECREF(value); + break; + case MatchAs_kind: + tp = (PyTypeObject *)state->MatchAs_type; + result = PyType_GenericNew(tp, NULL, NULL); + if (!result) goto failed; + value = ast2obj_pattern(state, o->v.MatchAs.pattern); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->pattern, value) == -1) + goto failed; + Py_DECREF(value); + value = ast2obj_identifier(state, o->v.MatchAs.name); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->name, value) == -1) + goto failed; + Py_DECREF(value); + break; + case MatchOr_kind: + tp = (PyTypeObject *)state->MatchOr_type; + result = PyType_GenericNew(tp, NULL, NULL); + if (!result) goto failed; + value = ast2obj_list(state, (asdl_seq*)o->v.MatchOr.patterns, + ast2obj_pattern); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->patterns, value) == -1) + goto failed; + Py_DECREF(value); + break; + } + value = ast2obj_int(state, o->lineno); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->lineno, value) < 0) + goto failed; + Py_DECREF(value); + value = ast2obj_int(state, o->col_offset); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->col_offset, value) < 0) + goto failed; + Py_DECREF(value); + value = ast2obj_int(state, o->end_lineno); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->end_lineno, value) < 0) + goto failed; + Py_DECREF(value); + value = ast2obj_int(state, o->end_col_offset); + if (!value) goto failed; + if (PyObject_SetAttr(result, state->end_col_offset, value) < 0) + goto failed; + Py_DECREF(value); + return result; +failed: + Py_XDECREF(value); + Py_XDECREF(result); + return NULL; +} + +PyObject* ast2obj_type_ignore(struct ast_state *state, void* _o) { type_ignore_ty o = (type_ignore_ty)_o; @@ -8689,92 +9028,6 @@ obj2ast_expr(struct ast_state *state, PyObject* obj, expr_ty* out, PyArena* if (*out == NULL) goto failed; return 0; } - tp = state->MatchAs_type; - isinstance = PyObject_IsInstance(obj, tp); - if (isinstance == -1) { - return 1; - } - if (isinstance) { - expr_ty pattern; - identifier name; - - if (_PyObject_LookupAttr(obj, state->pattern, &tmp) < 0) { - return 1; - } - if (tmp == NULL) { - PyErr_SetString(PyExc_TypeError, "required field \"pattern\" missing from MatchAs"); - return 1; - } - else { - int res; - res = obj2ast_expr(state, tmp, &pattern, arena); - if (res != 0) goto failed; - Py_CLEAR(tmp); - } - if (_PyObject_LookupAttr(obj, state->name, &tmp) < 0) { - return 1; - } - if (tmp == NULL) { - PyErr_SetString(PyExc_TypeError, "required field \"name\" missing from MatchAs"); - return 1; - } - else { - int res; - res = obj2ast_identifier(state, tmp, &name, arena); - if (res != 0) goto failed; - Py_CLEAR(tmp); - } - *out = _PyAST_MatchAs(pattern, name, lineno, col_offset, end_lineno, - end_col_offset, arena); - if (*out == NULL) goto failed; - return 0; - } - tp = state->MatchOr_type; - isinstance = PyObject_IsInstance(obj, tp); - if (isinstance == -1) { - return 1; - } - if (isinstance) { - asdl_expr_seq* patterns; - - if (_PyObject_LookupAttr(obj, state->patterns, &tmp) < 0) { - return 1; - } - if (tmp == NULL) { - PyErr_SetString(PyExc_TypeError, "required field \"patterns\" missing from MatchOr"); - return 1; - } - else { - int res; - Py_ssize_t len; - Py_ssize_t i; - if (!PyList_Check(tmp)) { - PyErr_Format(PyExc_TypeError, "MatchOr field \"patterns\" must be a list, not a %.200s", _PyType_Name(Py_TYPE(tmp))); - goto failed; - } - len = PyList_GET_SIZE(tmp); - patterns = _Py_asdl_expr_seq_new(len, arena); - if (patterns == NULL) goto failed; - for (i = 0; i < len; i++) { - expr_ty val; - PyObject *tmp2 = PyList_GET_ITEM(tmp, i); - Py_INCREF(tmp2); - res = obj2ast_expr(state, tmp2, &val, arena); - Py_DECREF(tmp2); - if (res != 0) goto failed; - if (len != PyList_GET_SIZE(tmp)) { - PyErr_SetString(PyExc_RuntimeError, "MatchOr field \"patterns\" changed size during iteration"); - goto failed; - } - asdl_seq_SET(patterns, i, val); - } - Py_CLEAR(tmp); - } - *out = _PyAST_MatchOr(patterns, lineno, col_offset, end_lineno, - end_col_offset, arena); - if (*out == NULL) goto failed; - return 0; - } PyErr_Format(PyExc_TypeError, "expected some sort of expr, but got %R", obj); failed: @@ -9897,7 +10150,7 @@ obj2ast_match_case(struct ast_state *state, PyObject* obj, match_case_ty* out, PyArena* arena) { PyObject* tmp = NULL; - expr_ty pattern; + pattern_ty pattern; expr_ty guard; asdl_stmt_seq* body; @@ -9910,7 +10163,7 @@ obj2ast_match_case(struct ast_state *state, PyObject* obj, match_case_ty* out, } else { int res; - res = obj2ast_expr(state, tmp, &pattern, arena); + res = obj2ast_pattern(state, tmp, &pattern, arena); if (res != 0) goto failed; Py_CLEAR(tmp); } @@ -9968,6 +10221,515 @@ failed: } int +obj2ast_pattern(struct ast_state *state, PyObject* obj, pattern_ty* out, + PyArena* arena) +{ + int isinstance; + + PyObject *tmp = NULL; + PyObject *tp; + int lineno; + int col_offset; + int end_lineno; + int end_col_offset; + + if (obj == Py_None) { + *out = NULL; + return 0; + } + if (_PyObject_LookupAttr(obj, state->lineno, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"lineno\" missing from pattern"); + return 1; + } + else { + int res; + res = obj2ast_int(state, tmp, &lineno, arena); + if (res != 0) goto failed; + Py_CLEAR(tmp); + } + if (_PyObject_LookupAttr(obj, state->col_offset, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"col_offset\" missing from pattern"); + return 1; + } + else { + int res; + res = obj2ast_int(state, tmp, &col_offset, arena); + if (res != 0) goto failed; + Py_CLEAR(tmp); + } + if (_PyObject_LookupAttr(obj, state->end_lineno, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"end_lineno\" missing from pattern"); + return 1; + } + else { + int res; + res = obj2ast_int(state, tmp, &end_lineno, arena); + if (res != 0) goto failed; + Py_CLEAR(tmp); + } + if (_PyObject_LookupAttr(obj, state->end_col_offset, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"end_col_offset\" missing from pattern"); + return 1; + } + else { + int res; + res = obj2ast_int(state, tmp, &end_col_offset, arena); + if (res != 0) goto failed; + Py_CLEAR(tmp); + } + tp = state->MatchValue_type; + isinstance = PyObject_IsInstance(obj, tp); + if (isinstance == -1) { + return 1; + } + if (isinstance) { + expr_ty value; + + if (_PyObject_LookupAttr(obj, state->value, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"value\" missing from MatchValue"); + return 1; + } + else { + int res; + res = obj2ast_expr(state, tmp, &value, arena); + if (res != 0) goto failed; + Py_CLEAR(tmp); + } + *out = _PyAST_MatchValue(value, lineno, col_offset, end_lineno, + end_col_offset, arena); + if (*out == NULL) goto failed; + return 0; + } + tp = state->MatchSingleton_type; + isinstance = PyObject_IsInstance(obj, tp); + if (isinstance == -1) { + return 1; + } + if (isinstance) { + constant value; + + if (_PyObject_LookupAttr(obj, state->value, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"value\" missing from MatchSingleton"); + return 1; + } + else { + int res; + res = obj2ast_constant(state, tmp, &value, arena); + if (res != 0) goto failed; + Py_CLEAR(tmp); + } + *out = _PyAST_MatchSingleton(value, lineno, col_offset, end_lineno, + end_col_offset, arena); + if (*out == NULL) goto failed; + return 0; + } + tp = state->MatchSequence_type; + isinstance = PyObject_IsInstance(obj, tp); + if (isinstance == -1) { + return 1; + } + if (isinstance) { + asdl_pattern_seq* patterns; + + if (_PyObject_LookupAttr(obj, state->patterns, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"patterns\" missing from MatchSequence"); + return 1; + } + else { + int res; + Py_ssize_t len; + Py_ssize_t i; + if (!PyList_Check(tmp)) { + PyErr_Format(PyExc_TypeError, "MatchSequence field \"patterns\" must be a list, not a %.200s", _PyType_Name(Py_TYPE(tmp))); + goto failed; + } + len = PyList_GET_SIZE(tmp); + patterns = _Py_asdl_pattern_seq_new(len, arena); + if (patterns == NULL) goto failed; + for (i = 0; i < len; i++) { + pattern_ty val; + PyObject *tmp2 = PyList_GET_ITEM(tmp, i); + Py_INCREF(tmp2); + res = obj2ast_pattern(state, tmp2, &val, arena); + Py_DECREF(tmp2); + if (res != 0) goto failed; + if (len != PyList_GET_SIZE(tmp)) { + PyErr_SetString(PyExc_RuntimeError, "MatchSequence field \"patterns\" changed size during iteration"); + goto failed; + } + asdl_seq_SET(patterns, i, val); + } + Py_CLEAR(tmp); + } + *out = _PyAST_MatchSequence(patterns, lineno, col_offset, end_lineno, + end_col_offset, arena); + if (*out == NULL) goto failed; + return 0; + } + tp = state->MatchMapping_type; + isinstance = PyObject_IsInstance(obj, tp); + if (isinstance == -1) { + return 1; + } + if (isinstance) { + asdl_expr_seq* keys; + asdl_pattern_seq* patterns; + identifier rest; + + if (_PyObject_LookupAttr(obj, state->keys, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"keys\" missing from MatchMapping"); + return 1; + } + else { + int res; + Py_ssize_t len; + Py_ssize_t i; + if (!PyList_Check(tmp)) { + PyErr_Format(PyExc_TypeError, "MatchMapping field \"keys\" must be a list, not a %.200s", _PyType_Name(Py_TYPE(tmp))); + goto failed; + } + len = PyList_GET_SIZE(tmp); + keys = _Py_asdl_expr_seq_new(len, arena); + if (keys == NULL) goto failed; + for (i = 0; i < len; i++) { + expr_ty val; + PyObject *tmp2 = PyList_GET_ITEM(tmp, i); + Py_INCREF(tmp2); + res = obj2ast_expr(state, tmp2, &val, arena); + Py_DECREF(tmp2); + if (res != 0) goto failed; + if (len != PyList_GET_SIZE(tmp)) { + PyErr_SetString(PyExc_RuntimeError, "MatchMapping field \"keys\" changed size during iteration"); + goto failed; + } + asdl_seq_SET(keys, i, val); + } + Py_CLEAR(tmp); + } + if (_PyObject_LookupAttr(obj, state->patterns, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"patterns\" missing from MatchMapping"); + return 1; + } + else { + int res; + Py_ssize_t len; + Py_ssize_t i; + if (!PyList_Check(tmp)) { + PyErr_Format(PyExc_TypeError, "MatchMapping field \"patterns\" must be a list, not a %.200s", _PyType_Name(Py_TYPE(tmp))); + goto failed; + } + len = PyList_GET_SIZE(tmp); + patterns = _Py_asdl_pattern_seq_new(len, arena); + if (patterns == NULL) goto failed; + for (i = 0; i < len; i++) { + pattern_ty val; + PyObject *tmp2 = PyList_GET_ITEM(tmp, i); + Py_INCREF(tmp2); + res = obj2ast_pattern(state, tmp2, &val, arena); + Py_DECREF(tmp2); + if (res != 0) goto failed; + if (len != PyList_GET_SIZE(tmp)) { + PyErr_SetString(PyExc_RuntimeError, "MatchMapping field \"patterns\" changed size during iteration"); + goto failed; + } + asdl_seq_SET(patterns, i, val); + } + Py_CLEAR(tmp); + } + if (_PyObject_LookupAttr(obj, state->rest, &tmp) < 0) { + return 1; + } + if (tmp == NULL || tmp == Py_None) { + Py_CLEAR(tmp); + rest = NULL; + } + else { + int res; + res = obj2ast_identifier(state, tmp, &rest, arena); + if (res != 0) goto failed; + Py_CLEAR(tmp); + } + *out = _PyAST_MatchMapping(keys, patterns, rest, lineno, col_offset, + end_lineno, end_col_offset, arena); + if (*out == NULL) goto failed; + return 0; + } + tp = state->MatchClass_type; + isinstance = PyObject_IsInstance(obj, tp); + if (isinstance == -1) { + return 1; + } + if (isinstance) { + expr_ty cls; + asdl_pattern_seq* patterns; + asdl_identifier_seq* kwd_attrs; + asdl_pattern_seq* kwd_patterns; + + if (_PyObject_LookupAttr(obj, state->cls, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"cls\" missing from MatchClass"); + return 1; + } + else { + int res; + res = obj2ast_expr(state, tmp, &cls, arena); + if (res != 0) goto failed; + Py_CLEAR(tmp); + } + if (_PyObject_LookupAttr(obj, state->patterns, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"patterns\" missing from MatchClass"); + return 1; + } + else { + int res; + Py_ssize_t len; + Py_ssize_t i; + if (!PyList_Check(tmp)) { + PyErr_Format(PyExc_TypeError, "MatchClass field \"patterns\" must be a list, not a %.200s", _PyType_Name(Py_TYPE(tmp))); + goto failed; + } + len = PyList_GET_SIZE(tmp); + patterns = _Py_asdl_pattern_seq_new(len, arena); + if (patterns == NULL) goto failed; + for (i = 0; i < len; i++) { + pattern_ty val; + PyObject *tmp2 = PyList_GET_ITEM(tmp, i); + Py_INCREF(tmp2); + res = obj2ast_pattern(state, tmp2, &val, arena); + Py_DECREF(tmp2); + if (res != 0) goto failed; + if (len != PyList_GET_SIZE(tmp)) { + PyErr_SetString(PyExc_RuntimeError, "MatchClass field \"patterns\" changed size during iteration"); + goto failed; + } + asdl_seq_SET(patterns, i, val); + } + Py_CLEAR(tmp); + } + if (_PyObject_LookupAttr(obj, state->kwd_attrs, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"kwd_attrs\" missing from MatchClass"); + return 1; + } + else { + int res; + Py_ssize_t len; + Py_ssize_t i; + if (!PyList_Check(tmp)) { + PyErr_Format(PyExc_TypeError, "MatchClass field \"kwd_attrs\" must be a list, not a %.200s", _PyType_Name(Py_TYPE(tmp))); + goto failed; + } + len = PyList_GET_SIZE(tmp); + kwd_attrs = _Py_asdl_identifier_seq_new(len, arena); + if (kwd_attrs == NULL) goto failed; + for (i = 0; i < len; i++) { + identifier val; + PyObject *tmp2 = PyList_GET_ITEM(tmp, i); + Py_INCREF(tmp2); + res = obj2ast_identifier(state, tmp2, &val, arena); + Py_DECREF(tmp2); + if (res != 0) goto failed; + if (len != PyList_GET_SIZE(tmp)) { + PyErr_SetString(PyExc_RuntimeError, "MatchClass field \"kwd_attrs\" changed size during iteration"); + goto failed; + } + asdl_seq_SET(kwd_attrs, i, val); + } + Py_CLEAR(tmp); + } + if (_PyObject_LookupAttr(obj, state->kwd_patterns, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"kwd_patterns\" missing from MatchClass"); + return 1; + } + else { + int res; + Py_ssize_t len; + Py_ssize_t i; + if (!PyList_Check(tmp)) { + PyErr_Format(PyExc_TypeError, "MatchClass field \"kwd_patterns\" must be a list, not a %.200s", _PyType_Name(Py_TYPE(tmp))); + goto failed; + } + len = PyList_GET_SIZE(tmp); + kwd_patterns = _Py_asdl_pattern_seq_new(len, arena); + if (kwd_patterns == NULL) goto failed; + for (i = 0; i < len; i++) { + pattern_ty val; + PyObject *tmp2 = PyList_GET_ITEM(tmp, i); + Py_INCREF(tmp2); + res = obj2ast_pattern(state, tmp2, &val, arena); + Py_DECREF(tmp2); + if (res != 0) goto failed; + if (len != PyList_GET_SIZE(tmp)) { + PyErr_SetString(PyExc_RuntimeError, "MatchClass field \"kwd_patterns\" changed size during iteration"); + goto failed; + } + asdl_seq_SET(kwd_patterns, i, val); + } + Py_CLEAR(tmp); + } + *out = _PyAST_MatchClass(cls, patterns, kwd_attrs, kwd_patterns, + lineno, col_offset, end_lineno, + end_col_offset, arena); + if (*out == NULL) goto failed; + return 0; + } + tp = state->MatchStar_type; + isinstance = PyObject_IsInstance(obj, tp); + if (isinstance == -1) { + return 1; + } + if (isinstance) { + identifier name; + + if (_PyObject_LookupAttr(obj, state->name, &tmp) < 0) { + return 1; + } + if (tmp == NULL || tmp == Py_None) { + Py_CLEAR(tmp); + name = NULL; + } + else { + int res; + res = obj2ast_identifier(state, tmp, &name, arena); + if (res != 0) goto failed; + Py_CLEAR(tmp); + } + *out = _PyAST_MatchStar(name, lineno, col_offset, end_lineno, + end_col_offset, arena); + if (*out == NULL) goto failed; + return 0; + } + tp = state->MatchAs_type; + isinstance = PyObject_IsInstance(obj, tp); + if (isinstance == -1) { + return 1; + } + if (isinstance) { + pattern_ty pattern; + identifier name; + + if (_PyObject_LookupAttr(obj, state->pattern, &tmp) < 0) { + return 1; + } + if (tmp == NULL || tmp == Py_None) { + Py_CLEAR(tmp); + pattern = NULL; + } + else { + int res; + res = obj2ast_pattern(state, tmp, &pattern, arena); + if (res != 0) goto failed; + Py_CLEAR(tmp); + } + if (_PyObject_LookupAttr(obj, state->name, &tmp) < 0) { + return 1; + } + if (tmp == NULL || tmp == Py_None) { + Py_CLEAR(tmp); + name = NULL; + } + else { + int res; + res = obj2ast_identifier(state, tmp, &name, arena); + if (res != 0) goto failed; + Py_CLEAR(tmp); + } + *out = _PyAST_MatchAs(pattern, name, lineno, col_offset, end_lineno, + end_col_offset, arena); + if (*out == NULL) goto failed; + return 0; + } + tp = state->MatchOr_type; + isinstance = PyObject_IsInstance(obj, tp); + if (isinstance == -1) { + return 1; + } + if (isinstance) { + asdl_pattern_seq* patterns; + + if (_PyObject_LookupAttr(obj, state->patterns, &tmp) < 0) { + return 1; + } + if (tmp == NULL) { + PyErr_SetString(PyExc_TypeError, "required field \"patterns\" missing from MatchOr"); + return 1; + } + else { + int res; + Py_ssize_t len; + Py_ssize_t i; + if (!PyList_Check(tmp)) { + PyErr_Format(PyExc_TypeError, "MatchOr field \"patterns\" must be a list, not a %.200s", _PyType_Name(Py_TYPE(tmp))); + goto failed; + } + len = PyList_GET_SIZE(tmp); + patterns = _Py_asdl_pattern_seq_new(len, arena); + if (patterns == NULL) goto failed; + for (i = 0; i < len; i++) { + pattern_ty val; + PyObject *tmp2 = PyList_GET_ITEM(tmp, i); + Py_INCREF(tmp2); + res = obj2ast_pattern(state, tmp2, &val, arena); + Py_DECREF(tmp2); + if (res != 0) goto failed; + if (len != PyList_GET_SIZE(tmp)) { + PyErr_SetString(PyExc_RuntimeError, "MatchOr field \"patterns\" changed size during iteration"); + goto failed; + } + asdl_seq_SET(patterns, i, val); + } + Py_CLEAR(tmp); + } + *out = _PyAST_MatchOr(patterns, lineno, col_offset, end_lineno, + end_col_offset, arena); + if (*out == NULL) goto failed; + return 0; + } + + PyErr_Format(PyExc_TypeError, "expected some sort of pattern, but got %R", obj); + failed: + Py_XDECREF(tmp); + return 1; +} + +int obj2ast_type_ignore(struct ast_state *state, PyObject* obj, type_ignore_ty* out, PyArena* arena) { @@ -10230,12 +10992,6 @@ astmodule_exec(PyObject *m) if (PyModule_AddObjectRef(m, "Slice", state->Slice_type) < 0) { return -1; } - if (PyModule_AddObjectRef(m, "MatchAs", state->MatchAs_type) < 0) { - return -1; - } - if (PyModule_AddObjectRef(m, "MatchOr", state->MatchOr_type) < 0) { - return -1; - } if (PyModule_AddObjectRef(m, "expr_context", state->expr_context_type) < 0) { return -1; @@ -10378,6 +11134,36 @@ astmodule_exec(PyObject *m) if (PyModule_AddObjectRef(m, "match_case", state->match_case_type) < 0) { return -1; } + if (PyModule_AddObjectRef(m, "pattern", state->pattern_type) < 0) { + return -1; + } + if (PyModule_AddObjectRef(m, "MatchValue", state->MatchValue_type) < 0) { + return -1; + } + if (PyModule_AddObjectRef(m, "MatchSingleton", state->MatchSingleton_type) + < 0) { + return -1; + } + if (PyModule_AddObjectRef(m, "MatchSequence", state->MatchSequence_type) < + 0) { + return -1; + } + if (PyModule_AddObjectRef(m, "MatchMapping", state->MatchMapping_type) < 0) + { + return -1; + } + if (PyModule_AddObjectRef(m, "MatchClass", state->MatchClass_type) < 0) { + return -1; + } + if (PyModule_AddObjectRef(m, "MatchStar", state->MatchStar_type) < 0) { + return -1; + } + if (PyModule_AddObjectRef(m, "MatchAs", state->MatchAs_type) < 0) { + return -1; + } + if (PyModule_AddObjectRef(m, "MatchOr", state->MatchOr_type) < 0) { + return -1; + } if (PyModule_AddObjectRef(m, "type_ignore", state->type_ignore_type) < 0) { return -1; } diff --git a/Python/ast.c b/Python/ast.c index 2b96543..1fc83f6 100644 --- a/Python/ast.c +++ b/Python/ast.c @@ -7,6 +7,7 @@ #include "pycore_pystate.h" // _PyThreadState_GET() #include <assert.h> +#include <stdbool.h> struct validator { int recursion_depth; /* current recursion depth */ @@ -18,6 +19,7 @@ static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, i static int _validate_nonempty_seq(asdl_seq *, const char *, const char *); static int validate_stmt(struct validator *, stmt_ty); static int validate_expr(struct validator *, expr_ty, expr_context_ty); +static int validate_pattern(struct validator *, pattern_ty); static int validate_name(PyObject *name) @@ -88,9 +90,9 @@ expr_context_name(expr_context_ty ctx) return "Store"; case Del: return "Del"; - default: - Py_UNREACHABLE(); + // No default case so compiler emits warning for unhandled cases } + Py_UNREACHABLE(); } static int @@ -180,7 +182,7 @@ validate_constant(struct validator *state, PyObject *value) static int validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx) { - int ret; + int ret = -1; if (++state->recursion_depth > state->recursion_limit) { PyErr_SetString(PyExc_RecursionError, "maximum recursion depth exceeded during compilation"); @@ -351,34 +353,216 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx) case NamedExpr_kind: ret = validate_expr(state, exp->v.NamedExpr.value, Load); break; - case MatchAs_kind: - PyErr_SetString(PyExc_ValueError, - "MatchAs is only valid in match_case patterns"); - return 0; - case MatchOr_kind: - PyErr_SetString(PyExc_ValueError, - "MatchOr is only valid in match_case patterns"); - return 0; /* This last case doesn't have any checking. */ case Name_kind: ret = 1; break; - default: + // No default case so compiler emits warning for unhandled cases + } + if (ret < 0) { PyErr_SetString(PyExc_SystemError, "unexpected expression"); - return 0; + ret = 0; } state->recursion_depth--; return ret; } + +// Note: the ensure_literal_* functions are only used to validate a restricted +// set of non-recursive literals that have already been checked with +// validate_expr, so they don't accept the validator state +static int +ensure_literal_number(expr_ty exp, bool allow_real, bool allow_imaginary) +{ + assert(exp->kind == Constant_kind); + PyObject *value = exp->v.Constant.value; + return (allow_real && PyFloat_CheckExact(value)) || + (allow_real && PyLong_CheckExact(value)) || + (allow_imaginary && PyComplex_CheckExact(value)); +} + static int -validate_pattern(expr_ty p) +ensure_literal_negative(expr_ty exp, bool allow_real, bool allow_imaginary) { - // Coming soon (thanks Batuhan)! + assert(exp->kind == UnaryOp_kind); + // Must be negation ... + if (exp->v.UnaryOp.op != USub) { + return 0; + } + // ... of a constant ... + expr_ty operand = exp->v.UnaryOp.operand; + if (operand->kind != Constant_kind) { + return 0; + } + // ... number + return ensure_literal_number(operand, allow_real, allow_imaginary); +} + +static int +ensure_literal_complex(expr_ty exp) +{ + assert(exp->kind == BinOp_kind); + expr_ty left = exp->v.BinOp.left; + expr_ty right = exp->v.BinOp.right; + // Ensure op is addition or subtraction + if (exp->v.BinOp.op != Add && exp->v.BinOp.op != Sub) { + return 0; + } + // Check LHS is a real number (potentially signed) + switch (left->kind) + { + case Constant_kind: + if (!ensure_literal_number(left, /*real=*/true, /*imaginary=*/false)) { + return 0; + } + break; + case UnaryOp_kind: + if (!ensure_literal_negative(left, /*real=*/true, /*imaginary=*/false)) { + return 0; + } + break; + default: + return 0; + } + // Check RHS is an imaginary number (no separate sign allowed) + switch (right->kind) + { + case Constant_kind: + if (!ensure_literal_number(right, /*real=*/false, /*imaginary=*/true)) { + return 0; + } + break; + default: + return 0; + } return 1; } static int +validate_pattern_match_value(struct validator *state, expr_ty exp) +{ + if (!validate_expr(state, exp, Load)) { + return 0; + } + + switch (exp->kind) + { + case Constant_kind: + case Attribute_kind: + // Constants and attribute lookups are always permitted + return 1; + case UnaryOp_kind: + // Negated numbers are permitted (whether real or imaginary) + // Compiler will complain if AST folding doesn't create a constant + if (ensure_literal_negative(exp, /*real=*/true, /*imaginary=*/true)) { + return 1; + } + break; + case BinOp_kind: + // Complex literals are permitted + // Compiler will complain if AST folding doesn't create a constant + if (ensure_literal_complex(exp)) { + return 1; + } + break; + default: + break; + } + PyErr_SetString(PyExc_SyntaxError, + "patterns may only match literals and attribute lookups"); + return 0; +} + +static int +validate_pattern(struct validator *state, pattern_ty p) +{ + int ret = -1; + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during compilation"); + return 0; + } + // Coming soon: https://bugs.python.org/issue43897 (thanks Batuhan)! + // TODO: Ensure no subnodes use "_" as an ordinary identifier + switch (p->kind) { + case MatchValue_kind: + ret = validate_pattern_match_value(state, p->v.MatchValue.value); + break; + case MatchSingleton_kind: + // TODO: Check constant is specifically None, True, or False + ret = validate_constant(state, p->v.MatchSingleton.value); + break; + case MatchSequence_kind: + // TODO: Validate all subpatterns + // return validate_patterns(state, p->v.MatchSequence.patterns); + ret = 1; + break; + case MatchMapping_kind: + // TODO: check "rest" target name is valid + if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) { + PyErr_SetString(PyExc_ValueError, + "MatchMapping doesn't have the same number of keys as patterns"); + return 0; + } + // null_ok=0 for key expressions, as rest-of-mapping is captured in "rest" + // TODO: replace with more restrictive expression validator, as per MatchValue above + if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) { + return 0; + } + // TODO: Validate all subpatterns + // ret = validate_patterns(state, p->v.MatchMapping.patterns); + ret = 1; + break; + case MatchClass_kind: + if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) { + PyErr_SetString(PyExc_ValueError, + "MatchClass doesn't have the same number of keyword attributes as patterns"); + return 0; + } + // TODO: Restrict cls lookup to being a name or attribute + if (!validate_expr(state, p->v.MatchClass.cls, Load)) { + return 0; + } + // TODO: Validate all subpatterns + // return validate_patterns(state, p->v.MatchClass.patterns) && + // validate_patterns(state, p->v.MatchClass.kwd_patterns); + ret = 1; + break; + case MatchStar_kind: + // TODO: check target name is valid + ret = 1; + break; + case MatchAs_kind: + // TODO: check target name is valid + if (p->v.MatchAs.pattern == NULL) { + ret = 1; + } + else if (p->v.MatchAs.name == NULL) { + PyErr_SetString(PyExc_ValueError, + "MatchAs must specify a target name if a pattern is given"); + return 0; + } + else { + ret = validate_pattern(state, p->v.MatchAs.pattern); + } + break; + case MatchOr_kind: + // TODO: Validate all subpatterns + // return validate_patterns(state, p->v.MatchOr.patterns); + ret = 1; + break; + // No default case, so the compiler will emit a warning if new pattern + // kinds are added without being handled here + } + if (ret < 0) { + PyErr_SetString(PyExc_SystemError, "unexpected pattern"); + ret = 0; + } + state->recursion_depth--; + return ret; +} + +static int _validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner) { if (asdl_seq_LEN(seq)) @@ -404,7 +588,7 @@ validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner) static int validate_stmt(struct validator *state, stmt_ty stmt) { - int ret; + int ret = -1; Py_ssize_t i; if (++state->recursion_depth > state->recursion_limit) { PyErr_SetString(PyExc_RecursionError, @@ -502,7 +686,7 @@ validate_stmt(struct validator *state, stmt_ty stmt) } for (i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) { match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i); - if (!validate_pattern(m->pattern) + if (!validate_pattern(state, m->pattern) || (m->guard && !validate_expr(state, m->guard, Load)) || !validate_body(state, m->body, "match_case")) { return 0; @@ -582,9 +766,11 @@ validate_stmt(struct validator *state, stmt_ty stmt) case Continue_kind: ret = 1; break; - default: + // No default case so compiler emits warning for unhandled cases + } + if (ret < 0) { PyErr_SetString(PyExc_SystemError, "unexpected statement"); - return 0; + ret = 0; } state->recursion_depth--; return ret; @@ -635,7 +821,7 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct int _PyAST_Validate(mod_ty mod) { - int res = 0; + int res = -1; struct validator state; PyThreadState *tstate; int recursion_limit = Py_GetRecursionLimit(); @@ -663,10 +849,16 @@ _PyAST_Validate(mod_ty mod) case Expression_kind: res = validate_expr(&state, mod->v.Expression.body, Load); break; - default: - PyErr_SetString(PyExc_SystemError, "impossible module node"); - res = 0; + case FunctionType_kind: + res = validate_exprs(&state, mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) && + validate_expr(&state, mod->v.FunctionType.returns, Load); break; + // No default case so compiler emits warning for unhandled cases + } + + if (res < 0) { + PyErr_SetString(PyExc_SystemError, "impossible module node"); + return 0; } /* Check that the recursion depth counting balanced correctly */ diff --git a/Python/ast_opt.c b/Python/ast_opt.c index 6eb514e..254dd64 100644 --- a/Python/ast_opt.c +++ b/Python/ast_opt.c @@ -411,7 +411,7 @@ static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); -static int astfold_pattern(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); +static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); #define CALL(FUNC, TYPE, ARG) \ if (!FUNC((ARG), ctx_, state)) \ @@ -602,10 +602,6 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) case Constant_kind: // Already a constant, nothing further to do break; - case MatchAs_kind: - case MatchOr_kind: - // These can't occur outside of patterns. - Py_UNREACHABLE(); // No default case, so the compiler will emit a warning if new expression // kinds are added without being handled here } @@ -797,112 +793,48 @@ astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) } static int -astfold_pattern_negative(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) +astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) { - assert(node_->kind == UnaryOp_kind); - assert(node_->v.UnaryOp.op == USub); - assert(node_->v.UnaryOp.operand->kind == Constant_kind); - PyObject *value = node_->v.UnaryOp.operand->v.Constant.value; - assert(PyComplex_CheckExact(value) || - PyFloat_CheckExact(value) || - PyLong_CheckExact(value)); - PyObject *negated = PyNumber_Negative(value); - if (negated == NULL) { - return 0; - } - assert(PyComplex_CheckExact(negated) || - PyFloat_CheckExact(negated) || - PyLong_CheckExact(negated)); - return make_const(node_, negated, ctx_); -} - -static int -astfold_pattern_complex(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) -{ - expr_ty left = node_->v.BinOp.left; - expr_ty right = node_->v.BinOp.right; - if (left->kind == UnaryOp_kind) { - CALL(astfold_pattern_negative, expr_ty, left); - } - assert(left->kind = Constant_kind); - assert(right->kind = Constant_kind); - // LHS must be real, RHS must be imaginary: - if (!(PyFloat_CheckExact(left->v.Constant.value) || - PyLong_CheckExact(left->v.Constant.value)) || - !PyComplex_CheckExact(right->v.Constant.value)) - { - // Not actually valid, but it's the compiler's job to complain: - return 1; - } - PyObject *new; - if (node_->v.BinOp.op == Add) { - new = PyNumber_Add(left->v.Constant.value, right->v.Constant.value); - } - else { - assert(node_->v.BinOp.op == Sub); - new = PyNumber_Subtract(left->v.Constant.value, right->v.Constant.value); - } - if (new == NULL) { + // Currently, this is really only used to form complex/negative numeric + // constants in MatchValue and MatchMapping nodes + // We still recurse into all subexpressions and subpatterns anyway + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during compilation"); return 0; } - assert(PyComplex_CheckExact(new)); - return make_const(node_, new, ctx_); -} - -static int -astfold_pattern_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) -{ - CALL(astfold_pattern, expr_ty, node_->value); - return 1; -} - -static int -astfold_pattern(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) -{ - // Don't blindly optimize the pattern as an expr; it plays by its own rules! - // Currently, this is only used to form complex/negative numeric constants. switch (node_->kind) { - case Attribute_kind: + case MatchValue_kind: + CALL(astfold_expr, expr_ty, node_->v.MatchValue.value); break; - case BinOp_kind: - CALL(astfold_pattern_complex, expr_ty, node_); + case MatchSingleton_kind: break; - case Call_kind: - CALL_SEQ(astfold_pattern, expr, node_->v.Call.args); - CALL_SEQ(astfold_pattern_keyword, keyword, node_->v.Call.keywords); + case MatchSequence_kind: + CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns); break; - case Constant_kind: + case MatchMapping_kind: + CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys); + CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns); break; - case Dict_kind: - CALL_SEQ(astfold_pattern, expr, node_->v.Dict.keys); - CALL_SEQ(astfold_pattern, expr, node_->v.Dict.values); + case MatchClass_kind: + CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls); + CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns); + CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns); break; - // Not actually valid, but it's the compiler's job to complain: - case JoinedStr_kind: - break; - case List_kind: - CALL_SEQ(astfold_pattern, expr, node_->v.List.elts); + case MatchStar_kind: break; case MatchAs_kind: - CALL(astfold_pattern, expr_ty, node_->v.MatchAs.pattern); + if (node_->v.MatchAs.pattern) { + CALL(astfold_pattern, expr_ty, node_->v.MatchAs.pattern); + } break; case MatchOr_kind: - CALL_SEQ(astfold_pattern, expr, node_->v.MatchOr.patterns); - break; - case Name_kind: - break; - case Starred_kind: - CALL(astfold_pattern, expr_ty, node_->v.Starred.value); + CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns); break; - case Tuple_kind: - CALL_SEQ(astfold_pattern, expr, node_->v.Tuple.elts); - break; - case UnaryOp_kind: - CALL(astfold_pattern_negative, expr_ty, node_); - break; - default: - Py_UNREACHABLE(); + // No default case, so the compiler will emit a warning if new pattern + // kinds are added without being handled here } + state->recursion_depth--; return 1; } diff --git a/Python/ast_unparse.c b/Python/ast_unparse.c index 5276b2f..126e904 100644 --- a/Python/ast_unparse.c +++ b/Python/ast_unparse.c @@ -3,6 +3,11 @@ #include <float.h> // DBL_MAX_10_EXP #include <stdbool.h> +/* This limited unparser is used to convert annotations back to strings + * during compilation rather than being a full AST unparser. + * See ast.unparse for a full unparser (written in Python) + */ + static PyObject *_str_open_br; static PyObject *_str_dbl_open_br; static PyObject *_str_close_br; @@ -912,11 +917,11 @@ append_ast_expr(_PyUnicodeWriter *writer, expr_ty e, int level) return append_ast_tuple(writer, e, level); case NamedExpr_kind: return append_named_expr(writer, e, level); - default: - PyErr_SetString(PyExc_SystemError, - "unknown expression kind"); - return -1; + // No default so compiler emits a warning for unhandled cases } + PyErr_SetString(PyExc_SystemError, + "unknown expression kind"); + return -1; } static int diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c index 3b0e59a..66a74cb 100644 --- a/Python/bltinmodule.c +++ b/Python/bltinmodule.c @@ -831,11 +831,7 @@ builtin_compile_impl(PyObject *module, PyObject *source, PyObject *filename, if (arena == NULL) goto error; mod = PyAST_obj2mod(source, arena, compile_mode); - if (mod == NULL) { - _PyArena_Free(arena); - goto error; - } - if (!_PyAST_Validate(mod)) { + if (mod == NULL || !_PyAST_Validate(mod)) { _PyArena_Free(arena); goto error; } diff --git a/Python/compile.c b/Python/compile.c index 2cf2f4a..3cf6122 100644 --- a/Python/compile.c +++ b/Python/compile.c @@ -280,9 +280,9 @@ static int compiler_async_comprehension_generator( int depth, expr_ty elt, expr_ty val, int type); -static int compiler_pattern(struct compiler *, expr_ty, pattern_context *); +static int compiler_pattern(struct compiler *, pattern_ty, pattern_context *); static int compiler_match(struct compiler *, stmt_ty); -static int compiler_pattern_subpattern(struct compiler *, expr_ty, +static int compiler_pattern_subpattern(struct compiler *, pattern_ty, pattern_context *); static PyCodeObject *assemble(struct compiler *, int addNone); @@ -5263,10 +5263,6 @@ compiler_visit_expr1(struct compiler *c, expr_ty e) return compiler_list(c, e); case Tuple_kind: return compiler_tuple(c, e); - case MatchAs_kind: - case MatchOr_kind: - // Can only occur in patterns, which are handled elsewhere. - Py_UNREACHABLE(); } return 1; } @@ -5600,16 +5596,22 @@ compiler_slice(struct compiler *c, expr_ty s) // that it's much easier to smooth out any redundant pushing, popping, and // jumping in the peephole optimizer than to detect or predict it here. - #define WILDCARD_CHECK(N) \ - ((N)->kind == Name_kind && \ - _PyUnicode_EqualToASCIIString((N)->v.Name.id, "_")) + ((N)->kind == MatchAs_kind && !(N)->v.MatchAs.name) +#define WILDCARD_STAR_CHECK(N) \ + ((N)->kind == MatchStar_kind && !(N)->v.MatchStar.name) + +// Limit permitted subexpressions, even if the parser & AST validator let them through +#define MATCH_VALUE_EXPR(N) \ + ((N)->kind == Constant_kind || (N)->kind == Attribute_kind) static int pattern_helper_store_name(struct compiler *c, identifier n, pattern_context *pc) { - assert(!_PyUnicode_EqualToASCIIString(n, "_")); + if (forbidden_name(c, n, Store)) { + return 0; + } // Can't assign to the same name twice: if (pc->stores == NULL) { RETURN_IF_FALSE(pc->stores = PySet_New(NULL)); @@ -5631,16 +5633,43 @@ pattern_helper_store_name(struct compiler *c, identifier n, pattern_context *pc) static int -pattern_helper_sequence_unpack(struct compiler *c, asdl_expr_seq *values, +pattern_unpack_helper(struct compiler *c, asdl_pattern_seq *elts) +{ + Py_ssize_t n = asdl_seq_LEN(elts); + int seen_star = 0; + for (Py_ssize_t i = 0; i < n; i++) { + pattern_ty elt = asdl_seq_GET(elts, i); + if (elt->kind == MatchStar_kind && !seen_star) { + if ((i >= (1 << 8)) || + (n-i-1 >= (INT_MAX >> 8))) + return compiler_error(c, + "too many expressions in " + "star-unpacking sequence pattern"); + ADDOP_I(c, UNPACK_EX, (i + ((n-i-1) << 8))); + seen_star = 1; + } + else if (elt->kind == MatchStar_kind) { + return compiler_error(c, + "multiple starred expressions in sequence pattern"); + } + } + if (!seen_star) { + ADDOP_I(c, UNPACK_SEQUENCE, n); + } + return 1; +} + +static int +pattern_helper_sequence_unpack(struct compiler *c, asdl_pattern_seq *patterns, Py_ssize_t star, pattern_context *pc) { - RETURN_IF_FALSE(unpack_helper(c, values)); + RETURN_IF_FALSE(pattern_unpack_helper(c, patterns)); // We've now got a bunch of new subjects on the stack. If any of them fail // to match, we need to pop everything else off, then finally push False. // fails is an array of blocks that correspond to the necessary amount of // popping for each element: basicblock **fails; - Py_ssize_t size = asdl_seq_LEN(values); + Py_ssize_t size = asdl_seq_LEN(patterns); fails = (basicblock **)PyObject_Malloc(sizeof(basicblock*) * size); if (fails == NULL) { PyErr_NoMemory(); @@ -5655,12 +5684,9 @@ pattern_helper_sequence_unpack(struct compiler *c, asdl_expr_seq *values, } } for (Py_ssize_t i = 0; i < size; i++) { - expr_ty value = asdl_seq_GET(values, i); - if (i == star) { - assert(value->kind == Starred_kind); - value = value->v.Starred.value; - } - if (!compiler_pattern_subpattern(c, value, pc) || + pattern_ty pattern = asdl_seq_GET(patterns, i); + assert((i == star) == (pattern->kind == MatchStar_kind)); + if (!compiler_pattern_subpattern(c, pattern, pc) || !compiler_addop_j(c, POP_JUMP_IF_FALSE, fails[i]) || compiler_next_block(c) == NULL) { @@ -5703,21 +5729,20 @@ error: // UNPACK_SEQUENCE / UNPACK_EX. This is more efficient for patterns with a // starred wildcard like [first, *_] / [first, *_, last] / [*_, last] / etc. static int -pattern_helper_sequence_subscr(struct compiler *c, asdl_expr_seq *values, +pattern_helper_sequence_subscr(struct compiler *c, asdl_pattern_seq *patterns, Py_ssize_t star, pattern_context *pc) { basicblock *end, *fail_pop_1; RETURN_IF_FALSE(end = compiler_new_block(c)); RETURN_IF_FALSE(fail_pop_1 = compiler_new_block(c)); - Py_ssize_t size = asdl_seq_LEN(values); + Py_ssize_t size = asdl_seq_LEN(patterns); for (Py_ssize_t i = 0; i < size; i++) { - expr_ty value = asdl_seq_GET(values, i); - if (WILDCARD_CHECK(value)) { + pattern_ty pattern = asdl_seq_GET(patterns, i); + if (WILDCARD_CHECK(pattern)) { continue; } if (i == star) { - assert(value->kind == Starred_kind); - assert(WILDCARD_CHECK(value->v.Starred.value)); + assert(WILDCARD_STAR_CHECK(pattern)); continue; } ADDOP(c, DUP_TOP); @@ -5732,7 +5757,7 @@ pattern_helper_sequence_subscr(struct compiler *c, asdl_expr_seq *values, ADDOP(c, BINARY_SUBTRACT); } ADDOP(c, BINARY_SUBSCR); - RETURN_IF_FALSE(compiler_pattern_subpattern(c, value, pc)); + RETURN_IF_FALSE(compiler_pattern_subpattern(c, pattern, pc)); ADDOP_JUMP(c, POP_JUMP_IF_FALSE, fail_pop_1); NEXT_BLOCK(c); } @@ -5746,10 +5771,9 @@ pattern_helper_sequence_subscr(struct compiler *c, asdl_expr_seq *values, return 1; } - // Like compiler_pattern, but turn off checks for irrefutability. static int -compiler_pattern_subpattern(struct compiler *c, expr_ty p, pattern_context *pc) +compiler_pattern_subpattern(struct compiler *c, pattern_ty p, pattern_context *pc) { int allow_irrefutable = pc->allow_irrefutable; pc->allow_irrefutable = 1; @@ -5758,11 +5782,43 @@ compiler_pattern_subpattern(struct compiler *c, expr_ty p, pattern_context *pc) return 1; } +static int +compiler_pattern_capture(struct compiler *c, identifier n, pattern_context *pc) +{ + RETURN_IF_FALSE(pattern_helper_store_name(c, n, pc)); + ADDOP_LOAD_CONST(c, Py_True); + return 1; +} static int -compiler_pattern_as(struct compiler *c, expr_ty p, pattern_context *pc) +compiler_pattern_wildcard(struct compiler *c, pattern_ty p, pattern_context *pc) { assert(p->kind == MatchAs_kind); + if (!pc->allow_irrefutable) { + // Whoops, can't have a wildcard here! + const char *e = "wildcard makes remaining patterns unreachable"; + return compiler_error(c, e); + } + ADDOP(c, POP_TOP); + ADDOP_LOAD_CONST(c, Py_True); + return 1; +} + +static int +compiler_pattern_as(struct compiler *c, pattern_ty p, pattern_context *pc) +{ + assert(p->kind == MatchAs_kind); + if (p->v.MatchAs.name == NULL) { + return compiler_pattern_wildcard(c, p, pc); + } + if (p->v.MatchAs.pattern == NULL) { + if (!pc->allow_irrefutable) { + // Whoops, can't have a name capture here! + const char *e = "name capture %R makes remaining patterns unreachable"; + return compiler_error(c, e, p->v.MatchAs.name); + } + return compiler_pattern_capture(c, p->v.MatchAs.name, pc); + } basicblock *end, *fail_pop_1; RETURN_IF_FALSE(end = compiler_new_block(c)); RETURN_IF_FALSE(fail_pop_1 = compiler_new_block(c)); @@ -5782,71 +5838,101 @@ compiler_pattern_as(struct compiler *c, expr_ty p, pattern_context *pc) return 1; } - static int -compiler_pattern_capture(struct compiler *c, expr_ty p, pattern_context *pc) +compiler_pattern_star(struct compiler *c, pattern_ty p, pattern_context *pc) { - assert(p->kind == Name_kind); - assert(p->v.Name.ctx == Store); - assert(!WILDCARD_CHECK(p)); + assert(p->kind == MatchStar_kind); if (!pc->allow_irrefutable) { - // Whoops, can't have a name capture here! - const char *e = "name capture %R makes remaining patterns unreachable"; - return compiler_error(c, e, p->v.Name.id); + // Whoops, can't have a star capture here! + const char *e = "star captures are only allowed as part of sequence patterns"; + return compiler_error(c, e); } - RETURN_IF_FALSE(pattern_helper_store_name(c, p->v.Name.id, pc)); - ADDOP_LOAD_CONST(c, Py_True); - return 1; + return compiler_pattern_capture(c, p->v.MatchStar.name, pc); } +static int +validate_kwd_attrs(struct compiler *c, asdl_identifier_seq *attrs, asdl_pattern_seq* patterns) +{ + // Any errors will point to the pattern rather than the arg name as the + // parser is only supplying identifiers rather than Name or keyword nodes + Py_ssize_t nattrs = asdl_seq_LEN(attrs); + for (Py_ssize_t i = 0; i < nattrs; i++) { + identifier attr = ((identifier)asdl_seq_GET(attrs, i)); + c->u->u_col_offset = ((pattern_ty) asdl_seq_GET(patterns, i))->col_offset; + if (forbidden_name(c, attr, Store)) { + return -1; + } + for (Py_ssize_t j = i + 1; j < nattrs; j++) { + identifier other = ((identifier)asdl_seq_GET(attrs, j)); + if (!PyUnicode_Compare(attr, other)) { + c->u->u_col_offset = ((pattern_ty) asdl_seq_GET(patterns, j))->col_offset; + compiler_error(c, "attribute name repeated in class pattern: %U", attr); + return -1; + } + } + } + return 0; +} static int -compiler_pattern_class(struct compiler *c, expr_ty p, pattern_context *pc) +compiler_pattern_class(struct compiler *c, pattern_ty p, pattern_context *pc) { - asdl_expr_seq *args = p->v.Call.args; - asdl_keyword_seq *kwargs = p->v.Call.keywords; - Py_ssize_t nargs = asdl_seq_LEN(args); - Py_ssize_t nkwargs = asdl_seq_LEN(kwargs); - if (INT_MAX < nargs || INT_MAX < nargs + nkwargs - 1) { + assert(p->kind == MatchClass_kind); + asdl_pattern_seq *patterns = p->v.MatchClass.patterns; + asdl_identifier_seq *kwd_attrs = p->v.MatchClass.kwd_attrs; + asdl_pattern_seq *kwd_patterns = p->v.MatchClass.kwd_patterns; + Py_ssize_t nargs = asdl_seq_LEN(patterns); + Py_ssize_t nattrs = asdl_seq_LEN(kwd_attrs); + Py_ssize_t nkwd_patterns = asdl_seq_LEN(kwd_patterns); + if (nattrs != nkwd_patterns) { + // AST validator shouldn't let this happen, but if it does, + // just fail, don't crash out of the interpreter + const char * e = "kwd_attrs (%d) / kwd_patterns (%d) length mismatch in class pattern"; + return compiler_error(c, e, nattrs, nkwd_patterns); + } + if (INT_MAX < nargs || INT_MAX < nargs + nattrs - 1) { const char *e = "too many sub-patterns in class pattern %R"; - return compiler_error(c, e, p->v.Call.func); + return compiler_error(c, e, p->v.MatchClass.cls); + } + if (nattrs) { + RETURN_IF_FALSE(!validate_kwd_attrs(c, kwd_attrs, kwd_patterns)); + c->u->u_col_offset = p->col_offset; // validate_kwd_attrs moves this } - RETURN_IF_FALSE(!validate_keywords(c, kwargs)); basicblock *end, *fail_pop_1; RETURN_IF_FALSE(end = compiler_new_block(c)); RETURN_IF_FALSE(fail_pop_1 = compiler_new_block(c)); - VISIT(c, expr, p->v.Call.func); - PyObject *kwnames; - RETURN_IF_FALSE(kwnames = PyTuple_New(nkwargs)); + VISIT(c, expr, p->v.MatchClass.cls); + PyObject *attr_names; + RETURN_IF_FALSE(attr_names = PyTuple_New(nattrs)); Py_ssize_t i; - for (i = 0; i < nkwargs; i++) { - PyObject *name = ((keyword_ty) asdl_seq_GET(kwargs, i))->arg; + for (i = 0; i < nattrs; i++) { + PyObject *name = asdl_seq_GET(kwd_attrs, i); Py_INCREF(name); - PyTuple_SET_ITEM(kwnames, i, name); + PyTuple_SET_ITEM(attr_names, i, name); } - ADDOP_LOAD_CONST_NEW(c, kwnames); + ADDOP_LOAD_CONST_NEW(c, attr_names); ADDOP_I(c, MATCH_CLASS, nargs); ADDOP_JUMP(c, POP_JUMP_IF_FALSE, fail_pop_1); NEXT_BLOCK(c); // TOS is now a tuple of (nargs + nkwargs) attributes. - for (i = 0; i < nargs + nkwargs; i++) { - expr_ty arg; + for (i = 0; i < nargs + nattrs; i++) { + pattern_ty pattern; if (i < nargs) { // Positional: - arg = asdl_seq_GET(args, i); + pattern = asdl_seq_GET(patterns, i); } else { // Keyword: - arg = ((keyword_ty) asdl_seq_GET(kwargs, i - nargs))->value; + pattern = asdl_seq_GET(kwd_patterns, i - nargs); } - if (WILDCARD_CHECK(arg)) { + if (WILDCARD_CHECK(pattern)) { continue; } // Get the i-th attribute, and match it against the i-th pattern: ADDOP(c, DUP_TOP); ADDOP_LOAD_CONST_NEW(c, PyLong_FromSsize_t(i)); ADDOP(c, BINARY_SUBSCR); - RETURN_IF_FALSE(compiler_pattern_subpattern(c, arg, pc)); + RETURN_IF_FALSE(compiler_pattern_subpattern(c, pattern, pc)); ADDOP_JUMP(c, POP_JUMP_IF_FALSE, fail_pop_1); NEXT_BLOCK(c); } @@ -5861,36 +5947,30 @@ compiler_pattern_class(struct compiler *c, expr_ty p, pattern_context *pc) return 1; } - static int -compiler_pattern_literal(struct compiler *c, expr_ty p, pattern_context *pc) -{ - assert(p->kind == Constant_kind); - PyObject *v = p->v.Constant.value; - ADDOP_LOAD_CONST(c, v); - // Literal True, False, and None are compared by identity. All others use - // equality: - ADDOP_COMPARE(c, (v == Py_None || PyBool_Check(v)) ? Is : Eq); - return 1; -} - - -static int -compiler_pattern_mapping(struct compiler *c, expr_ty p, pattern_context *pc) +compiler_pattern_mapping(struct compiler *c, pattern_ty p, pattern_context *pc) { + assert(p->kind == MatchMapping_kind); basicblock *end, *fail_pop_1, *fail_pop_3; RETURN_IF_FALSE(end = compiler_new_block(c)); RETURN_IF_FALSE(fail_pop_1 = compiler_new_block(c)); RETURN_IF_FALSE(fail_pop_3 = compiler_new_block(c)); - asdl_expr_seq *keys = p->v.Dict.keys; - asdl_expr_seq *values = p->v.Dict.values; - Py_ssize_t size = asdl_seq_LEN(values); - // A starred pattern will be a keyless value. It is guaranteed to be last: - int star = size ? !asdl_seq_GET(keys, size - 1) : 0; + asdl_expr_seq *keys = p->v.MatchMapping.keys; + asdl_pattern_seq *patterns = p->v.MatchMapping.patterns; + Py_ssize_t size = asdl_seq_LEN(keys); + Py_ssize_t npatterns = asdl_seq_LEN(patterns); + if (size != npatterns) { + // AST validator shouldn't let this happen, but if it does, + // just fail, don't crash out of the interpreter + const char * e = "keys (%d) / patterns (%d) length mismatch in mapping pattern"; + return compiler_error(c, e, size, npatterns); + } + // We have a double-star target if "rest" is set + PyObject *star_target = p->v.MatchMapping.rest; ADDOP(c, MATCH_MAPPING); ADDOP_JUMP(c, POP_JUMP_IF_FALSE, fail_pop_1); NEXT_BLOCK(c); - if (!size) { + if (!size && !star_target) { // If the pattern is just "{}", we're done! ADDOP(c, POP_TOP); ADDOP_LOAD_CONST(c, Py_True); @@ -5901,53 +5981,57 @@ compiler_pattern_mapping(struct compiler *c, expr_ty p, pattern_context *pc) compiler_use_next_block(c, end); return 1; } - if (size - star) { + if (size) { // If the pattern has any keys in it, perform a length check: ADDOP(c, GET_LEN); - ADDOP_LOAD_CONST_NEW(c, PyLong_FromSsize_t(size - star)); + ADDOP_LOAD_CONST_NEW(c, PyLong_FromSsize_t(size)); ADDOP_COMPARE(c, GtE); ADDOP_JUMP(c, POP_JUMP_IF_FALSE, fail_pop_1); NEXT_BLOCK(c); } - if (INT_MAX < size - star - 1) { + if (INT_MAX < size - 1) { return compiler_error(c, "too many sub-patterns in mapping pattern"); } // Collect all of the keys into a tuple for MATCH_KEYS and // COPY_DICT_WITHOUT_KEYS. They can either be dotted names or literals: - for (Py_ssize_t i = 0; i < size - star; i++) { + for (Py_ssize_t i = 0; i < size; i++) { expr_ty key = asdl_seq_GET(keys, i); if (key == NULL) { - const char *e = "can't use starred name here " - "(consider moving to end)"; + const char *e = "can't use NULL keys in MatchMapping " + "(set 'rest' parameter instead)"; + c->u->u_col_offset = ((pattern_ty) asdl_seq_GET(patterns, i))->col_offset; + return compiler_error(c, e); + } + if (!MATCH_VALUE_EXPR(key)) { + const char *e = "mapping pattern keys may only match literals and attribute lookups"; return compiler_error(c, e); } VISIT(c, expr, key); } - ADDOP_I(c, BUILD_TUPLE, size - star); + ADDOP_I(c, BUILD_TUPLE, size); ADDOP(c, MATCH_KEYS); ADDOP_JUMP(c, POP_JUMP_IF_FALSE, fail_pop_3); NEXT_BLOCK(c); // So far so good. There's now a tuple of values on the stack to match // sub-patterns against: - for (Py_ssize_t i = 0; i < size - star; i++) { - expr_ty value = asdl_seq_GET(values, i); - if (WILDCARD_CHECK(value)) { + for (Py_ssize_t i = 0; i < size; i++) { + pattern_ty pattern = asdl_seq_GET(patterns, i); + if (WILDCARD_CHECK(pattern)) { continue; } ADDOP(c, DUP_TOP); ADDOP_LOAD_CONST_NEW(c, PyLong_FromSsize_t(i)); ADDOP(c, BINARY_SUBSCR); - RETURN_IF_FALSE(compiler_pattern_subpattern(c, value, pc)); + RETURN_IF_FALSE(compiler_pattern_subpattern(c, pattern, pc)); ADDOP_JUMP(c, POP_JUMP_IF_FALSE, fail_pop_3); NEXT_BLOCK(c); } // If we get this far, it's a match! We're done with that tuple of values. ADDOP(c, POP_TOP); - if (star) { - // If we had a starred name, bind a dict of remaining items to it: + if (star_target) { + // If we have a starred name, bind a dict of remaining items to it: ADDOP(c, COPY_DICT_WITHOUT_KEYS); - PyObject *id = asdl_seq_GET(values, size - 1)->v.Name.id; - RETURN_IF_FALSE(pattern_helper_store_name(c, id, pc)); + RETURN_IF_FALSE(pattern_helper_store_name(c, star_target, pc)); } else { // Otherwise, we don't care about this tuple of keys anymore: @@ -5970,9 +6054,8 @@ compiler_pattern_mapping(struct compiler *c, expr_ty p, pattern_context *pc) return 1; } - static int -compiler_pattern_or(struct compiler *c, expr_ty p, pattern_context *pc) +compiler_pattern_or(struct compiler *c, pattern_ty p, pattern_context *pc) { assert(p->kind == MatchOr_kind); // control is the set of names bound by the first alternative. If all of the @@ -5988,7 +6071,7 @@ compiler_pattern_or(struct compiler *c, expr_ty p, pattern_context *pc) int allow_irrefutable = pc->allow_irrefutable; for (Py_ssize_t i = 0; i < size; i++) { // NOTE: Can't use our nice returning macros in here: they'll leak sets! - expr_ty alt = asdl_seq_GET(p->v.MatchOr.patterns, i); + pattern_ty alt = asdl_seq_GET(p->v.MatchOr.patterns, i); pc->stores = PySet_New(stores_init); // An irrefutable sub-pattern must be last, if it is allowed at all: int is_last = i == size - 1; @@ -6044,28 +6127,28 @@ fail: static int -compiler_pattern_sequence(struct compiler *c, expr_ty p, pattern_context *pc) +compiler_pattern_sequence(struct compiler *c, pattern_ty p, pattern_context *pc) { - assert(p->kind == List_kind || p->kind == Tuple_kind); - asdl_expr_seq *values = (p->kind == Tuple_kind) ? p->v.Tuple.elts - : p->v.List.elts; - Py_ssize_t size = asdl_seq_LEN(values); + assert(p->kind == MatchSequence_kind); + asdl_pattern_seq *patterns = p->v.MatchSequence.patterns; + Py_ssize_t size = asdl_seq_LEN(patterns); Py_ssize_t star = -1; int only_wildcard = 1; int star_wildcard = 0; // Find a starred name, if it exists. There may be at most one: for (Py_ssize_t i = 0; i < size; i++) { - expr_ty value = asdl_seq_GET(values, i); - if (value->kind == Starred_kind) { - value = value->v.Starred.value; + pattern_ty pattern = asdl_seq_GET(patterns, i); + if (pattern->kind == MatchStar_kind) { if (star >= 0) { const char *e = "multiple starred names in sequence pattern"; return compiler_error(c, e); } - star_wildcard = WILDCARD_CHECK(value); + star_wildcard = WILDCARD_STAR_CHECK(pattern); + only_wildcard &= star_wildcard; star = i; + continue; } - only_wildcard &= WILDCARD_CHECK(value); + only_wildcard &= WILDCARD_CHECK(pattern); } basicblock *end, *fail_pop_1; RETURN_IF_FALSE(end = compiler_new_block(c)); @@ -6095,10 +6178,10 @@ compiler_pattern_sequence(struct compiler *c, expr_ty p, pattern_context *pc) ADDOP_LOAD_CONST(c, Py_True); } else if (star_wildcard) { - RETURN_IF_FALSE(pattern_helper_sequence_subscr(c, values, star, pc)); + RETURN_IF_FALSE(pattern_helper_sequence_subscr(c, patterns, star, pc)); } else { - RETURN_IF_FALSE(pattern_helper_sequence_unpack(c, values, star, pc)); + RETURN_IF_FALSE(pattern_helper_sequence_unpack(c, patterns, star, pc)); } ADDOP_JUMP(c, JUMP_FORWARD, end); compiler_use_next_block(c, fail_pop_1); @@ -6108,72 +6191,57 @@ compiler_pattern_sequence(struct compiler *c, expr_ty p, pattern_context *pc) return 1; } - static int -compiler_pattern_value(struct compiler *c, expr_ty p, pattern_context *pc) +compiler_pattern_value(struct compiler *c, pattern_ty p, pattern_context *pc) { - assert(p->kind == Attribute_kind); - assert(p->v.Attribute.ctx == Load); - VISIT(c, expr, p); + assert(p->kind == MatchValue_kind); + expr_ty value = p->v.MatchValue.value; + if (!MATCH_VALUE_EXPR(value)) { + const char *e = "patterns may only match literals and attribute lookups"; + return compiler_error(c, e); + } + VISIT(c, expr, value); ADDOP_COMPARE(c, Eq); return 1; } - static int -compiler_pattern_wildcard(struct compiler *c, expr_ty p, pattern_context *pc) +compiler_pattern_constant(struct compiler *c, pattern_ty p, pattern_context *pc) { - assert(p->kind == Name_kind); - assert(p->v.Name.ctx == Store); - assert(WILDCARD_CHECK(p)); - if (!pc->allow_irrefutable) { - // Whoops, can't have a wildcard here! - const char *e = "wildcard makes remaining patterns unreachable"; - return compiler_error(c, e); - } - ADDOP(c, POP_TOP); - ADDOP_LOAD_CONST(c, Py_True); + assert(p->kind == MatchSingleton_kind); + ADDOP_LOAD_CONST(c, p->v.MatchSingleton.value); + ADDOP_COMPARE(c, Is); return 1; } - static int -compiler_pattern(struct compiler *c, expr_ty p, pattern_context *pc) +compiler_pattern(struct compiler *c, pattern_ty p, pattern_context *pc) { SET_LOC(c, p); switch (p->kind) { - case Attribute_kind: + case MatchValue_kind: return compiler_pattern_value(c, p, pc); - case BinOp_kind: - // Because we allow "2+2j", things like "2+2" make it this far: - return compiler_error(c, "patterns cannot include operators"); - case Call_kind: - return compiler_pattern_class(c, p, pc); - case Constant_kind: - return compiler_pattern_literal(c, p, pc); - case Dict_kind: - return compiler_pattern_mapping(c, p, pc); - case JoinedStr_kind: - // Because we allow strings, f-strings make it this far: - return compiler_error(c, "patterns cannot include f-strings"); - case List_kind: - case Tuple_kind: + case MatchSingleton_kind: + return compiler_pattern_constant(c, p, pc); + case MatchSequence_kind: return compiler_pattern_sequence(c, p, pc); + case MatchMapping_kind: + return compiler_pattern_mapping(c, p, pc); + case MatchClass_kind: + return compiler_pattern_class(c, p, pc); + case MatchStar_kind: + return compiler_pattern_star(c, p, pc); case MatchAs_kind: return compiler_pattern_as(c, p, pc); case MatchOr_kind: return compiler_pattern_or(c, p, pc); - case Name_kind: - if (WILDCARD_CHECK(p)) { - return compiler_pattern_wildcard(c, p, pc); - } - return compiler_pattern_capture(c, p, pc); - default: - Py_UNREACHABLE(); } + // AST validator shouldn't let this happen, but if it does, + // just fail, don't crash out of the interpreter + const char *e = "invalid match pattern node in AST (kind=%d)"; + return compiler_error(c, e, p->kind); } - static int compiler_match(struct compiler *c, stmt_ty s) { @@ -6181,7 +6249,7 @@ compiler_match(struct compiler *c, stmt_ty s) basicblock *next, *end; RETURN_IF_FALSE(end = compiler_new_block(c)); Py_ssize_t cases = asdl_seq_LEN(s->v.Match.cases); - assert(cases); + assert(cases > 0); pattern_context pc; // We use pc.stores to track: // - Repeated name assignments in the same pattern. @@ -6235,9 +6303,8 @@ compiler_match(struct compiler *c, stmt_ty s) return 1; } - #undef WILDCARD_CHECK - +#undef WILDCARD_STAR_CHECK /* End of the compiler section, beginning of the assembler section */ diff --git a/Python/symtable.c b/Python/symtable.c index c6f8694..e620f1e 100644 --- a/Python/symtable.c +++ b/Python/symtable.c @@ -214,6 +214,7 @@ static int symtable_implicit_arg(struct symtable *st, int pos); static int symtable_visit_annotations(struct symtable *st, arguments_ty, expr_ty); static int symtable_visit_withitem(struct symtable *st, withitem_ty item); static int symtable_visit_match_case(struct symtable *st, match_case_ty m); +static int symtable_visit_pattern(struct symtable *st, pattern_ty s); static identifier top = NULL, lambda = NULL, genexpr = NULL, @@ -246,7 +247,6 @@ symtable_new(void) goto fail; st->st_cur = NULL; st->st_private = NULL; - st->in_pattern = 0; return st; fail: _PySymtable_Free(st); @@ -1676,13 +1676,6 @@ symtable_visit_expr(struct symtable *st, expr_ty e) VISIT(st, expr, e->v.Slice.step) break; case Name_kind: - // Don't make "_" a local when used in a pattern: - if (st->in_pattern && - e->v.Name.ctx == Store && - _PyUnicode_EqualToASCIIString(e->v.Name.id, "_")) - { - break; - } if (!symtable_add_def(st, e->v.Name.id, e->v.Name.ctx == Load ? USE : DEF_LOCAL)) VISIT_QUIT(st, 0); @@ -1702,12 +1695,55 @@ symtable_visit_expr(struct symtable *st, expr_ty e) case Tuple_kind: VISIT_SEQ(st, expr, e->v.Tuple.elts); break; + } + VISIT_QUIT(st, 1); +} + +static int +symtable_visit_pattern(struct symtable *st, pattern_ty p) +{ + if (++st->recursion_depth > st->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during compilation"); + VISIT_QUIT(st, 0); + } + switch (p->kind) { + case MatchValue_kind: + VISIT(st, expr, p->v.MatchValue.value); + break; + case MatchSingleton_kind: + /* Nothing to do here. */ + break; + case MatchSequence_kind: + VISIT_SEQ(st, pattern, p->v.MatchSequence.patterns); + break; + case MatchStar_kind: + if (p->v.MatchStar.name) { + symtable_add_def(st, p->v.MatchStar.name, DEF_LOCAL); + } + break; + case MatchMapping_kind: + VISIT_SEQ(st, expr, p->v.MatchMapping.keys); + VISIT_SEQ(st, pattern, p->v.MatchMapping.patterns); + if (p->v.MatchMapping.rest) { + symtable_add_def(st, p->v.MatchMapping.rest, DEF_LOCAL); + } + break; + case MatchClass_kind: + VISIT(st, expr, p->v.MatchClass.cls); + VISIT_SEQ(st, pattern, p->v.MatchClass.patterns); + VISIT_SEQ(st, pattern, p->v.MatchClass.kwd_patterns); + break; case MatchAs_kind: - VISIT(st, expr, e->v.MatchAs.pattern); - symtable_add_def(st, e->v.MatchAs.name, DEF_LOCAL); + if (p->v.MatchAs.pattern) { + VISIT(st, pattern, p->v.MatchAs.pattern); + } + if (p->v.MatchAs.name) { + symtable_add_def(st, p->v.MatchAs.name, DEF_LOCAL); + } break; case MatchOr_kind: - VISIT_SEQ(st, expr, e->v.MatchOr.patterns); + VISIT_SEQ(st, pattern, p->v.MatchOr.patterns); break; } VISIT_QUIT(st, 1); @@ -1830,11 +1866,7 @@ symtable_visit_withitem(struct symtable *st, withitem_ty item) static int symtable_visit_match_case(struct symtable *st, match_case_ty m) { - assert(!st->in_pattern); - st->in_pattern = 1; - VISIT(st, expr, m->pattern); - assert(st->in_pattern); - st->in_pattern = 0; + VISIT(st, pattern, m->pattern); if (m->guard) { VISIT(st, expr, m->guard); } |