diff options
author | Sam Gross <colesbury@gmail.com> | 2023-11-16 19:19:54 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-16 19:19:54 (GMT) |
commit | 446f18a911916eabd2c0ceed0c2a109fc8480727 (patch) | |
tree | dfe463feacf3aed8db9f622d649e8221b4101e03 /Parser | |
parent | f66afa395a6d06097ad1ca222ed076e18a7a8126 (diff) | |
download | cpython-446f18a911916eabd2c0ceed0c2a109fc8480727.zip cpython-446f18a911916eabd2c0ceed0c2a109fc8480727.tar.gz cpython-446f18a911916eabd2c0ceed0c2a109fc8480727.tar.bz2 |
gh-111956: Add thread-safe one-time initialization. (gh-111960)
Diffstat (limited to 'Parser')
-rwxr-xr-x | Parser/asdl_c.py | 89 |
1 files changed, 42 insertions, 47 deletions
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py index ae642e8..c9bf08e 100755 --- a/Parser/asdl_c.py +++ b/Parser/asdl_c.py @@ -518,7 +518,7 @@ class Obj2ModVisitor(PickleVisitor): if add_label: self.emit("failed:", 1) self.emit("Py_XDECREF(tmp);", 1) - self.emit("return 1;", 1) + self.emit("return -1;", 1) self.emit("}", 0) self.emit("", 0) @@ -529,7 +529,7 @@ class Obj2ModVisitor(PickleVisitor): "state->%s_type);") self.emit(line % (t.name,), 1) self.emit("if (isinstance == -1) {", 1) - self.emit("return 1;", 2) + self.emit("return -1;", 2) self.emit("}", 1) self.emit("if (isinstance) {", 1) self.emit("*out = %s;" % t.name, 2) @@ -558,7 +558,7 @@ class Obj2ModVisitor(PickleVisitor): self.emit("tp = state->%s_type;" % (t.name,), 1) self.emit("isinstance = PyObject_IsInstance(obj, tp);", 1) self.emit("if (isinstance == -1) {", 1) - self.emit("return 1;", 2) + self.emit("return -1;", 2) self.emit("}", 1) self.emit("if (isinstance) {", 1) for f in t.fields: @@ -605,7 +605,7 @@ class Obj2ModVisitor(PickleVisitor): self.emit("return 0;", 1) self.emit("failed:", 0) self.emit("Py_XDECREF(tmp);", 1) - self.emit("return 1;", 1) + self.emit("return -1;", 1) self.emit("}", 0) self.emit("", 0) @@ -631,13 +631,13 @@ class Obj2ModVisitor(PickleVisitor): ctype = get_c_type(field.type) line = "if (PyObject_GetOptionalAttr(obj, state->%s, &tmp) < 0) {" self.emit(line % field.name, depth) - self.emit("return 1;", depth+1) + self.emit("return -1;", depth+1) self.emit("}", depth) if field.seq: self.emit("if (tmp == NULL) {", depth) self.emit("tmp = PyList_New(0);", depth+1) self.emit("if (tmp == NULL) {", depth+1) - self.emit("return 1;", depth+2) + self.emit("return -1;", depth+2) self.emit("}", depth+1) self.emit("}", depth) self.emit("{", depth) @@ -647,7 +647,7 @@ class Obj2ModVisitor(PickleVisitor): message = "required field \\\"%s\\\" missing from %s" % (field.name, name) format = "PyErr_SetString(PyExc_TypeError, \"%s\");" self.emit(format % message, depth+1, reflow=False) - self.emit("return 1;", depth+1) + self.emit("return -1;", depth+1) else: self.emit("if (tmp == NULL || tmp == Py_None) {", depth) self.emit("Py_CLEAR(tmp);", depth+1) @@ -968,16 +968,16 @@ add_attributes(struct ast_state *state, PyObject *type, const char * const *attr int i, result; PyObject *s, *l = PyTuple_New(num_fields); if (!l) - return 0; + return -1; for (i = 0; i < num_fields; i++) { s = PyUnicode_InternFromString(attrs[i]); if (!s) { Py_DECREF(l); - return 0; + return -1; } PyTuple_SET_ITEM(l, i, s); } - result = PyObject_SetAttr(type, state->_attributes, l) >= 0; + result = PyObject_SetAttr(type, state->_attributes, l); Py_DECREF(l); return result; } @@ -1052,7 +1052,7 @@ static int obj2ast_identifier(struct ast_state *state, PyObject* obj, PyObject** { if (!PyUnicode_CheckExact(obj) && obj != Py_None) { PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str"); - return 1; + return -1; } return obj2ast_object(state, obj, out, arena); } @@ -1061,7 +1061,7 @@ static int obj2ast_string(struct ast_state *state, PyObject* obj, PyObject** out { if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) { PyErr_SetString(PyExc_TypeError, "AST string must be of type str"); - return 1; + return -1; } return obj2ast_object(state, obj, out, arena); } @@ -1071,12 +1071,12 @@ static int obj2ast_int(struct ast_state* Py_UNUSED(state), PyObject* obj, int* o int i; if (!PyLong_Check(obj)) { PyErr_Format(PyExc_ValueError, "invalid integer value: %R", obj); - return 1; + return -1; } i = PyLong_AsInt(obj); if (i == -1 && PyErr_Occurred()) - return 1; + return -1; *out = i; return 0; } @@ -1102,22 +1102,15 @@ static int add_ast_fields(struct ast_state *state) static int init_types(struct ast_state *state) { - // init_types() must not be called after _PyAST_Fini() - // has been called - assert(state->initialized >= 0); - - if (state->initialized) { - return 1; - } if (init_identifiers(state) < 0) { - return 0; + return -1; } state->AST_type = PyType_FromSpec(&AST_type_spec); if (!state->AST_type) { - return 0; + return -1; } if (add_ast_fields(state) < 0) { - return 0; + return -1; } ''')) for dfn in mod.dfns: @@ -1125,8 +1118,7 @@ static int add_ast_fields(struct ast_state *state) self.file.write(textwrap.dedent(''' state->recursion_depth = 0; state->recursion_limit = 0; - state->initialized = 1; - return 1; + return 0; } ''')) @@ -1138,12 +1130,12 @@ static int add_ast_fields(struct ast_state *state) self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' % (name, name, fields, len(prod.fields)), 1) self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False) - self.emit("if (!state->%s_type) return 0;" % name, 1) + self.emit("if (!state->%s_type) return -1;" % name, 1) if prod.attributes: - self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" % + self.emit("if (add_attributes(state, state->%s_type, %s_attributes, %d) < 0) return -1;" % (name, name, len(prod.attributes)), 1) else: - self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1) + self.emit("if (add_attributes(state, state->%s_type, NULL, 0) < 0) return -1;" % name, 1) self.emit_defaults(name, prod.fields, 1) self.emit_defaults(name, prod.attributes, 1) @@ -1151,12 +1143,12 @@ static int add_ast_fields(struct ast_state *state) self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' % (name, name), 1) self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False) - self.emit("if (!state->%s_type) return 0;" % name, 1) + self.emit("if (!state->%s_type) return -1;" % name, 1) if sum.attributes: - self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" % + self.emit("if (add_attributes(state, state->%s_type, %s_attributes, %d) < 0) return -1;" % (name, name, len(sum.attributes)), 1) else: - self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1) + self.emit("if (add_attributes(state, state->%s_type, NULL, 0) < 0) return -1;" % name, 1) self.emit_defaults(name, sum.attributes, 1) simple = is_simple(sum) for t in sum.types: @@ -1170,20 +1162,20 @@ static int add_ast_fields(struct ast_state *state) self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' % (cons.name, cons.name, name, fields, len(cons.fields)), 1) self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False) - self.emit("if (!state->%s_type) return 0;" % cons.name, 1) + self.emit("if (!state->%s_type) return -1;" % cons.name, 1) self.emit_defaults(cons.name, cons.fields, 1) if simple: self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)" "state->%s_type, NULL, NULL);" % (cons.name, cons.name), 1) - self.emit("if (!state->%s_singleton) return 0;" % cons.name, 1) + self.emit("if (!state->%s_singleton) return -1;" % cons.name, 1) def emit_defaults(self, name, fields, depth): for field in fields: if field.opt: self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' % (name, field.name), depth) - self.emit("return 0;", depth+1) + self.emit("return -1;", depth+1) class ASTModuleVisitor(PickleVisitor): @@ -1279,7 +1271,7 @@ class ObjVisitor(PickleVisitor): self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1) self.emit("PyErr_SetString(PyExc_RecursionError,", 2) self.emit('"maximum recursion depth exceeded during ast construction");', 3) - self.emit("return 0;", 2) + self.emit("return NULL;", 2) self.emit("}", 1) def func_end(self): @@ -1400,7 +1392,7 @@ PyObject* PyAST_mod2obj(mod_ty t) int COMPILER_STACK_FRAME_SCALE = 2; PyThreadState *tstate = _PyThreadState_GET(); if (!tstate) { - return 0; + return NULL; } state->recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE; int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining; @@ -1414,7 +1406,7 @@ PyObject* PyAST_mod2obj(mod_ty t) PyErr_Format(PyExc_SystemError, "AST constructor recursion depth mismatch (before=%d, after=%d)", starting_recursion_depth, state->recursion_depth); - return 0; + return NULL; } return result; } @@ -1481,7 +1473,8 @@ class ChainOfVisitors: def generate_ast_state(module_state, f): f.write('struct ast_state {\n') - f.write(' int initialized;\n') + f.write(' _PyOnceFlag once;\n') + f.write(' int finalized;\n') f.write(' int recursion_depth;\n') f.write(' int recursion_limit;\n') for s in module_state: @@ -1501,11 +1494,8 @@ def generate_ast_fini(module_state, f): f.write(textwrap.dedent(""" Py_CLEAR(_Py_INTERP_CACHED_OBJECT(interp, str_replace_inf)); - #if !defined(NDEBUG) - state->initialized = -1; - #else - state->initialized = 0; - #endif + state->finalized = 1; + state->once = (_PyOnceFlag){0}; } """)) @@ -1544,6 +1534,7 @@ def generate_module_def(mod, metadata, f, internal_h): #include "pycore_ast.h" #include "pycore_ast_state.h" // struct ast_state #include "pycore_ceval.h" // _Py_EnterRecursiveCall + #include "pycore_lock.h" // _PyOnceFlag #include "pycore_interp.h" // _PyInterpreterState.ast #include "pycore_pystate.h" // _PyInterpreterState_GET() #include <stddef.h> @@ -1556,7 +1547,8 @@ def generate_module_def(mod, metadata, f, internal_h): { PyInterpreterState *interp = _PyInterpreterState_GET(); struct ast_state *state = &interp->ast; - if (!init_types(state)) { + assert(!state->finalized); + if (_PyOnceFlag_CallOnce(&state->once, (_Py_once_fn_t *)&init_types, state) < 0) { return NULL; } return state; @@ -1570,8 +1562,8 @@ def generate_module_def(mod, metadata, f, internal_h): for identifier in state_strings: f.write(' if ((state->' + identifier) f.write(' = PyUnicode_InternFromString("') - f.write(identifier + '")) == NULL) return 0;\n') - f.write(' return 1;\n') + f.write(identifier + '")) == NULL) return -1;\n') + f.write(' return 0;\n') f.write('};\n\n') def write_header(mod, metadata, f): @@ -1629,6 +1621,9 @@ def write_internal_h_header(mod, f): print(textwrap.dedent(""" #ifndef Py_INTERNAL_AST_STATE_H #define Py_INTERNAL_AST_STATE_H + + #include "pycore_lock.h" // _PyOnceFlag + #ifdef __cplusplus extern "C" { #endif |