diff options
-rwxr-xr-x | Parser/asdl_c.py | 99 |
1 files changed, 75 insertions, 24 deletions
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py index 4e5c5c8..371730a 100755 --- a/Parser/asdl_c.py +++ b/Parser/asdl_c.py @@ -4,6 +4,7 @@ import os import sys import textwrap +import types from argparse import ArgumentParser from contextlib import contextmanager @@ -100,11 +101,12 @@ def asdl_of(name, obj): class EmitVisitor(asdl.VisitorBase): """Visit that emits lines""" - def __init__(self, file): + def __init__(self, file, metadata = None): self.file = file self.identifiers = set() self.singletons = set() self.types = set() + self._metadata = metadata super(EmitVisitor, self).__init__() def emit_identifier(self, name): @@ -127,6 +129,42 @@ class EmitVisitor(asdl.VisitorBase): line = (" " * TABSIZE * depth) + line self.file.write(line + "\n") + @property + def metadata(self): + if self._metadata is None: + raise ValueError( + "%s was expecting to be annnotated with metadata" + % type(self).__name__ + ) + return self._metadata + + @metadata.setter + def metadata(self, value): + self._metadata = value + +class MetadataVisitor(asdl.VisitorBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Metadata: + # - simple_sums: Tracks the list of compound type + # names where all the constructors + # belonging to that type lack of any + # fields. + self.metadata = types.SimpleNamespace( + simple_sums=set() + ) + + def visitModule(self, mod): + for dfn in mod.dfns: + self.visit(dfn) + + def visitType(self, type): + self.visit(type.value, type.name) + + def visitSum(self, sum, name): + if is_simple(sum): + self.metadata.simple_sums.add(name) class TypeDefVisitor(EmitVisitor): def visitModule(self, mod): @@ -244,7 +282,7 @@ class StructVisitor(EmitVisitor): ctype = get_c_type(field.type) name = field.name if field.seq: - if field.type == 'cmpop': + if field.type in self.metadata.simple_sums: self.emit("asdl_int_seq *%(name)s;" % locals(), depth) else: _type = field.type @@ -304,7 +342,7 @@ class PrototypeVisitor(EmitVisitor): name = f.name # XXX should extend get_c_type() to handle this if f.seq: - if f.type == 'cmpop': + if f.type in self.metadata.simple_sums: ctype = "asdl_int_seq *" else: ctype = f"asdl_{f.type}_seq *" @@ -549,16 +587,11 @@ class Obj2ModVisitor(PickleVisitor): ctype = get_c_type(field.type) self.emit("%s %s;" % (ctype, field.name), depth) - def isSimpleSum(self, field): - # XXX can the members of this list be determined automatically? - return field.type in ('expr_context', 'boolop', 'operator', - 'unaryop', 'cmpop') - def isNumeric(self, field): return get_c_type(field.type) in ("int", "bool") def isSimpleType(self, field): - return self.isSimpleSum(field) or self.isNumeric(field) + return field.type in self.metadata.simple_sums or self.isNumeric(field) def visitField(self, field, name, sum=None, prod=None, depth=0): ctype = get_c_type(field.type) @@ -1282,18 +1315,23 @@ class ObjVisitor(PickleVisitor): def set(self, field, value, depth): if field.seq: - # XXX should really check for is_simple, but that requires a symbol table - if field.type == "cmpop": + if field.type in self.metadata.simple_sums: # While the sequence elements are stored as void*, - # ast2obj_cmpop expects an enum + # simple sums expects an enum self.emit("{", depth) self.emit("Py_ssize_t i, n = asdl_seq_LEN(%s);" % value, depth+1) self.emit("value = PyList_New(n);", depth+1) self.emit("if (!value) goto failed;", depth+1) self.emit("for(i = 0; i < n; i++)", depth+1) # This cannot fail, so no need for error handling - self.emit("PyList_SET_ITEM(value, i, ast2obj_cmpop(state, (cmpop_ty)asdl_seq_GET(%s, i)));" % value, - depth+2, reflow=False) + self.emit( + "PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format( + field.type, + value + ), + depth + 2, + reflow=False, + ) self.emit("}", depth) else: self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth) @@ -1362,11 +1400,13 @@ int PyAST_Check(PyObject* obj) """ class ChainOfVisitors: - def __init__(self, *visitors): + def __init__(self, *visitors, metadata = None): self.visitors = visitors + self.metadata = metadata def visit(self, object): for v in self.visitors: + v.metadata = self.metadata v.visit(object) v.emit("", 0) @@ -1468,7 +1508,7 @@ def generate_module_def(mod, f, internal_h): f.write(' return 1;\n') f.write('};\n\n') -def write_header(mod, f): +def write_header(mod, metadata, f): f.write(textwrap.dedent(""" #ifndef Py_INTERNAL_AST_H #define Py_INTERNAL_AST_H @@ -1483,12 +1523,19 @@ def write_header(mod, f): #include "pycore_asdl.h" """).lstrip()) - c = ChainOfVisitors(TypeDefVisitor(f), - SequenceDefVisitor(f), - StructVisitor(f)) + + c = ChainOfVisitors( + TypeDefVisitor(f), + SequenceDefVisitor(f), + StructVisitor(f), + metadata=metadata + ) c.visit(mod) + f.write("// Note: these macros affect function definitions, not only call sites.\n") - PrototypeVisitor(f).visit(mod) + prototype_visitor = PrototypeVisitor(f, metadata=metadata) + prototype_visitor.visit(mod) + f.write(textwrap.dedent(""" PyObject* PyAST_mod2obj(mod_ty t); @@ -1535,8 +1582,7 @@ def write_internal_h_footer(mod, f): #endif /* !Py_INTERNAL_AST_STATE_H */ """), file=f) - -def write_source(mod, f, internal_h_file): +def write_source(mod, metadata, f, internal_h_file): generate_module_def(mod, f, internal_h_file) v = ChainOfVisitors( @@ -1549,6 +1595,7 @@ def write_source(mod, f, internal_h_file): Obj2ModVisitor(f), ASTModuleVisitor(f), PartingShots(f), + metadata=metadata ) v.visit(mod) @@ -1561,6 +1608,10 @@ def main(input_filename, c_filename, h_filename, internal_h_filename, dump_modul if not asdl.check(mod): sys.exit(1) + metadata_visitor = MetadataVisitor() + metadata_visitor.visit(mod) + metadata = metadata_visitor.metadata + with c_filename.open("w") as c_file, \ h_filename.open("w") as h_file, \ internal_h_filename.open("w") as internal_h_file: @@ -1569,8 +1620,8 @@ def main(input_filename, c_filename, h_filename, internal_h_filename, dump_modul internal_h_file.write(auto_gen_msg) write_internal_h_header(mod, internal_h_file) - write_source(mod, c_file, internal_h_file) - write_header(mod, h_file) + write_source(mod, metadata, c_file, internal_h_file) + write_header(mod, metadata, h_file) write_internal_h_footer(mod, internal_h_file) print(f"{c_filename}, {h_filename}, {internal_h_filename} regenerated.") |