diff options
author | Batuhan Taskaya <batuhan@python.org> | 2021-07-28 17:14:45 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-28 17:14:45 (GMT) |
commit | 31bec6f1b178dadec3cb43353274b4e958a8f015 (patch) | |
tree | 007bb1ecbe4c69a0719a42076a18aa2fb8e098b6 | |
parent | 53b9458f2e9314703a5406ca817d757f1509882a (diff) | |
download | cpython-31bec6f1b178dadec3cb43353274b4e958a8f015.zip cpython-31bec6f1b178dadec3cb43353274b4e958a8f015.tar.gz cpython-31bec6f1b178dadec3cb43353274b4e958a8f015.tar.bz2 |
bpo-43897: AST validation for pattern matching nodes (GH24771)
-rw-r--r-- | Lib/test/test_ast.py | 143 | ||||
-rw-r--r-- | Python/ast.c | 154 |
2 files changed, 265 insertions, 32 deletions
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 5f1ee75..925bb88 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -696,7 +696,7 @@ class AST_Tests(unittest.TestCase): for constant in "True", "False", "None": expr = ast.Expression(ast.Name(constant, ast.Load())) ast.fix_missing_locations(expr) - with self.assertRaisesRegex(ValueError, f"Name node can't be used with '{constant}' constant"): + with self.assertRaisesRegex(ValueError, f"identifier field can't represent '{constant}' constant"): compile(expr, "<test>", "eval") def test_precedence_enum(self): @@ -1507,6 +1507,147 @@ class ASTValidatorTests(unittest.TestCase): mod = ast.parse(source, fn) compile(mod, fn, "exec") + constant_1 = ast.Constant(1) + pattern_1 = ast.MatchValue(constant_1) + + constant_x = ast.Constant('x') + pattern_x = ast.MatchValue(constant_x) + + constant_true = ast.Constant(True) + pattern_true = ast.MatchSingleton(True) + + name_carter = ast.Name('carter', ast.Load()) + + _MATCH_PATTERNS = [ + ast.MatchValue( + ast.Attribute( + ast.Attribute( + ast.Name('x', ast.Store()), + 'y', ast.Load() + ), + 'z', ast.Load() + ) + ), + ast.MatchValue( + ast.Attribute( + ast.Attribute( + ast.Name('x', ast.Load()), + 'y', ast.Store() + ), + 'z', ast.Load() + ) + ), + ast.MatchValue( + ast.Constant(...) + ), + ast.MatchValue( + ast.Constant(True) + ), + ast.MatchValue( + ast.Constant((1,2,3)) + ), + ast.MatchSingleton('string'), + ast.MatchSequence([ + ast.MatchSingleton('string') + ]), + ast.MatchSequence( + [ + ast.MatchSequence( + [ + ast.MatchSingleton('string') + ] + ) + ] + ), + ast.MatchMapping( + [constant_1, constant_true], + [pattern_x] + ), + ast.MatchMapping( + [constant_true, constant_1], + [pattern_x, pattern_1], + rest='True' + ), + ast.MatchMapping( + [constant_true, ast.Starred(ast.Name('lol', ast.Load()), ast.Load())], + [pattern_x, pattern_1], + rest='legit' + ), + ast.MatchClass( + ast.Attribute( + ast.Attribute( + constant_x, + 'y', ast.Load()), + 'z', ast.Load()), + patterns=[], kwd_attrs=[], kwd_patterns=[] + ), + ast.MatchClass( + name_carter, + patterns=[], + kwd_attrs=['True'], + kwd_patterns=[pattern_1] + ), + ast.MatchClass( + name_carter, + patterns=[], + kwd_attrs=[], + kwd_patterns=[pattern_1] + ), + ast.MatchClass( + name_carter, + patterns=[ast.MatchSingleton('string')], + kwd_attrs=[], + kwd_patterns=[] + ), + ast.MatchClass( + name_carter, + patterns=[ast.MatchStar()], + kwd_attrs=[], + kwd_patterns=[] + ), + ast.MatchClass( + name_carter, + patterns=[], + kwd_attrs=[], + kwd_patterns=[ast.MatchStar()] + ), + ast.MatchSequence( + [ + ast.MatchStar("True") + ] + ), + ast.MatchAs( + name='False' + ), + ast.MatchOr( + [] + ), + ast.MatchOr( + [pattern_1] + ), + ast.MatchOr( + [pattern_1, pattern_x, ast.MatchSingleton('xxx')] + ) + ] + + def test_match_validation_pattern(self): + name_x = ast.Name('x', ast.Load()) + for pattern in self._MATCH_PATTERNS: + with self.subTest(ast.dump(pattern, indent=4)): + node = ast.Match( + subject=name_x, + cases = [ + ast.match_case( + pattern=pattern, + body = [ast.Pass()] + ) + ] + ) + node = ast.fix_missing_locations(node) + module = ast.Module([node], []) + with self.assertRaises(ValueError): + compile(module, "<test>", "exec") + class ConstantTests(unittest.TestCase): """Tests on the ast.Constant node type.""" diff --git a/Python/ast.c b/Python/ast.c index 1fc83f6..0a306c0 100644 --- a/Python/ast.c +++ b/Python/ast.c @@ -15,7 +15,8 @@ struct validator { }; static int validate_stmts(struct validator *, asdl_stmt_seq *); -static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, int); +static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int); +static int validate_patterns(struct validator *, asdl_pattern_seq *, int); static int _validate_nonempty_seq(asdl_seq *, const char *, const char *); static int validate_stmt(struct validator *, stmt_ty); static int validate_expr(struct validator *, expr_ty, expr_context_ty); @@ -33,7 +34,7 @@ validate_name(PyObject *name) }; for (int i = 0; forbidden[i] != NULL; i++) { if (_PyUnicode_EqualToASCIIString(name, forbidden[i])) { - PyErr_Format(PyExc_ValueError, "Name node can't be used with '%s' constant", forbidden[i]); + PyErr_Format(PyExc_ValueError, "identifier field can't represent '%s' constant", forbidden[i]); return 0; } } @@ -448,6 +449,21 @@ validate_pattern_match_value(struct validator *state, expr_ty exp) switch (exp->kind) { case Constant_kind: + /* Ellipsis and immutable sequences are not allowed. + For True, False and None, MatchSingleton() should + be used */ + if (!validate_expr(state, exp, Load)) { + return 0; + } + PyObject *literal = exp->v.Constant.value; + if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) || + PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) || + PyUnicode_CheckExact(literal)) { + return 1; + } + PyErr_SetString(PyExc_ValueError, + "unexpected constant inside of a literal pattern"); + return 0; case Attribute_kind: // Constants and attribute lookups are always permitted return 1; @@ -465,11 +481,14 @@ validate_pattern_match_value(struct validator *state, expr_ty exp) return 1; } break; + case JoinedStr_kind: + // Handled in the later stages + return 1; default: break; } - PyErr_SetString(PyExc_SyntaxError, - "patterns may only match literals and attribute lookups"); + PyErr_SetString(PyExc_ValueError, + "patterns may only match literals and attribute lookups"); return 0; } @@ -489,51 +508,101 @@ validate_pattern(struct validator *state, pattern_ty p) ret = validate_pattern_match_value(state, p->v.MatchValue.value); break; case MatchSingleton_kind: - // TODO: Check constant is specifically None, True, or False - ret = validate_constant(state, p->v.MatchSingleton.value); + ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value); + if (!ret) { + PyErr_SetString(PyExc_ValueError, + "MatchSingleton can only contain True, False and None"); + } break; case MatchSequence_kind: - // TODO: Validate all subpatterns - // return validate_patterns(state, p->v.MatchSequence.patterns); - ret = 1; + ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1); break; case MatchMapping_kind: - // TODO: check "rest" target name is valid if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) { PyErr_SetString(PyExc_ValueError, "MatchMapping doesn't have the same number of keys as patterns"); - return 0; + ret = 0; + break; } - // null_ok=0 for key expressions, as rest-of-mapping is captured in "rest" - // TODO: replace with more restrictive expression validator, as per MatchValue above - if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) { - return 0; + + if (p->v.MatchMapping.rest && !validate_name(p->v.MatchMapping.rest)) { + ret = 0; + break; } - // TODO: Validate all subpatterns - // ret = validate_patterns(state, p->v.MatchMapping.patterns); - ret = 1; + + asdl_expr_seq *keys = p->v.MatchMapping.keys; + for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) { + expr_ty key = asdl_seq_GET(keys, i); + if (key->kind == Constant_kind) { + PyObject *literal = key->v.Constant.value; + if (literal == Py_None || PyBool_Check(literal)) { + /* validate_pattern_match_value will ensure the key + doesn't contain True, False and None but it is + syntactically valid, so we will pass those on in + a special case. */ + continue; + } + } + if (!validate_pattern_match_value(state, key)) { + ret = 0; + break; + } + } + + ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0); break; case MatchClass_kind: if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) { PyErr_SetString(PyExc_ValueError, "MatchClass doesn't have the same number of keyword attributes as patterns"); - return 0; + ret = 0; + break; } - // TODO: Restrict cls lookup to being a name or attribute if (!validate_expr(state, p->v.MatchClass.cls, Load)) { - return 0; + ret = 0; + break; } - // TODO: Validate all subpatterns - // return validate_patterns(state, p->v.MatchClass.patterns) && - // validate_patterns(state, p->v.MatchClass.kwd_patterns); - ret = 1; + + expr_ty cls = p->v.MatchClass.cls; + while (1) { + if (cls->kind == Name_kind) { + break; + } + else if (cls->kind == Attribute_kind) { + cls = cls->v.Attribute.value; + continue; + } + else { + PyErr_SetString(PyExc_ValueError, + "MatchClass cls field can only contain Name or Attribute nodes."); + state->recursion_depth--; + return 0; + } + } + + for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) { + PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i); + if (!validate_name(identifier)) { + state->recursion_depth--; + return 0; + } + } + + if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) { + ret = 0; + break; + } + + ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0); break; case MatchStar_kind: - // TODO: check target name is valid - ret = 1; + ret = p->v.MatchStar.name == NULL || validate_name(p->v.MatchStar.name); break; case MatchAs_kind: - // TODO: check target name is valid + if (p->v.MatchAs.name && !validate_name(p->v.MatchAs.name)) { + ret = 0; + break; + } if (p->v.MatchAs.pattern == NULL) { ret = 1; } @@ -547,9 +616,13 @@ validate_pattern(struct validator *state, pattern_ty p) } break; case MatchOr_kind: - // TODO: Validate all subpatterns - // return validate_patterns(state, p->v.MatchOr.patterns); - ret = 1; + if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) { + PyErr_SetString(PyExc_ValueError, + "MatchOr requires at least 2 patterns"); + ret = 0; + break; + } + ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0); break; // No default case, so the compiler will emit a warning if new pattern // kinds are added without being handled here @@ -815,6 +888,25 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct return 1; } +static int +validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok) +{ + Py_ssize_t i; + for (i = 0; i < asdl_seq_LEN(patterns); i++) { + pattern_ty pattern = asdl_seq_GET(patterns, i); + if (pattern->kind == MatchStar_kind && !star_ok) { + PyErr_SetString(PyExc_ValueError, + "Can't use MatchStar within this sequence of patterns"); + return 0; + } + if (!validate_pattern(state, pattern)) { + return 0; + } + } + return 1; +} + + /* See comments in symtable.c. */ #define COMPILER_STACK_FRAME_SCALE 3 |