summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBatuhan Taskaya <batuhan@python.org>2021-07-28 17:14:45 (GMT)
committerGitHub <noreply@github.com>2021-07-28 17:14:45 (GMT)
commit31bec6f1b178dadec3cb43353274b4e958a8f015 (patch)
tree007bb1ecbe4c69a0719a42076a18aa2fb8e098b6
parent53b9458f2e9314703a5406ca817d757f1509882a (diff)
downloadcpython-31bec6f1b178dadec3cb43353274b4e958a8f015.zip
cpython-31bec6f1b178dadec3cb43353274b4e958a8f015.tar.gz
cpython-31bec6f1b178dadec3cb43353274b4e958a8f015.tar.bz2
bpo-43897: AST validation for pattern matching nodes (GH24771)
-rw-r--r--Lib/test/test_ast.py143
-rw-r--r--Python/ast.c154
2 files changed, 265 insertions, 32 deletions
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index 5f1ee75..925bb88 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -696,7 +696,7 @@ class AST_Tests(unittest.TestCase):
for constant in "True", "False", "None":
expr = ast.Expression(ast.Name(constant, ast.Load()))
ast.fix_missing_locations(expr)
- with self.assertRaisesRegex(ValueError, f"Name node can't be used with '{constant}' constant"):
+ with self.assertRaisesRegex(ValueError, f"identifier field can't represent '{constant}' constant"):
compile(expr, "<test>", "eval")
def test_precedence_enum(self):
@@ -1507,6 +1507,147 @@ class ASTValidatorTests(unittest.TestCase):
mod = ast.parse(source, fn)
compile(mod, fn, "exec")
+ constant_1 = ast.Constant(1)
+ pattern_1 = ast.MatchValue(constant_1)
+
+ constant_x = ast.Constant('x')
+ pattern_x = ast.MatchValue(constant_x)
+
+ constant_true = ast.Constant(True)
+ pattern_true = ast.MatchSingleton(True)
+
+ name_carter = ast.Name('carter', ast.Load())
+
+ _MATCH_PATTERNS = [
+ ast.MatchValue(
+ ast.Attribute(
+ ast.Attribute(
+ ast.Name('x', ast.Store()),
+ 'y', ast.Load()
+ ),
+ 'z', ast.Load()
+ )
+ ),
+ ast.MatchValue(
+ ast.Attribute(
+ ast.Attribute(
+ ast.Name('x', ast.Load()),
+ 'y', ast.Store()
+ ),
+ 'z', ast.Load()
+ )
+ ),
+ ast.MatchValue(
+ ast.Constant(...)
+ ),
+ ast.MatchValue(
+ ast.Constant(True)
+ ),
+ ast.MatchValue(
+ ast.Constant((1,2,3))
+ ),
+ ast.MatchSingleton('string'),
+ ast.MatchSequence([
+ ast.MatchSingleton('string')
+ ]),
+ ast.MatchSequence(
+ [
+ ast.MatchSequence(
+ [
+ ast.MatchSingleton('string')
+ ]
+ )
+ ]
+ ),
+ ast.MatchMapping(
+ [constant_1, constant_true],
+ [pattern_x]
+ ),
+ ast.MatchMapping(
+ [constant_true, constant_1],
+ [pattern_x, pattern_1],
+ rest='True'
+ ),
+ ast.MatchMapping(
+ [constant_true, ast.Starred(ast.Name('lol', ast.Load()), ast.Load())],
+ [pattern_x, pattern_1],
+ rest='legit'
+ ),
+ ast.MatchClass(
+ ast.Attribute(
+ ast.Attribute(
+ constant_x,
+ 'y', ast.Load()),
+ 'z', ast.Load()),
+ patterns=[], kwd_attrs=[], kwd_patterns=[]
+ ),
+ ast.MatchClass(
+ name_carter,
+ patterns=[],
+ kwd_attrs=['True'],
+ kwd_patterns=[pattern_1]
+ ),
+ ast.MatchClass(
+ name_carter,
+ patterns=[],
+ kwd_attrs=[],
+ kwd_patterns=[pattern_1]
+ ),
+ ast.MatchClass(
+ name_carter,
+ patterns=[ast.MatchSingleton('string')],
+ kwd_attrs=[],
+ kwd_patterns=[]
+ ),
+ ast.MatchClass(
+ name_carter,
+ patterns=[ast.MatchStar()],
+ kwd_attrs=[],
+ kwd_patterns=[]
+ ),
+ ast.MatchClass(
+ name_carter,
+ patterns=[],
+ kwd_attrs=[],
+ kwd_patterns=[ast.MatchStar()]
+ ),
+ ast.MatchSequence(
+ [
+ ast.MatchStar("True")
+ ]
+ ),
+ ast.MatchAs(
+ name='False'
+ ),
+ ast.MatchOr(
+ []
+ ),
+ ast.MatchOr(
+ [pattern_1]
+ ),
+ ast.MatchOr(
+ [pattern_1, pattern_x, ast.MatchSingleton('xxx')]
+ )
+ ]
+
+ def test_match_validation_pattern(self):
+ name_x = ast.Name('x', ast.Load())
+ for pattern in self._MATCH_PATTERNS:
+ with self.subTest(ast.dump(pattern, indent=4)):
+ node = ast.Match(
+ subject=name_x,
+ cases = [
+ ast.match_case(
+ pattern=pattern,
+ body = [ast.Pass()]
+ )
+ ]
+ )
+ node = ast.fix_missing_locations(node)
+ module = ast.Module([node], [])
+ with self.assertRaises(ValueError):
+ compile(module, "<test>", "exec")
+
class ConstantTests(unittest.TestCase):
"""Tests on the ast.Constant node type."""
diff --git a/Python/ast.c b/Python/ast.c
index 1fc83f6..0a306c0 100644
--- a/Python/ast.c
+++ b/Python/ast.c
@@ -15,7 +15,8 @@ struct validator {
};
static int validate_stmts(struct validator *, asdl_stmt_seq *);
-static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, int);
+static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
+static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
static int validate_stmt(struct validator *, stmt_ty);
static int validate_expr(struct validator *, expr_ty, expr_context_ty);
@@ -33,7 +34,7 @@ validate_name(PyObject *name)
};
for (int i = 0; forbidden[i] != NULL; i++) {
if (_PyUnicode_EqualToASCIIString(name, forbidden[i])) {
- PyErr_Format(PyExc_ValueError, "Name node can't be used with '%s' constant", forbidden[i]);
+ PyErr_Format(PyExc_ValueError, "identifier field can't represent '%s' constant", forbidden[i]);
return 0;
}
}
@@ -448,6 +449,21 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
switch (exp->kind)
{
case Constant_kind:
+ /* Ellipsis and immutable sequences are not allowed.
+ For True, False and None, MatchSingleton() should
+ be used */
+ if (!validate_expr(state, exp, Load)) {
+ return 0;
+ }
+ PyObject *literal = exp->v.Constant.value;
+ if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) ||
+ PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) ||
+ PyUnicode_CheckExact(literal)) {
+ return 1;
+ }
+ PyErr_SetString(PyExc_ValueError,
+ "unexpected constant inside of a literal pattern");
+ return 0;
case Attribute_kind:
// Constants and attribute lookups are always permitted
return 1;
@@ -465,11 +481,14 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
return 1;
}
break;
+ case JoinedStr_kind:
+ // Handled in the later stages
+ return 1;
default:
break;
}
- PyErr_SetString(PyExc_SyntaxError,
- "patterns may only match literals and attribute lookups");
+ PyErr_SetString(PyExc_ValueError,
+ "patterns may only match literals and attribute lookups");
return 0;
}
@@ -489,51 +508,101 @@ validate_pattern(struct validator *state, pattern_ty p)
ret = validate_pattern_match_value(state, p->v.MatchValue.value);
break;
case MatchSingleton_kind:
- // TODO: Check constant is specifically None, True, or False
- ret = validate_constant(state, p->v.MatchSingleton.value);
+ ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value);
+ if (!ret) {
+ PyErr_SetString(PyExc_ValueError,
+ "MatchSingleton can only contain True, False and None");
+ }
break;
case MatchSequence_kind:
- // TODO: Validate all subpatterns
- // return validate_patterns(state, p->v.MatchSequence.patterns);
- ret = 1;
+ ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1);
break;
case MatchMapping_kind:
- // TODO: check "rest" target name is valid
if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
PyErr_SetString(PyExc_ValueError,
"MatchMapping doesn't have the same number of keys as patterns");
- return 0;
+ ret = 0;
+ break;
}
- // null_ok=0 for key expressions, as rest-of-mapping is captured in "rest"
- // TODO: replace with more restrictive expression validator, as per MatchValue above
- if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) {
- return 0;
+
+ if (p->v.MatchMapping.rest && !validate_name(p->v.MatchMapping.rest)) {
+ ret = 0;
+ break;
}
- // TODO: Validate all subpatterns
- // ret = validate_patterns(state, p->v.MatchMapping.patterns);
- ret = 1;
+
+ asdl_expr_seq *keys = p->v.MatchMapping.keys;
+ for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) {
+ expr_ty key = asdl_seq_GET(keys, i);
+ if (key->kind == Constant_kind) {
+ PyObject *literal = key->v.Constant.value;
+ if (literal == Py_None || PyBool_Check(literal)) {
+ /* validate_pattern_match_value will ensure the key
+ doesn't contain True, False and None but it is
+ syntactically valid, so we will pass those on in
+ a special case. */
+ continue;
+ }
+ }
+ if (!validate_pattern_match_value(state, key)) {
+ ret = 0;
+ break;
+ }
+ }
+
+ ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0);
break;
case MatchClass_kind:
if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
PyErr_SetString(PyExc_ValueError,
"MatchClass doesn't have the same number of keyword attributes as patterns");
- return 0;
+ ret = 0;
+ break;
}
- // TODO: Restrict cls lookup to being a name or attribute
if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
- return 0;
+ ret = 0;
+ break;
}
- // TODO: Validate all subpatterns
- // return validate_patterns(state, p->v.MatchClass.patterns) &&
- // validate_patterns(state, p->v.MatchClass.kwd_patterns);
- ret = 1;
+
+ expr_ty cls = p->v.MatchClass.cls;
+ while (1) {
+ if (cls->kind == Name_kind) {
+ break;
+ }
+ else if (cls->kind == Attribute_kind) {
+ cls = cls->v.Attribute.value;
+ continue;
+ }
+ else {
+ PyErr_SetString(PyExc_ValueError,
+ "MatchClass cls field can only contain Name or Attribute nodes.");
+ state->recursion_depth--;
+ return 0;
+ }
+ }
+
+ for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) {
+ PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i);
+ if (!validate_name(identifier)) {
+ state->recursion_depth--;
+ return 0;
+ }
+ }
+
+ if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) {
+ ret = 0;
+ break;
+ }
+
+ ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
break;
case MatchStar_kind:
- // TODO: check target name is valid
- ret = 1;
+ ret = p->v.MatchStar.name == NULL || validate_name(p->v.MatchStar.name);
break;
case MatchAs_kind:
- // TODO: check target name is valid
+ if (p->v.MatchAs.name && !validate_name(p->v.MatchAs.name)) {
+ ret = 0;
+ break;
+ }
if (p->v.MatchAs.pattern == NULL) {
ret = 1;
}
@@ -547,9 +616,13 @@ validate_pattern(struct validator *state, pattern_ty p)
}
break;
case MatchOr_kind:
- // TODO: Validate all subpatterns
- // return validate_patterns(state, p->v.MatchOr.patterns);
- ret = 1;
+ if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) {
+ PyErr_SetString(PyExc_ValueError,
+ "MatchOr requires at least 2 patterns");
+ ret = 0;
+ break;
+ }
+ ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0);
break;
// No default case, so the compiler will emit a warning if new pattern
// kinds are added without being handled here
@@ -815,6 +888,25 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct
return 1;
}
+static int
+validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok)
+{
+ Py_ssize_t i;
+ for (i = 0; i < asdl_seq_LEN(patterns); i++) {
+ pattern_ty pattern = asdl_seq_GET(patterns, i);
+ if (pattern->kind == MatchStar_kind && !star_ok) {
+ PyErr_SetString(PyExc_ValueError,
+ "Can't use MatchStar within this sequence of patterns");
+ return 0;
+ }
+ if (!validate_pattern(state, pattern)) {
+ return 0;
+ }
+ }
+ return 1;
+}
+
+
/* See comments in symtable.c. */
#define COMPILER_STACK_FRAME_SCALE 3