diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2023-12-25 19:20:07 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-25 19:20:07 (GMT) |
commit | d58a5f453f59f44ccf09b1a9b11a0b879ac6f35b (patch) | |
tree | 8499a750b08b12fabc40c6a3be26cf96ca1d75a8 /Parser | |
parent | 2c07540e7d2a271d09bcca2f7731a8c60cd38350 (diff) | |
download | cpython-d58a5f453f59f44ccf09b1a9b11a0b879ac6f35b.zip cpython-d58a5f453f59f44ccf09b1a9b11a0b879ac6f35b.tar.gz cpython-d58a5f453f59f44ccf09b1a9b11a0b879ac6f35b.tar.bz2 |
[3.12] gh-106905: Use separate structs to track recursion depth in each PyAST_mod2obj call. (GH-113035) (GH-113472)
(cherry picked from commit 48c49739f5502fc7aa82f247ab2e4d7b55bdca62)
Co-authored-by: Yilei Yang <yileiyang@google.com>
Co-authored-by: Gregory P. Smith [Google LLC] <greg@krypto.org>
Diffstat (limited to 'Parser')
-rwxr-xr-x | Parser/asdl_c.py | 58 |
1 files changed, 32 insertions, 26 deletions
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py index aa093b3..d42c263 100755 --- a/Parser/asdl_c.py +++ b/Parser/asdl_c.py @@ -731,7 +731,7 @@ class SequenceConstructorVisitor(EmitVisitor): class PyTypesDeclareVisitor(PickleVisitor): def visitProduct(self, prod, name): - self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0) + self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0) if prod.attributes: self.emit("static const char * const %s_attributes[] = {" % name, 0) for a in prod.attributes: @@ -752,7 +752,7 @@ class PyTypesDeclareVisitor(PickleVisitor): ptype = "void*" if is_simple(sum): ptype = get_c_type(name) - self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0) + self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0) for t in sum.types: self.visitConstructor(t, name) @@ -984,7 +984,8 @@ add_attributes(struct ast_state *state, PyObject *type, const char * const *attr /* Conversion AST -> Python */ -static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*)) +static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq, + PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*)) { Py_ssize_t i, n = asdl_seq_LEN(seq); PyObject *result = PyList_New(n); @@ -992,7 +993,7 @@ static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* if (!result) return NULL; for (i = 0; i < n; i++) { - value = func(state, asdl_seq_GET_UNTYPED(seq, i)); + value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i)); if (!value) { Py_DECREF(result); return NULL; @@ -1002,7 +1003,7 @@ static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* return result; } -static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o) +static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o) { PyObject *op = (PyObject*)o; if (!op) { @@ -1014,7 +1015,7 @@ static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o) #define ast2obj_identifier ast2obj_object #define ast2obj_string ast2obj_object -static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b) +static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b) { return PyLong_FromLong(b); } @@ -1123,8 +1124,6 @@ static int add_ast_fields(struct ast_state *state) for dfn in mod.dfns: self.visit(dfn) self.file.write(textwrap.dedent(''' - state->recursion_depth = 0; - state->recursion_limit = 0; state->initialized = 1; return 1; } @@ -1265,7 +1264,7 @@ class ObjVisitor(PickleVisitor): def func_begin(self, name): ctype = get_c_type(name) self.emit("PyObject*", 0) - self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0) + self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0) self.emit("{", 0) self.emit("%s o = (%s)_o;" % (ctype, ctype), 1) self.emit("PyObject *result = NULL, *value = NULL;", 1) @@ -1273,16 +1272,17 @@ class ObjVisitor(PickleVisitor): self.emit('if (!o) {', 1) self.emit("Py_RETURN_NONE;", 2) self.emit("}", 1) - self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1) + self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1) self.emit("PyErr_SetString(PyExc_RecursionError,", 2) self.emit('"maximum recursion depth exceeded during ast construction");', 3) self.emit("return 0;", 2) self.emit("}", 1) def func_end(self): - self.emit("state->recursion_depth--;", 1) + self.emit("vstate->recursion_depth--;", 1) self.emit("return result;", 1) self.emit("failed:", 0) + self.emit("vstate->recursion_depth--;", 1) self.emit("Py_XDECREF(value);", 1) self.emit("Py_XDECREF(result);", 1) self.emit("return NULL;", 1) @@ -1300,7 +1300,7 @@ class ObjVisitor(PickleVisitor): self.visitConstructor(t, i + 1, name) self.emit("}", 1) for a in sum.attributes: - self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) + self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1) self.emit("if (!value) goto failed;", 1) self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1) self.emit('goto failed;', 2) @@ -1308,7 +1308,7 @@ class ObjVisitor(PickleVisitor): self.func_end() def simpleSum(self, sum, name): - self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0) + self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0) self.emit("{", 0) self.emit("switch(o) {", 1) for t in sum.types: @@ -1326,7 +1326,7 @@ class ObjVisitor(PickleVisitor): for field in prod.fields: self.visitField(field, name, 1, True) for a in prod.attributes: - self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) + self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1) self.emit("if (!value) goto failed;", 1) self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1) self.emit('goto failed;', 2) @@ -1367,7 +1367,7 @@ class ObjVisitor(PickleVisitor): self.emit("for(i = 0; i < n; i++)", depth+1) # This cannot fail, so no need for error handling self.emit( - "PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format( + "PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format( field.type, value ), @@ -1376,9 +1376,9 @@ class ObjVisitor(PickleVisitor): ) self.emit("}", depth) else: - self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth) + self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth) else: - self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False) + self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False) class PartingShots(StaticVisitor): @@ -1396,21 +1396,22 @@ PyObject* PyAST_mod2obj(mod_ty t) int COMPILER_STACK_FRAME_SCALE = 2; PyThreadState *tstate = _PyThreadState_GET(); if (!tstate) { - return 0; + return NULL; } - state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE; + struct validator vstate; + vstate.recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE; int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining; starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE; - state->recursion_depth = starting_recursion_depth; + vstate.recursion_depth = starting_recursion_depth; - PyObject *result = ast2obj_mod(state, t); + PyObject *result = ast2obj_mod(state, &vstate, t); /* Check that the recursion depth counting balanced correctly */ - if (result && state->recursion_depth != starting_recursion_depth) { + if (result && vstate.recursion_depth != starting_recursion_depth) { PyErr_Format(PyExc_SystemError, "AST constructor recursion depth mismatch (before=%d, after=%d)", - starting_recursion_depth, state->recursion_depth); - return 0; + starting_recursion_depth, vstate.recursion_depth); + return NULL; } return result; } @@ -1478,8 +1479,8 @@ class ChainOfVisitors: def generate_ast_state(module_state, f): f.write('struct ast_state {\n') f.write(' int initialized;\n') - f.write(' int recursion_depth;\n') - f.write(' int recursion_limit;\n') + f.write(' int unused_recursion_depth;\n') + f.write(' int unused_recursion_limit;\n') for s in module_state: f.write(' PyObject *' + s + ';\n') f.write('};') @@ -1545,6 +1546,11 @@ def generate_module_def(mod, metadata, f, internal_h): #include "structmember.h" #include <stddef.h> + struct validator { + int recursion_depth; /* current recursion depth */ + int recursion_limit; /* recursion limit */ + }; + // Forward declaration static int init_types(struct ast_state *state); |