summaryrefslogtreecommitdiffstats
path: root/Parser
diff options
context:
space:
mode:
authorSam Gross <colesbury@gmail.com>2023-11-16 19:19:54 (GMT)
committerGitHub <noreply@github.com>2023-11-16 19:19:54 (GMT)
commit446f18a911916eabd2c0ceed0c2a109fc8480727 (patch)
treedfe463feacf3aed8db9f622d649e8221b4101e03 /Parser
parentf66afa395a6d06097ad1ca222ed076e18a7a8126 (diff)
downloadcpython-446f18a911916eabd2c0ceed0c2a109fc8480727.zip
cpython-446f18a911916eabd2c0ceed0c2a109fc8480727.tar.gz
cpython-446f18a911916eabd2c0ceed0c2a109fc8480727.tar.bz2
gh-111956: Add thread-safe one-time initialization. (gh-111960)
Diffstat (limited to 'Parser')
-rwxr-xr-xParser/asdl_c.py89
1 files changed, 42 insertions, 47 deletions
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py
index ae642e8..c9bf08e 100755
--- a/Parser/asdl_c.py
+++ b/Parser/asdl_c.py
@@ -518,7 +518,7 @@ class Obj2ModVisitor(PickleVisitor):
if add_label:
self.emit("failed:", 1)
self.emit("Py_XDECREF(tmp);", 1)
- self.emit("return 1;", 1)
+ self.emit("return -1;", 1)
self.emit("}", 0)
self.emit("", 0)
@@ -529,7 +529,7 @@ class Obj2ModVisitor(PickleVisitor):
"state->%s_type);")
self.emit(line % (t.name,), 1)
self.emit("if (isinstance == -1) {", 1)
- self.emit("return 1;", 2)
+ self.emit("return -1;", 2)
self.emit("}", 1)
self.emit("if (isinstance) {", 1)
self.emit("*out = %s;" % t.name, 2)
@@ -558,7 +558,7 @@ class Obj2ModVisitor(PickleVisitor):
self.emit("tp = state->%s_type;" % (t.name,), 1)
self.emit("isinstance = PyObject_IsInstance(obj, tp);", 1)
self.emit("if (isinstance == -1) {", 1)
- self.emit("return 1;", 2)
+ self.emit("return -1;", 2)
self.emit("}", 1)
self.emit("if (isinstance) {", 1)
for f in t.fields:
@@ -605,7 +605,7 @@ class Obj2ModVisitor(PickleVisitor):
self.emit("return 0;", 1)
self.emit("failed:", 0)
self.emit("Py_XDECREF(tmp);", 1)
- self.emit("return 1;", 1)
+ self.emit("return -1;", 1)
self.emit("}", 0)
self.emit("", 0)
@@ -631,13 +631,13 @@ class Obj2ModVisitor(PickleVisitor):
ctype = get_c_type(field.type)
line = "if (PyObject_GetOptionalAttr(obj, state->%s, &tmp) < 0) {"
self.emit(line % field.name, depth)
- self.emit("return 1;", depth+1)
+ self.emit("return -1;", depth+1)
self.emit("}", depth)
if field.seq:
self.emit("if (tmp == NULL) {", depth)
self.emit("tmp = PyList_New(0);", depth+1)
self.emit("if (tmp == NULL) {", depth+1)
- self.emit("return 1;", depth+2)
+ self.emit("return -1;", depth+2)
self.emit("}", depth+1)
self.emit("}", depth)
self.emit("{", depth)
@@ -647,7 +647,7 @@ class Obj2ModVisitor(PickleVisitor):
message = "required field \\\"%s\\\" missing from %s" % (field.name, name)
format = "PyErr_SetString(PyExc_TypeError, \"%s\");"
self.emit(format % message, depth+1, reflow=False)
- self.emit("return 1;", depth+1)
+ self.emit("return -1;", depth+1)
else:
self.emit("if (tmp == NULL || tmp == Py_None) {", depth)
self.emit("Py_CLEAR(tmp);", depth+1)
@@ -968,16 +968,16 @@ add_attributes(struct ast_state *state, PyObject *type, const char * const *attr
int i, result;
PyObject *s, *l = PyTuple_New(num_fields);
if (!l)
- return 0;
+ return -1;
for (i = 0; i < num_fields; i++) {
s = PyUnicode_InternFromString(attrs[i]);
if (!s) {
Py_DECREF(l);
- return 0;
+ return -1;
}
PyTuple_SET_ITEM(l, i, s);
}
- result = PyObject_SetAttr(type, state->_attributes, l) >= 0;
+ result = PyObject_SetAttr(type, state->_attributes, l);
Py_DECREF(l);
return result;
}
@@ -1052,7 +1052,7 @@ static int obj2ast_identifier(struct ast_state *state, PyObject* obj, PyObject**
{
if (!PyUnicode_CheckExact(obj) && obj != Py_None) {
PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str");
- return 1;
+ return -1;
}
return obj2ast_object(state, obj, out, arena);
}
@@ -1061,7 +1061,7 @@ static int obj2ast_string(struct ast_state *state, PyObject* obj, PyObject** out
{
if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) {
PyErr_SetString(PyExc_TypeError, "AST string must be of type str");
- return 1;
+ return -1;
}
return obj2ast_object(state, obj, out, arena);
}
@@ -1071,12 +1071,12 @@ static int obj2ast_int(struct ast_state* Py_UNUSED(state), PyObject* obj, int* o
int i;
if (!PyLong_Check(obj)) {
PyErr_Format(PyExc_ValueError, "invalid integer value: %R", obj);
- return 1;
+ return -1;
}
i = PyLong_AsInt(obj);
if (i == -1 && PyErr_Occurred())
- return 1;
+ return -1;
*out = i;
return 0;
}
@@ -1102,22 +1102,15 @@ static int add_ast_fields(struct ast_state *state)
static int
init_types(struct ast_state *state)
{
- // init_types() must not be called after _PyAST_Fini()
- // has been called
- assert(state->initialized >= 0);
-
- if (state->initialized) {
- return 1;
- }
if (init_identifiers(state) < 0) {
- return 0;
+ return -1;
}
state->AST_type = PyType_FromSpec(&AST_type_spec);
if (!state->AST_type) {
- return 0;
+ return -1;
}
if (add_ast_fields(state) < 0) {
- return 0;
+ return -1;
}
'''))
for dfn in mod.dfns:
@@ -1125,8 +1118,7 @@ static int add_ast_fields(struct ast_state *state)
self.file.write(textwrap.dedent('''
state->recursion_depth = 0;
state->recursion_limit = 0;
- state->initialized = 1;
- return 1;
+ return 0;
}
'''))
@@ -1138,12 +1130,12 @@ static int add_ast_fields(struct ast_state *state)
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' %
(name, name, fields, len(prod.fields)), 1)
self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False)
- self.emit("if (!state->%s_type) return 0;" % name, 1)
+ self.emit("if (!state->%s_type) return -1;" % name, 1)
if prod.attributes:
- self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" %
+ self.emit("if (add_attributes(state, state->%s_type, %s_attributes, %d) < 0) return -1;" %
(name, name, len(prod.attributes)), 1)
else:
- self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
+ self.emit("if (add_attributes(state, state->%s_type, NULL, 0) < 0) return -1;" % name, 1)
self.emit_defaults(name, prod.fields, 1)
self.emit_defaults(name, prod.attributes, 1)
@@ -1151,12 +1143,12 @@ static int add_ast_fields(struct ast_state *state)
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' %
(name, name), 1)
self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False)
- self.emit("if (!state->%s_type) return 0;" % name, 1)
+ self.emit("if (!state->%s_type) return -1;" % name, 1)
if sum.attributes:
- self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" %
+ self.emit("if (add_attributes(state, state->%s_type, %s_attributes, %d) < 0) return -1;" %
(name, name, len(sum.attributes)), 1)
else:
- self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
+ self.emit("if (add_attributes(state, state->%s_type, NULL, 0) < 0) return -1;" % name, 1)
self.emit_defaults(name, sum.attributes, 1)
simple = is_simple(sum)
for t in sum.types:
@@ -1170,20 +1162,20 @@ static int add_ast_fields(struct ast_state *state)
self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' %
(cons.name, cons.name, name, fields, len(cons.fields)), 1)
self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False)
- self.emit("if (!state->%s_type) return 0;" % cons.name, 1)
+ self.emit("if (!state->%s_type) return -1;" % cons.name, 1)
self.emit_defaults(cons.name, cons.fields, 1)
if simple:
self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)"
"state->%s_type, NULL, NULL);" %
(cons.name, cons.name), 1)
- self.emit("if (!state->%s_singleton) return 0;" % cons.name, 1)
+ self.emit("if (!state->%s_singleton) return -1;" % cons.name, 1)
def emit_defaults(self, name, fields, depth):
for field in fields:
if field.opt:
self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' %
(name, field.name), depth)
- self.emit("return 0;", depth+1)
+ self.emit("return -1;", depth+1)
class ASTModuleVisitor(PickleVisitor):
@@ -1279,7 +1271,7 @@ class ObjVisitor(PickleVisitor):
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
- self.emit("return 0;", 2)
+ self.emit("return NULL;", 2)
self.emit("}", 1)
def func_end(self):
@@ -1400,7 +1392,7 @@ PyObject* PyAST_mod2obj(mod_ty t)
int COMPILER_STACK_FRAME_SCALE = 2;
PyThreadState *tstate = _PyThreadState_GET();
if (!tstate) {
- return 0;
+ return NULL;
}
state->recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
@@ -1414,7 +1406,7 @@ PyObject* PyAST_mod2obj(mod_ty t)
PyErr_Format(PyExc_SystemError,
"AST constructor recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state->recursion_depth);
- return 0;
+ return NULL;
}
return result;
}
@@ -1481,7 +1473,8 @@ class ChainOfVisitors:
def generate_ast_state(module_state, f):
f.write('struct ast_state {\n')
- f.write(' int initialized;\n')
+ f.write(' _PyOnceFlag once;\n')
+ f.write(' int finalized;\n')
f.write(' int recursion_depth;\n')
f.write(' int recursion_limit;\n')
for s in module_state:
@@ -1501,11 +1494,8 @@ def generate_ast_fini(module_state, f):
f.write(textwrap.dedent("""
Py_CLEAR(_Py_INTERP_CACHED_OBJECT(interp, str_replace_inf));
- #if !defined(NDEBUG)
- state->initialized = -1;
- #else
- state->initialized = 0;
- #endif
+ state->finalized = 1;
+ state->once = (_PyOnceFlag){0};
}
"""))
@@ -1544,6 +1534,7 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "pycore_ast.h"
#include "pycore_ast_state.h" // struct ast_state
#include "pycore_ceval.h" // _Py_EnterRecursiveCall
+ #include "pycore_lock.h" // _PyOnceFlag
#include "pycore_interp.h" // _PyInterpreterState.ast
#include "pycore_pystate.h" // _PyInterpreterState_GET()
#include <stddef.h>
@@ -1556,7 +1547,8 @@ def generate_module_def(mod, metadata, f, internal_h):
{
PyInterpreterState *interp = _PyInterpreterState_GET();
struct ast_state *state = &interp->ast;
- if (!init_types(state)) {
+ assert(!state->finalized);
+ if (_PyOnceFlag_CallOnce(&state->once, (_Py_once_fn_t *)&init_types, state) < 0) {
return NULL;
}
return state;
@@ -1570,8 +1562,8 @@ def generate_module_def(mod, metadata, f, internal_h):
for identifier in state_strings:
f.write(' if ((state->' + identifier)
f.write(' = PyUnicode_InternFromString("')
- f.write(identifier + '")) == NULL) return 0;\n')
- f.write(' return 1;\n')
+ f.write(identifier + '")) == NULL) return -1;\n')
+ f.write(' return 0;\n')
f.write('};\n\n')
def write_header(mod, metadata, f):
@@ -1629,6 +1621,9 @@ def write_internal_h_header(mod, f):
print(textwrap.dedent("""
#ifndef Py_INTERNAL_AST_STATE_H
#define Py_INTERNAL_AST_STATE_H
+
+ #include "pycore_lock.h" // _PyOnceFlag
+
#ifdef __cplusplus
extern "C" {
#endif