summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xParser/asdl_c.py99
1 files changed, 75 insertions, 24 deletions
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py
index 4e5c5c8..371730a 100755
--- a/Parser/asdl_c.py
+++ b/Parser/asdl_c.py
@@ -4,6 +4,7 @@
import os
import sys
import textwrap
+import types
from argparse import ArgumentParser
from contextlib import contextmanager
@@ -100,11 +101,12 @@ def asdl_of(name, obj):
class EmitVisitor(asdl.VisitorBase):
"""Visit that emits lines"""
- def __init__(self, file):
+ def __init__(self, file, metadata = None):
self.file = file
self.identifiers = set()
self.singletons = set()
self.types = set()
+ self._metadata = metadata
super(EmitVisitor, self).__init__()
def emit_identifier(self, name):
@@ -127,6 +129,42 @@ class EmitVisitor(asdl.VisitorBase):
line = (" " * TABSIZE * depth) + line
self.file.write(line + "\n")
+ @property
+ def metadata(self):
+ if self._metadata is None:
+ raise ValueError(
+ "%s was expecting to be annnotated with metadata"
+ % type(self).__name__
+ )
+ return self._metadata
+
+ @metadata.setter
+ def metadata(self, value):
+ self._metadata = value
+
+class MetadataVisitor(asdl.VisitorBase):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Metadata:
+ # - simple_sums: Tracks the list of compound type
+ # names where all the constructors
+ # belonging to that type lack of any
+ # fields.
+ self.metadata = types.SimpleNamespace(
+ simple_sums=set()
+ )
+
+ def visitModule(self, mod):
+ for dfn in mod.dfns:
+ self.visit(dfn)
+
+ def visitType(self, type):
+ self.visit(type.value, type.name)
+
+ def visitSum(self, sum, name):
+ if is_simple(sum):
+ self.metadata.simple_sums.add(name)
class TypeDefVisitor(EmitVisitor):
def visitModule(self, mod):
@@ -244,7 +282,7 @@ class StructVisitor(EmitVisitor):
ctype = get_c_type(field.type)
name = field.name
if field.seq:
- if field.type == 'cmpop':
+ if field.type in self.metadata.simple_sums:
self.emit("asdl_int_seq *%(name)s;" % locals(), depth)
else:
_type = field.type
@@ -304,7 +342,7 @@ class PrototypeVisitor(EmitVisitor):
name = f.name
# XXX should extend get_c_type() to handle this
if f.seq:
- if f.type == 'cmpop':
+ if f.type in self.metadata.simple_sums:
ctype = "asdl_int_seq *"
else:
ctype = f"asdl_{f.type}_seq *"
@@ -549,16 +587,11 @@ class Obj2ModVisitor(PickleVisitor):
ctype = get_c_type(field.type)
self.emit("%s %s;" % (ctype, field.name), depth)
- def isSimpleSum(self, field):
- # XXX can the members of this list be determined automatically?
- return field.type in ('expr_context', 'boolop', 'operator',
- 'unaryop', 'cmpop')
-
def isNumeric(self, field):
return get_c_type(field.type) in ("int", "bool")
def isSimpleType(self, field):
- return self.isSimpleSum(field) or self.isNumeric(field)
+ return field.type in self.metadata.simple_sums or self.isNumeric(field)
def visitField(self, field, name, sum=None, prod=None, depth=0):
ctype = get_c_type(field.type)
@@ -1282,18 +1315,23 @@ class ObjVisitor(PickleVisitor):
def set(self, field, value, depth):
if field.seq:
- # XXX should really check for is_simple, but that requires a symbol table
- if field.type == "cmpop":
+ if field.type in self.metadata.simple_sums:
# While the sequence elements are stored as void*,
- # ast2obj_cmpop expects an enum
+ # simple sums expects an enum
self.emit("{", depth)
self.emit("Py_ssize_t i, n = asdl_seq_LEN(%s);" % value, depth+1)
self.emit("value = PyList_New(n);", depth+1)
self.emit("if (!value) goto failed;", depth+1)
self.emit("for(i = 0; i < n; i++)", depth+1)
# This cannot fail, so no need for error handling
- self.emit("PyList_SET_ITEM(value, i, ast2obj_cmpop(state, (cmpop_ty)asdl_seq_GET(%s, i)));" % value,
- depth+2, reflow=False)
+ self.emit(
+ "PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
+ field.type,
+ value
+ ),
+ depth + 2,
+ reflow=False,
+ )
self.emit("}", depth)
else:
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
@@ -1362,11 +1400,13 @@ int PyAST_Check(PyObject* obj)
"""
class ChainOfVisitors:
- def __init__(self, *visitors):
+ def __init__(self, *visitors, metadata = None):
self.visitors = visitors
+ self.metadata = metadata
def visit(self, object):
for v in self.visitors:
+ v.metadata = self.metadata
v.visit(object)
v.emit("", 0)
@@ -1468,7 +1508,7 @@ def generate_module_def(mod, f, internal_h):
f.write(' return 1;\n')
f.write('};\n\n')
-def write_header(mod, f):
+def write_header(mod, metadata, f):
f.write(textwrap.dedent("""
#ifndef Py_INTERNAL_AST_H
#define Py_INTERNAL_AST_H
@@ -1483,12 +1523,19 @@ def write_header(mod, f):
#include "pycore_asdl.h"
""").lstrip())
- c = ChainOfVisitors(TypeDefVisitor(f),
- SequenceDefVisitor(f),
- StructVisitor(f))
+
+ c = ChainOfVisitors(
+ TypeDefVisitor(f),
+ SequenceDefVisitor(f),
+ StructVisitor(f),
+ metadata=metadata
+ )
c.visit(mod)
+
f.write("// Note: these macros affect function definitions, not only call sites.\n")
- PrototypeVisitor(f).visit(mod)
+ prototype_visitor = PrototypeVisitor(f, metadata=metadata)
+ prototype_visitor.visit(mod)
+
f.write(textwrap.dedent("""
PyObject* PyAST_mod2obj(mod_ty t);
@@ -1535,8 +1582,7 @@ def write_internal_h_footer(mod, f):
#endif /* !Py_INTERNAL_AST_STATE_H */
"""), file=f)
-
-def write_source(mod, f, internal_h_file):
+def write_source(mod, metadata, f, internal_h_file):
generate_module_def(mod, f, internal_h_file)
v = ChainOfVisitors(
@@ -1549,6 +1595,7 @@ def write_source(mod, f, internal_h_file):
Obj2ModVisitor(f),
ASTModuleVisitor(f),
PartingShots(f),
+ metadata=metadata
)
v.visit(mod)
@@ -1561,6 +1608,10 @@ def main(input_filename, c_filename, h_filename, internal_h_filename, dump_modul
if not asdl.check(mod):
sys.exit(1)
+ metadata_visitor = MetadataVisitor()
+ metadata_visitor.visit(mod)
+ metadata = metadata_visitor.metadata
+
with c_filename.open("w") as c_file, \
h_filename.open("w") as h_file, \
internal_h_filename.open("w") as internal_h_file:
@@ -1569,8 +1620,8 @@ def main(input_filename, c_filename, h_filename, internal_h_filename, dump_modul
internal_h_file.write(auto_gen_msg)
write_internal_h_header(mod, internal_h_file)
- write_source(mod, c_file, internal_h_file)
- write_header(mod, h_file)
+ write_source(mod, metadata, c_file, internal_h_file)
+ write_header(mod, metadata, h_file)
write_internal_h_footer(mod, internal_h_file)
print(f"{c_filename}, {h_filename}, {internal_h_filename} regenerated.")