diff options
author | Jelle Zijlstra <jelle.zijlstra@gmail.com> | 2024-02-28 02:13:03 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-28 02:13:03 (GMT) |
commit | ed4dfd8825b49e16a0fcb9e67baf1b58bb8d438f (patch) | |
tree | 88935b427cd68a5a249f0876f3a1cbe5ce241ce8 /Parser/asdl_c.py | |
parent | 5a1559d9493dd298a08c4be32b52295aa3eb89e5 (diff) | |
download | cpython-ed4dfd8825b49e16a0fcb9e67baf1b58bb8d438f.zip cpython-ed4dfd8825b49e16a0fcb9e67baf1b58bb8d438f.tar.gz cpython-ed4dfd8825b49e16a0fcb9e67baf1b58bb8d438f.tar.bz2 |
gh-105858: Improve AST node constructors (#105880)
Demonstration:
>>> ast.FunctionDef.__annotations__
{'name': <class 'str'>, 'args': <class 'ast.arguments'>, 'body': list[ast.stmt], 'decorator_list': list[ast.expr], 'returns': ast.expr | None, 'type_comment': str | None, 'type_params': list[ast.type_param]}
>>> ast.FunctionDef()
<stdin>:1: DeprecationWarning: FunctionDef.__init__ missing 1 required positional argument: 'name'. This will become an error in Python 3.15.
<stdin>:1: DeprecationWarning: FunctionDef.__init__ missing 1 required positional argument: 'args'. This will become an error in Python 3.15.
<ast.FunctionDef object at 0x101959460>
>>> node = ast.FunctionDef(name="foo", args=ast.arguments())
>>> node.decorator_list
[]
>>> ast.FunctionDef(whatever="you want", name="x", args=ast.arguments())
<stdin>:1: DeprecationWarning: FunctionDef.__init__ got an unexpected keyword argument 'whatever'. Support for arbitrary keyword arguments is deprecated and will be removed in Python 3.15.
<ast.FunctionDef object at 0x1019581f0>
Diffstat (limited to 'Parser/asdl_c.py')
-rwxr-xr-x | Parser/asdl_c.py | 238 |
1 files changed, 231 insertions, 7 deletions
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py index ce92672..865fd76 100755 --- a/Parser/asdl_c.py +++ b/Parser/asdl_c.py @@ -15,6 +15,13 @@ TABSIZE = 4 MAX_COL = 80 AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n" +builtin_type_to_c_type = { + "identifier": "PyUnicode_Type", + "string": "PyUnicode_Type", + "int": "PyLong_Type", + "constant": "PyBaseObject_Type", +} + def get_c_type(name): """Return a string for the C name of the type. @@ -764,6 +771,67 @@ class PyTypesDeclareVisitor(PickleVisitor): self.emit("};",0) +class AnnotationsVisitor(PickleVisitor): + def visitModule(self, mod): + self.file.write(textwrap.dedent(''' + static int + add_ast_annotations(struct ast_state *state) + { + bool cond; + ''')) + for dfn in mod.dfns: + self.visit(dfn) + self.file.write(textwrap.dedent(''' + return 1; + } + ''')) + + def visitProduct(self, prod, name): + self.emit_annotations(name, prod.fields) + + def visitSum(self, sum, name): + for t in sum.types: + self.visitConstructor(t, name) + + def visitConstructor(self, cons, name): + self.emit_annotations(cons.name, cons.fields) + + def emit_annotations(self, name, fields): + self.emit(f"PyObject *{name}_annotations = PyDict_New();", 1) + self.emit(f"if (!{name}_annotations) return 0;", 1) + for field in fields: + self.emit("{", 1) + if field.type in builtin_type_to_c_type: + self.emit(f"PyObject *type = (PyObject *)&{builtin_type_to_c_type[field.type]};", 2) + else: + self.emit(f"PyObject *type = state->{field.type}_type;", 2) + if field.opt: + self.emit("type = _Py_union_type_or(type, Py_None);", 2) + self.emit("cond = type != NULL;", 2) + self.emit_annotations_error(name, 2) + elif field.seq: + self.emit("type = Py_GenericAlias((PyObject *)&PyList_Type, type);", 2) + self.emit("cond = type != NULL;", 2) + self.emit_annotations_error(name, 2) + else: + self.emit("Py_INCREF(type);", 2) + self.emit(f"cond = PyDict_SetItemString({name}_annotations, \"{field.name}\", type) == 0;", 2) + self.emit("Py_DECREF(type);", 2) + self.emit_annotations_error(name, 2) + self.emit("}", 1) + self.emit(f'cond = PyObject_SetAttrString(state->{name}_type, "_field_types", {name}_annotations) == 0;', 1) + self.emit_annotations_error(name, 1) + self.emit(f'cond = PyObject_SetAttrString(state->{name}_type, "__annotations__", {name}_annotations) == 0;', 1) + self.emit_annotations_error(name, 1) + self.emit(f"Py_DECREF({name}_annotations);", 1) + + def emit_annotations_error(self, name, depth): + self.emit("if (!cond) {", depth) + self.emit(f"Py_DECREF({name}_annotations);", depth + 1) + self.emit("return 0;", depth + 1) + self.emit("}", depth) + + class PyTypesVisitor(PickleVisitor): def visitModule(self, mod): @@ -812,7 +880,7 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw) Py_ssize_t i, numfields = 0; int res = -1; - PyObject *key, *value, *fields; + PyObject *key, *value, *fields, *remaining_fields = NULL; if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) { goto cleanup; } @@ -821,6 +889,13 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw) if (numfields == -1) { goto cleanup; } + remaining_fields = PySet_New(fields); + } + else { + remaining_fields = PySet_New(NULL); + } + if (remaining_fields == NULL) { + goto cleanup; } res = 0; /* if no error occurs, this stays 0 to the end */ @@ -840,6 +915,11 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw) goto cleanup; } res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i)); + if (PySet_Discard(remaining_fields, name) < 0) { + res = -1; + Py_DECREF(name); + goto cleanup; + } Py_DECREF(name); if (res < 0) { goto cleanup; @@ -852,13 +932,14 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw) if (contains == -1) { res = -1; goto cleanup; - } else if (contains == 1) { - Py_ssize_t p = PySequence_Index(fields, key); + } + else if (contains == 1) { + int p = PySet_Discard(remaining_fields, key); if (p == -1) { res = -1; goto cleanup; } - if (p < PyTuple_GET_SIZE(args)) { + if (p == 0) { PyErr_Format(PyExc_TypeError, "%.400s got multiple values for argument '%U'", Py_TYPE(self)->tp_name, key); @@ -866,15 +947,91 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw) goto cleanup; } } + else if ( + PyUnicode_CompareWithASCIIString(key, "lineno") != 0 && + PyUnicode_CompareWithASCIIString(key, "col_offset") != 0 && + PyUnicode_CompareWithASCIIString(key, "end_lineno") != 0 && + PyUnicode_CompareWithASCIIString(key, "end_col_offset") != 0 + ) { + if (PyErr_WarnFormat( + PyExc_DeprecationWarning, 1, + "%.400s.__init__ got an unexpected keyword argument '%U'. " + "Support for arbitrary keyword arguments is deprecated " + "and will be removed in Python 3.15.", + Py_TYPE(self)->tp_name, key + ) < 0) { + res = -1; + goto cleanup; + } + } res = PyObject_SetAttr(self, key, value); if (res < 0) { goto cleanup; } } } + Py_ssize_t size = PySet_Size(remaining_fields); + PyObject *field_types = NULL, *remaining_list = NULL; + if (size > 0) { + if (!PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), &_Py_ID(_field_types), + &field_types)) { + res = -1; + goto cleanup; + } + remaining_list = PySequence_List(remaining_fields); + if (!remaining_list) { + goto set_remaining_cleanup; + } + for (Py_ssize_t i = 0; i < size; i++) { + PyObject *name = PyList_GET_ITEM(remaining_list, i); + PyObject *type = PyDict_GetItemWithError(field_types, name); + if (!type) { + if (!PyErr_Occurred()) { + PyErr_SetObject(PyExc_KeyError, name); + } + goto set_remaining_cleanup; + } + if (_PyUnion_Check(type)) { + // optional field + // do nothing, we'll have set a None default on the class + } + else if (Py_IS_TYPE(type, &Py_GenericAliasType)) { + // list field + PyObject *empty = PyList_New(0); + if (!empty) { + goto set_remaining_cleanup; + } + res = PyObject_SetAttr(self, name, empty); + Py_DECREF(empty); + if (res < 0) { + goto set_remaining_cleanup; + } + } + else { + // simple field (e.g., identifier) + if (PyErr_WarnFormat( + PyExc_DeprecationWarning, 1, + "%.400s.__init__ missing 1 required positional argument: '%U'. " + "This will become an error in Python 3.15.", + Py_TYPE(self)->tp_name, name + ) < 0) { + res = -1; + goto cleanup; + } + } + } + Py_DECREF(remaining_list); + Py_DECREF(field_types); + } cleanup: Py_XDECREF(fields); + Py_XDECREF(remaining_fields); return res; + set_remaining_cleanup: + Py_XDECREF(remaining_list); + Py_XDECREF(field_types); + res = -1; + goto cleanup; } /* Pickling support */ @@ -886,14 +1043,75 @@ ast_type_reduce(PyObject *self, PyObject *unused) return NULL; } - PyObject *dict; + PyObject *dict = NULL, *fields = NULL, *remaining_fields = NULL, + *remaining_dict = NULL, *positional_args = NULL; if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) { return NULL; } + PyObject *result = NULL; if (dict) { - return Py_BuildValue("O()N", Py_TYPE(self), dict); + // Serialize the fields as positional args if possible, because if we + // serialize them as a dict, during unpickling they are set only *after* + // the object is constructed, which will now trigger a DeprecationWarning + // if the AST type has required fields. + if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) { + goto cleanup; + } + if (fields) { + Py_ssize_t numfields = PySequence_Size(fields); + if (numfields == -1) { + Py_DECREF(dict); + goto cleanup; + } + remaining_dict = PyDict_Copy(dict); + Py_DECREF(dict); + if (!remaining_dict) { + goto cleanup; + } + positional_args = PyList_New(0); + if (!positional_args) { + goto cleanup; + } + for (Py_ssize_t i = 0; i < numfields; i++) { + PyObject *name = PySequence_GetItem(fields, i); + if (!name) { + goto cleanup; + } + PyObject *value = PyDict_GetItemWithError(remaining_dict, name); + if (!value) { + if (PyErr_Occurred()) { + goto cleanup; + } + break; + } + if (PyList_Append(positional_args, value) < 0) { + goto cleanup; + } + if (PyDict_DelItem(remaining_dict, name) < 0) { + goto cleanup; + } + Py_DECREF(name); + } + PyObject *args_tuple = PyList_AsTuple(positional_args); + if (!args_tuple) { + goto cleanup; + } + result = Py_BuildValue("ONO", Py_TYPE(self), args_tuple, + remaining_dict); + } + else { + result = Py_BuildValue("O()N", Py_TYPE(self), dict); + } + } + else { + result = Py_BuildValue("O()", Py_TYPE(self)); } - return Py_BuildValue("O()", Py_TYPE(self)); +cleanup: + Py_XDECREF(fields); + Py_XDECREF(remaining_fields); + Py_XDECREF(remaining_dict); + Py_XDECREF(positional_args); + return result; } static PyMemberDef ast_type_members[] = { @@ -1117,6 +1335,9 @@ static int add_ast_fields(struct ast_state *state) for dfn in mod.dfns: self.visit(dfn) self.file.write(textwrap.dedent(''' + if (!add_ast_annotations(state)) { + return -1; + } return 0; } ''')) @@ -1534,6 +1755,8 @@ def generate_module_def(mod, metadata, f, internal_h): #include "pycore_lock.h" // _PyOnceFlag #include "pycore_interp.h" // _PyInterpreterState.ast #include "pycore_pystate.h" // _PyInterpreterState_GET() + #include "pycore_unionobject.h" // _Py_union_type_or + #include "structmember.h" #include <stddef.h> struct validator { @@ -1651,6 +1874,7 @@ def write_source(mod, metadata, f, internal_h_file): v = ChainOfVisitors( SequenceConstructorVisitor(f), PyTypesDeclareVisitor(f), + AnnotationsVisitor(f), PyTypesVisitor(f), Obj2ModPrototypeVisitor(f), FunctionVisitor(f), |