From 405f5c54914483607194a3ba6d4e50533d92bad1 Mon Sep 17 00:00:00 2001
From: "Miss Islington (bot)"
 <31488909+miss-islington@users.noreply.github.com>
Date: Wed, 28 Jul 2021 18:02:14 -0700
Subject: [3.10] bpo-43897: Reject "_" captures and top-level MatchStar in the
 AST validator (GH-27432) (GH-27435)

(cherry picked from commit 8d0647485db5af2a0f0929d6509479ca45f1281b)


Co-authored-by: Brandt Bucher <brandt@python.org>

Automerge-Triggered-By: GH:brandtbucher
---
 Lib/test/test_ast.py |  6 +++++-
 Python/ast.c         | 48 ++++++++++++++++++++++++++++--------------------
 2 files changed, 33 insertions(+), 21 deletions(-)

diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index ac0669a..326f3ab 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -1596,7 +1596,11 @@ class ASTValidatorTests(unittest.TestCase):
         ),
         ast.MatchOr(
             [pattern_1, pattern_x, ast.MatchSingleton('xxx')]
-        )
+        ),
+        ast.MatchAs(name="_"),
+        ast.MatchStar(name="x"),
+        ast.MatchSequence([ast.MatchStar("_")]),
+        ast.MatchMapping([], [], rest="_"),
     ]
 
     def test_match_validation_pattern(self):
diff --git a/Python/ast.c b/Python/ast.c
index 0a306c0..2113124 100644
--- a/Python/ast.c
+++ b/Python/ast.c
@@ -20,7 +20,7 @@ static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
 static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
 static int validate_stmt(struct validator *, stmt_ty);
 static int validate_expr(struct validator *, expr_ty, expr_context_ty);
-static int validate_pattern(struct validator *, pattern_ty);
+static int validate_pattern(struct validator *, pattern_ty, int);
 
 static int
 validate_name(PyObject *name)
@@ -493,7 +493,17 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
 }
 
 static int
-validate_pattern(struct validator *state, pattern_ty p)
+validate_capture(PyObject *name)
+{
+    if (_PyUnicode_EqualToASCIIString(name, "_")) {
+        PyErr_Format(PyExc_ValueError, "can't capture name '_' in patterns");
+        return 0;
+    }
+    return validate_name(name);
+}
+
+static int
+validate_pattern(struct validator *state, pattern_ty p, int star_ok)
 {
     int ret = -1;
     if (++state->recursion_depth > state->recursion_limit) {
@@ -501,8 +511,6 @@ validate_pattern(struct validator *state, pattern_ty p)
                         "maximum recursion depth exceeded during compilation");
         return 0;
     }
-    // Coming soon: https://bugs.python.org/issue43897 (thanks Batuhan)!
-    // TODO: Ensure no subnodes use "_" as an ordinary identifier
     switch (p->kind) {
         case MatchValue_kind:
             ret = validate_pattern_match_value(state, p->v.MatchValue.value);
@@ -525,7 +533,7 @@ validate_pattern(struct validator *state, pattern_ty p)
                 break;
             }
 
-            if (p->v.MatchMapping.rest && !validate_name(p->v.MatchMapping.rest)) {
+            if (p->v.MatchMapping.rest && !validate_capture(p->v.MatchMapping.rest)) {
                 ret = 0;
                 break;
             }
@@ -575,16 +583,16 @@ validate_pattern(struct validator *state, pattern_ty p)
                 else {
                     PyErr_SetString(PyExc_ValueError,
                                     "MatchClass cls field can only contain Name or Attribute nodes.");
-                    state->recursion_depth--;
-                    return 0;
+                    ret = 0;
+                    break;
                 }
             }
 
             for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) {
                 PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i);
                 if (!validate_name(identifier)) {
-                    state->recursion_depth--;
-                    return 0;
+                    ret = 0;
+                    break;
                 }
             }
 
@@ -596,10 +604,15 @@ validate_pattern(struct validator *state, pattern_ty p)
             ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
             break;
         case MatchStar_kind:
-            ret = p->v.MatchStar.name == NULL || validate_name(p->v.MatchStar.name);
+            if (!star_ok) {
+                PyErr_SetString(PyExc_ValueError, "can't use MatchStar here");
+                ret = 0;
+                break;
+            }
+            ret = p->v.MatchStar.name == NULL || validate_capture(p->v.MatchStar.name);
             break;
         case MatchAs_kind:
-            if (p->v.MatchAs.name && !validate_name(p->v.MatchAs.name)) {
+            if (p->v.MatchAs.name && !validate_capture(p->v.MatchAs.name)) {
                 ret = 0;
                 break;
             }
@@ -609,10 +622,10 @@ validate_pattern(struct validator *state, pattern_ty p)
             else if (p->v.MatchAs.name == NULL) {
                 PyErr_SetString(PyExc_ValueError,
                                 "MatchAs must specify a target name if a pattern is given");
-                return 0;
+                ret = 0;
             }
             else {
-                ret = validate_pattern(state, p->v.MatchAs.pattern);
+                ret = validate_pattern(state, p->v.MatchAs.pattern, /*star_ok=*/0);
             }
             break;
         case MatchOr_kind:
@@ -759,7 +772,7 @@ validate_stmt(struct validator *state, stmt_ty stmt)
         }
         for (i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) {
             match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i);
-            if (!validate_pattern(state, m->pattern)
+            if (!validate_pattern(state, m->pattern, /*star_ok=*/0)
                 || (m->guard && !validate_expr(state, m->guard, Load))
                 || !validate_body(state, m->body, "match_case")) {
                 return 0;
@@ -894,12 +907,7 @@ validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_
     Py_ssize_t i;
     for (i = 0; i < asdl_seq_LEN(patterns); i++) {
         pattern_ty pattern = asdl_seq_GET(patterns, i);
-        if (pattern->kind == MatchStar_kind && !star_ok) {
-            PyErr_SetString(PyExc_ValueError,
-                            "Can't use MatchStar within this sequence of patterns");
-            return 0;
-        }
-        if (!validate_pattern(state, pattern)) {
+        if (!validate_pattern(state, pattern, star_ok)) {
             return 0;
         }
     }
-- 
cgit v0.12