summaryrefslogtreecommitdiffstats
path: root/Parser/asdl_c.py
diff options
context:
space:
mode:
authorJelle Zijlstra <jelle.zijlstra@gmail.com>2024-02-28 02:13:03 (GMT)
committerGitHub <noreply@github.com>2024-02-28 02:13:03 (GMT)
commited4dfd8825b49e16a0fcb9e67baf1b58bb8d438f (patch)
tree88935b427cd68a5a249f0876f3a1cbe5ce241ce8 /Parser/asdl_c.py
parent5a1559d9493dd298a08c4be32b52295aa3eb89e5 (diff)
downloadcpython-ed4dfd8825b49e16a0fcb9e67baf1b58bb8d438f.zip
cpython-ed4dfd8825b49e16a0fcb9e67baf1b58bb8d438f.tar.gz
cpython-ed4dfd8825b49e16a0fcb9e67baf1b58bb8d438f.tar.bz2
gh-105858: Improve AST node constructors (#105880)
Demonstration: >>> ast.FunctionDef.__annotations__ {'name': <class 'str'>, 'args': <class 'ast.arguments'>, 'body': list[ast.stmt], 'decorator_list': list[ast.expr], 'returns': ast.expr | None, 'type_comment': str | None, 'type_params': list[ast.type_param]} >>> ast.FunctionDef() <stdin>:1: DeprecationWarning: FunctionDef.__init__ missing 1 required positional argument: 'name'. This will become an error in Python 3.15. <stdin>:1: DeprecationWarning: FunctionDef.__init__ missing 1 required positional argument: 'args'. This will become an error in Python 3.15. <ast.FunctionDef object at 0x101959460> >>> node = ast.FunctionDef(name="foo", args=ast.arguments()) >>> node.decorator_list [] >>> ast.FunctionDef(whatever="you want", name="x", args=ast.arguments()) <stdin>:1: DeprecationWarning: FunctionDef.__init__ got an unexpected keyword argument 'whatever'. Support for arbitrary keyword arguments is deprecated and will be removed in Python 3.15. <ast.FunctionDef object at 0x1019581f0>
Diffstat (limited to 'Parser/asdl_c.py')
-rwxr-xr-xParser/asdl_c.py238
1 files changed, 231 insertions, 7 deletions
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py
index ce92672..865fd76 100755
--- a/Parser/asdl_c.py
+++ b/Parser/asdl_c.py
@@ -15,6 +15,13 @@ TABSIZE = 4
MAX_COL = 80
AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n"
+builtin_type_to_c_type = {
+ "identifier": "PyUnicode_Type",
+ "string": "PyUnicode_Type",
+ "int": "PyLong_Type",
+ "constant": "PyBaseObject_Type",
+}
+
def get_c_type(name):
"""Return a string for the C name of the type.
@@ -764,6 +771,67 @@ class PyTypesDeclareVisitor(PickleVisitor):
self.emit("};",0)
+class AnnotationsVisitor(PickleVisitor):
+ def visitModule(self, mod):
+ self.file.write(textwrap.dedent('''
+ static int
+ add_ast_annotations(struct ast_state *state)
+ {
+ bool cond;
+ '''))
+ for dfn in mod.dfns:
+ self.visit(dfn)
+ self.file.write(textwrap.dedent('''
+ return 1;
+ }
+ '''))
+
+ def visitProduct(self, prod, name):
+ self.emit_annotations(name, prod.fields)
+
+ def visitSum(self, sum, name):
+ for t in sum.types:
+ self.visitConstructor(t, name)
+
+ def visitConstructor(self, cons, name):
+ self.emit_annotations(cons.name, cons.fields)
+
+ def emit_annotations(self, name, fields):
+ self.emit(f"PyObject *{name}_annotations = PyDict_New();", 1)
+ self.emit(f"if (!{name}_annotations) return 0;", 1)
+ for field in fields:
+ self.emit("{", 1)
+ if field.type in builtin_type_to_c_type:
+ self.emit(f"PyObject *type = (PyObject *)&{builtin_type_to_c_type[field.type]};", 2)
+ else:
+ self.emit(f"PyObject *type = state->{field.type}_type;", 2)
+ if field.opt:
+ self.emit("type = _Py_union_type_or(type, Py_None);", 2)
+ self.emit("cond = type != NULL;", 2)
+ self.emit_annotations_error(name, 2)
+ elif field.seq:
+ self.emit("type = Py_GenericAlias((PyObject *)&PyList_Type, type);", 2)
+ self.emit("cond = type != NULL;", 2)
+ self.emit_annotations_error(name, 2)
+ else:
+ self.emit("Py_INCREF(type);", 2)
+ self.emit(f"cond = PyDict_SetItemString({name}_annotations, \"{field.name}\", type) == 0;", 2)
+ self.emit("Py_DECREF(type);", 2)
+ self.emit_annotations_error(name, 2)
+ self.emit("}", 1)
+ self.emit(f'cond = PyObject_SetAttrString(state->{name}_type, "_field_types", {name}_annotations) == 0;', 1)
+ self.emit_annotations_error(name, 1)
+ self.emit(f'cond = PyObject_SetAttrString(state->{name}_type, "__annotations__", {name}_annotations) == 0;', 1)
+ self.emit_annotations_error(name, 1)
+ self.emit(f"Py_DECREF({name}_annotations);", 1)
+
+ def emit_annotations_error(self, name, depth):
+ self.emit("if (!cond) {", depth)
+ self.emit(f"Py_DECREF({name}_annotations);", depth + 1)
+ self.emit("return 0;", depth + 1)
+ self.emit("}", depth)
+
+
class PyTypesVisitor(PickleVisitor):
def visitModule(self, mod):
@@ -812,7 +880,7 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
Py_ssize_t i, numfields = 0;
int res = -1;
- PyObject *key, *value, *fields;
+ PyObject *key, *value, *fields, *remaining_fields = NULL;
if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
goto cleanup;
}
@@ -821,6 +889,13 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
if (numfields == -1) {
goto cleanup;
}
+ remaining_fields = PySet_New(fields);
+ }
+ else {
+ remaining_fields = PySet_New(NULL);
+ }
+ if (remaining_fields == NULL) {
+ goto cleanup;
}
res = 0; /* if no error occurs, this stays 0 to the end */
@@ -840,6 +915,11 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
goto cleanup;
}
res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i));
+ if (PySet_Discard(remaining_fields, name) < 0) {
+ res = -1;
+ Py_DECREF(name);
+ goto cleanup;
+ }
Py_DECREF(name);
if (res < 0) {
goto cleanup;
@@ -852,13 +932,14 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
if (contains == -1) {
res = -1;
goto cleanup;
- } else if (contains == 1) {
- Py_ssize_t p = PySequence_Index(fields, key);
+ }
+ else if (contains == 1) {
+ int p = PySet_Discard(remaining_fields, key);
if (p == -1) {
res = -1;
goto cleanup;
}
- if (p < PyTuple_GET_SIZE(args)) {
+ if (p == 0) {
PyErr_Format(PyExc_TypeError,
"%.400s got multiple values for argument '%U'",
Py_TYPE(self)->tp_name, key);
@@ -866,15 +947,91 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
goto cleanup;
}
}
+ else if (
+ PyUnicode_CompareWithASCIIString(key, "lineno") != 0 &&
+ PyUnicode_CompareWithASCIIString(key, "col_offset") != 0 &&
+ PyUnicode_CompareWithASCIIString(key, "end_lineno") != 0 &&
+ PyUnicode_CompareWithASCIIString(key, "end_col_offset") != 0
+ ) {
+ if (PyErr_WarnFormat(
+ PyExc_DeprecationWarning, 1,
+ "%.400s.__init__ got an unexpected keyword argument '%U'. "
+ "Support for arbitrary keyword arguments is deprecated "
+ "and will be removed in Python 3.15.",
+ Py_TYPE(self)->tp_name, key
+ ) < 0) {
+ res = -1;
+ goto cleanup;
+ }
+ }
res = PyObject_SetAttr(self, key, value);
if (res < 0) {
goto cleanup;
}
}
}
+ Py_ssize_t size = PySet_Size(remaining_fields);
+ PyObject *field_types = NULL, *remaining_list = NULL;
+ if (size > 0) {
+ if (!PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), &_Py_ID(_field_types),
+ &field_types)) {
+ res = -1;
+ goto cleanup;
+ }
+ remaining_list = PySequence_List(remaining_fields);
+ if (!remaining_list) {
+ goto set_remaining_cleanup;
+ }
+ for (Py_ssize_t i = 0; i < size; i++) {
+ PyObject *name = PyList_GET_ITEM(remaining_list, i);
+ PyObject *type = PyDict_GetItemWithError(field_types, name);
+ if (!type) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetObject(PyExc_KeyError, name);
+ }
+ goto set_remaining_cleanup;
+ }
+ if (_PyUnion_Check(type)) {
+ // optional field
+ // do nothing, we'll have set a None default on the class
+ }
+ else if (Py_IS_TYPE(type, &Py_GenericAliasType)) {
+ // list field
+ PyObject *empty = PyList_New(0);
+ if (!empty) {
+ goto set_remaining_cleanup;
+ }
+ res = PyObject_SetAttr(self, name, empty);
+ Py_DECREF(empty);
+ if (res < 0) {
+ goto set_remaining_cleanup;
+ }
+ }
+ else {
+ // simple field (e.g., identifier)
+ if (PyErr_WarnFormat(
+ PyExc_DeprecationWarning, 1,
+ "%.400s.__init__ missing 1 required positional argument: '%U'. "
+ "This will become an error in Python 3.15.",
+ Py_TYPE(self)->tp_name, name
+ ) < 0) {
+ res = -1;
+ goto cleanup;
+ }
+ }
+ }
+ Py_DECREF(remaining_list);
+ Py_DECREF(field_types);
+ }
cleanup:
Py_XDECREF(fields);
+ Py_XDECREF(remaining_fields);
return res;
+ set_remaining_cleanup:
+ Py_XDECREF(remaining_list);
+ Py_XDECREF(field_types);
+ res = -1;
+ goto cleanup;
}
/* Pickling support */
@@ -886,14 +1043,75 @@ ast_type_reduce(PyObject *self, PyObject *unused)
return NULL;
}
- PyObject *dict;
+ PyObject *dict = NULL, *fields = NULL, *remaining_fields = NULL,
+ *remaining_dict = NULL, *positional_args = NULL;
if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) {
return NULL;
}
+ PyObject *result = NULL;
if (dict) {
- return Py_BuildValue("O()N", Py_TYPE(self), dict);
+ // Serialize the fields as positional args if possible, because if we
+ // serialize them as a dict, during unpickling they are set only *after*
+ // the object is constructed, which will now trigger a DeprecationWarning
+ // if the AST type has required fields.
+ if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
+ goto cleanup;
+ }
+ if (fields) {
+ Py_ssize_t numfields = PySequence_Size(fields);
+ if (numfields == -1) {
+ Py_DECREF(dict);
+ goto cleanup;
+ }
+ remaining_dict = PyDict_Copy(dict);
+ Py_DECREF(dict);
+ if (!remaining_dict) {
+ goto cleanup;
+ }
+ positional_args = PyList_New(0);
+ if (!positional_args) {
+ goto cleanup;
+ }
+ for (Py_ssize_t i = 0; i < numfields; i++) {
+ PyObject *name = PySequence_GetItem(fields, i);
+ if (!name) {
+ goto cleanup;
+ }
+ PyObject *value = PyDict_GetItemWithError(remaining_dict, name);
+ if (!value) {
+ if (PyErr_Occurred()) {
+ goto cleanup;
+ }
+ break;
+ }
+ if (PyList_Append(positional_args, value) < 0) {
+ goto cleanup;
+ }
+ if (PyDict_DelItem(remaining_dict, name) < 0) {
+ goto cleanup;
+ }
+ Py_DECREF(name);
+ }
+ PyObject *args_tuple = PyList_AsTuple(positional_args);
+ if (!args_tuple) {
+ goto cleanup;
+ }
+ result = Py_BuildValue("ONO", Py_TYPE(self), args_tuple,
+ remaining_dict);
+ }
+ else {
+ result = Py_BuildValue("O()N", Py_TYPE(self), dict);
+ }
+ }
+ else {
+ result = Py_BuildValue("O()", Py_TYPE(self));
}
- return Py_BuildValue("O()", Py_TYPE(self));
+cleanup:
+ Py_XDECREF(fields);
+ Py_XDECREF(remaining_fields);
+ Py_XDECREF(remaining_dict);
+ Py_XDECREF(positional_args);
+ return result;
}
static PyMemberDef ast_type_members[] = {
@@ -1117,6 +1335,9 @@ static int add_ast_fields(struct ast_state *state)
for dfn in mod.dfns:
self.visit(dfn)
self.file.write(textwrap.dedent('''
+ if (!add_ast_annotations(state)) {
+ return -1;
+ }
return 0;
}
'''))
@@ -1534,6 +1755,8 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "pycore_lock.h" // _PyOnceFlag
#include "pycore_interp.h" // _PyInterpreterState.ast
#include "pycore_pystate.h" // _PyInterpreterState_GET()
+ #include "pycore_unionobject.h" // _Py_union_type_or
+ #include "structmember.h"
#include <stddef.h>
struct validator {
@@ -1651,6 +1874,7 @@ def write_source(mod, metadata, f, internal_h_file):
v = ChainOfVisitors(
SequenceConstructorVisitor(f),
PyTypesDeclareVisitor(f),
+ AnnotationsVisitor(f),
PyTypesVisitor(f),
Obj2ModPrototypeVisitor(f),
FunctionVisitor(f),