diff options
Diffstat (limited to 'Python/ast.c')
-rw-r--r-- | Python/ast.c | 238 |
1 files changed, 215 insertions, 23 deletions
diff --git a/Python/ast.c b/Python/ast.c index 2b96543..1fc83f6 100644 --- a/Python/ast.c +++ b/Python/ast.c @@ -7,6 +7,7 @@ #include "pycore_pystate.h" // _PyThreadState_GET() #include <assert.h> +#include <stdbool.h> struct validator { int recursion_depth; /* current recursion depth */ @@ -18,6 +19,7 @@ static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, i static int _validate_nonempty_seq(asdl_seq *, const char *, const char *); static int validate_stmt(struct validator *, stmt_ty); static int validate_expr(struct validator *, expr_ty, expr_context_ty); +static int validate_pattern(struct validator *, pattern_ty); static int validate_name(PyObject *name) @@ -88,9 +90,9 @@ expr_context_name(expr_context_ty ctx) return "Store"; case Del: return "Del"; - default: - Py_UNREACHABLE(); + // No default case so compiler emits warning for unhandled cases } + Py_UNREACHABLE(); } static int @@ -180,7 +182,7 @@ validate_constant(struct validator *state, PyObject *value) static int validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx) { - int ret; + int ret = -1; if (++state->recursion_depth > state->recursion_limit) { PyErr_SetString(PyExc_RecursionError, "maximum recursion depth exceeded during compilation"); @@ -351,34 +353,216 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx) case NamedExpr_kind: ret = validate_expr(state, exp->v.NamedExpr.value, Load); break; - case MatchAs_kind: - PyErr_SetString(PyExc_ValueError, - "MatchAs is only valid in match_case patterns"); - return 0; - case MatchOr_kind: - PyErr_SetString(PyExc_ValueError, - "MatchOr is only valid in match_case patterns"); - return 0; /* This last case doesn't have any checking. */ case Name_kind: ret = 1; break; - default: + // No default case so compiler emits warning for unhandled cases + } + if (ret < 0) { PyErr_SetString(PyExc_SystemError, "unexpected expression"); - return 0; + ret = 0; } state->recursion_depth--; return ret; } + +// Note: the ensure_literal_* functions are only used to validate a restricted +// set of non-recursive literals that have already been checked with +// validate_expr, so they don't accept the validator state +static int +ensure_literal_number(expr_ty exp, bool allow_real, bool allow_imaginary) +{ + assert(exp->kind == Constant_kind); + PyObject *value = exp->v.Constant.value; + return (allow_real && PyFloat_CheckExact(value)) || + (allow_real && PyLong_CheckExact(value)) || + (allow_imaginary && PyComplex_CheckExact(value)); +} + static int -validate_pattern(expr_ty p) +ensure_literal_negative(expr_ty exp, bool allow_real, bool allow_imaginary) { - // Coming soon (thanks Batuhan)! + assert(exp->kind == UnaryOp_kind); + // Must be negation ... + if (exp->v.UnaryOp.op != USub) { + return 0; + } + // ... of a constant ... + expr_ty operand = exp->v.UnaryOp.operand; + if (operand->kind != Constant_kind) { + return 0; + } + // ... number + return ensure_literal_number(operand, allow_real, allow_imaginary); +} + +static int +ensure_literal_complex(expr_ty exp) +{ + assert(exp->kind == BinOp_kind); + expr_ty left = exp->v.BinOp.left; + expr_ty right = exp->v.BinOp.right; + // Ensure op is addition or subtraction + if (exp->v.BinOp.op != Add && exp->v.BinOp.op != Sub) { + return 0; + } + // Check LHS is a real number (potentially signed) + switch (left->kind) + { + case Constant_kind: + if (!ensure_literal_number(left, /*real=*/true, /*imaginary=*/false)) { + return 0; + } + break; + case UnaryOp_kind: + if (!ensure_literal_negative(left, /*real=*/true, /*imaginary=*/false)) { + return 0; + } + break; + default: + return 0; + } + // Check RHS is an imaginary number (no separate sign allowed) + switch (right->kind) + { + case Constant_kind: + if (!ensure_literal_number(right, /*real=*/false, /*imaginary=*/true)) { + return 0; + } + break; + default: + return 0; + } return 1; } static int +validate_pattern_match_value(struct validator *state, expr_ty exp) +{ + if (!validate_expr(state, exp, Load)) { + return 0; + } + + switch (exp->kind) + { + case Constant_kind: + case Attribute_kind: + // Constants and attribute lookups are always permitted + return 1; + case UnaryOp_kind: + // Negated numbers are permitted (whether real or imaginary) + // Compiler will complain if AST folding doesn't create a constant + if (ensure_literal_negative(exp, /*real=*/true, /*imaginary=*/true)) { + return 1; + } + break; + case BinOp_kind: + // Complex literals are permitted + // Compiler will complain if AST folding doesn't create a constant + if (ensure_literal_complex(exp)) { + return 1; + } + break; + default: + break; + } + PyErr_SetString(PyExc_SyntaxError, + "patterns may only match literals and attribute lookups"); + return 0; +} + +static int +validate_pattern(struct validator *state, pattern_ty p) +{ + int ret = -1; + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during compilation"); + return 0; + } + // Coming soon: https://bugs.python.org/issue43897 (thanks Batuhan)! + // TODO: Ensure no subnodes use "_" as an ordinary identifier + switch (p->kind) { + case MatchValue_kind: + ret = validate_pattern_match_value(state, p->v.MatchValue.value); + break; + case MatchSingleton_kind: + // TODO: Check constant is specifically None, True, or False + ret = validate_constant(state, p->v.MatchSingleton.value); + break; + case MatchSequence_kind: + // TODO: Validate all subpatterns + // return validate_patterns(state, p->v.MatchSequence.patterns); + ret = 1; + break; + case MatchMapping_kind: + // TODO: check "rest" target name is valid + if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) { + PyErr_SetString(PyExc_ValueError, + "MatchMapping doesn't have the same number of keys as patterns"); + return 0; + } + // null_ok=0 for key expressions, as rest-of-mapping is captured in "rest" + // TODO: replace with more restrictive expression validator, as per MatchValue above + if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) { + return 0; + } + // TODO: Validate all subpatterns + // ret = validate_patterns(state, p->v.MatchMapping.patterns); + ret = 1; + break; + case MatchClass_kind: + if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) { + PyErr_SetString(PyExc_ValueError, + "MatchClass doesn't have the same number of keyword attributes as patterns"); + return 0; + } + // TODO: Restrict cls lookup to being a name or attribute + if (!validate_expr(state, p->v.MatchClass.cls, Load)) { + return 0; + } + // TODO: Validate all subpatterns + // return validate_patterns(state, p->v.MatchClass.patterns) && + // validate_patterns(state, p->v.MatchClass.kwd_patterns); + ret = 1; + break; + case MatchStar_kind: + // TODO: check target name is valid + ret = 1; + break; + case MatchAs_kind: + // TODO: check target name is valid + if (p->v.MatchAs.pattern == NULL) { + ret = 1; + } + else if (p->v.MatchAs.name == NULL) { + PyErr_SetString(PyExc_ValueError, + "MatchAs must specify a target name if a pattern is given"); + return 0; + } + else { + ret = validate_pattern(state, p->v.MatchAs.pattern); + } + break; + case MatchOr_kind: + // TODO: Validate all subpatterns + // return validate_patterns(state, p->v.MatchOr.patterns); + ret = 1; + break; + // No default case, so the compiler will emit a warning if new pattern + // kinds are added without being handled here + } + if (ret < 0) { + PyErr_SetString(PyExc_SystemError, "unexpected pattern"); + ret = 0; + } + state->recursion_depth--; + return ret; +} + +static int _validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner) { if (asdl_seq_LEN(seq)) @@ -404,7 +588,7 @@ validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner) static int validate_stmt(struct validator *state, stmt_ty stmt) { - int ret; + int ret = -1; Py_ssize_t i; if (++state->recursion_depth > state->recursion_limit) { PyErr_SetString(PyExc_RecursionError, @@ -502,7 +686,7 @@ validate_stmt(struct validator *state, stmt_ty stmt) } for (i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) { match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i); - if (!validate_pattern(m->pattern) + if (!validate_pattern(state, m->pattern) || (m->guard && !validate_expr(state, m->guard, Load)) || !validate_body(state, m->body, "match_case")) { return 0; @@ -582,9 +766,11 @@ validate_stmt(struct validator *state, stmt_ty stmt) case Continue_kind: ret = 1; break; - default: + // No default case so compiler emits warning for unhandled cases + } + if (ret < 0) { PyErr_SetString(PyExc_SystemError, "unexpected statement"); - return 0; + ret = 0; } state->recursion_depth--; return ret; @@ -635,7 +821,7 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct int _PyAST_Validate(mod_ty mod) { - int res = 0; + int res = -1; struct validator state; PyThreadState *tstate; int recursion_limit = Py_GetRecursionLimit(); @@ -663,10 +849,16 @@ _PyAST_Validate(mod_ty mod) case Expression_kind: res = validate_expr(&state, mod->v.Expression.body, Load); break; - default: - PyErr_SetString(PyExc_SystemError, "impossible module node"); - res = 0; + case FunctionType_kind: + res = validate_exprs(&state, mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) && + validate_expr(&state, mod->v.FunctionType.returns, Load); break; + // No default case so compiler emits warning for unhandled cases + } + + if (res < 0) { + PyErr_SetString(PyExc_SystemError, "impossible module node"); + return 0; } /* Check that the recursion depth counting balanced correctly */ |