diff options
Diffstat (limited to 'Parser/asdl_c.py')
-rwxr-xr-x | Parser/asdl_c.py | 61 |
1 files changed, 55 insertions, 6 deletions
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py index 0c05339..242eccf 100755 --- a/Parser/asdl_c.py +++ b/Parser/asdl_c.py @@ -163,6 +163,32 @@ class TypeDefVisitor(EmitVisitor): self.emit(s, depth) self.emit("", depth) +class SequenceDefVisitor(EmitVisitor): + def visitModule(self, mod): + for dfn in mod.dfns: + self.visit(dfn) + + def visitType(self, type, depth=0): + self.visit(type.value, type.name, depth) + + def visitSum(self, sum, name, depth): + if is_simple(sum): + return + self.emit_sequence_constructor(name, depth) + + def emit_sequence_constructor(self, name,depth): + ctype = get_c_type(name) + self.emit("""\ +typedef struct { + _ASDL_SEQ_HEAD + %(ctype)s typed_elements[1]; +} asdl_%(name)s_seq;""" % locals(), reflow=False, depth=depth) + self.emit("", depth) + self.emit("asdl_%(name)s_seq *_Py_asdl_%(name)s_seq_new(Py_ssize_t size, PyArena *arena);" % locals(), depth) + self.emit("", depth) + + def visitProduct(self, product, name, depth): + self.emit_sequence_constructor(name, depth) class StructVisitor(EmitVisitor): """Visitor to generate typedefs for AST.""" @@ -219,7 +245,8 @@ class StructVisitor(EmitVisitor): if field.type == 'cmpop': self.emit("asdl_int_seq *%(name)s;" % locals(), depth) else: - self.emit("asdl_seq *%(name)s;" % locals(), depth) + _type = field.type + self.emit("asdl_%(_type)s_seq *%(name)s;" % locals(), depth) else: self.emit("%(ctype)s %(name)s;" % locals(), depth) @@ -274,7 +301,7 @@ class PrototypeVisitor(EmitVisitor): if f.type == 'cmpop': ctype = "asdl_int_seq *" else: - ctype = "asdl_seq *" + ctype = f"asdl_{f.type}_seq *" else: ctype = get_c_type(f.type) args.append((ctype, name, f.opt or f.seq)) @@ -507,7 +534,8 @@ class Obj2ModVisitor(PickleVisitor): if self.isSimpleType(field): self.emit("asdl_int_seq* %s;" % field.name, depth) else: - self.emit("asdl_seq* %s;" % field.name, depth) + _type = field.type + self.emit(f"asdl_{field.type}_seq* {field.name};", depth) else: ctype = get_c_type(field.type) self.emit("%s %s;" % (ctype, field.name), depth) @@ -562,7 +590,7 @@ class Obj2ModVisitor(PickleVisitor): if self.isSimpleType(field): self.emit("%s = _Py_asdl_int_seq_new(len, arena);" % field.name, depth+1) else: - self.emit("%s = _Py_asdl_seq_new(len, arena);" % field.name, depth+1) + self.emit("%s = _Py_asdl_%s_seq_new(len, arena);" % (field.name, field.type), depth+1) self.emit("if (%s == NULL) goto failed;" % field.name, depth+1) self.emit("for (i = 0; i < len; i++) {", depth+1) self.emit("%s val;" % ctype, depth+2) @@ -600,6 +628,24 @@ class MarshalPrototypeVisitor(PickleVisitor): visitProduct = visitSum = prototype +class SequenceConstructorVisitor(EmitVisitor): + def visitModule(self, mod): + for dfn in mod.dfns: + self.visit(dfn) + + def visitType(self, type): + self.visit(type.value, type.name) + + def visitProduct(self, prod, name): + self.emit_sequence_constructor(name, get_c_type(name)) + + def visitSum(self, sum, name): + if not is_simple(sum): + self.emit_sequence_constructor(name, get_c_type(name)) + + def emit_sequence_constructor(self, name, type): + self.emit(f"GENERATE_ASDL_SEQ_CONSTRUCTOR({name}, {type})", depth=0) + class PyTypesDeclareVisitor(PickleVisitor): def visitProduct(self, prod, name): @@ -647,6 +693,7 @@ class PyTypesDeclareVisitor(PickleVisitor): self.emit('"%s",' % t.name, 1) self.emit("};",0) + class PyTypesVisitor(PickleVisitor): def visitModule(self, mod): @@ -874,7 +921,7 @@ static PyObject* ast2obj_list(astmodulestate *state, asdl_seq *seq, PyObject* (* if (!result) return NULL; for (i = 0; i < n; i++) { - value = func(state, asdl_seq_GET(seq, i)); + value = func(state, asdl_seq_GET_UNTYPED(seq, i)); if (!value) { Py_DECREF(result); return NULL; @@ -1264,7 +1311,7 @@ class ObjVisitor(PickleVisitor): depth+2, reflow=False) self.emit("}", depth) else: - self.emit("value = ast2obj_list(state, %s, ast2obj_%s);" % (value, field.type), depth) + self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth) else: ctype = get_c_type(field.type) self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False) @@ -1431,6 +1478,7 @@ def write_header(f, mod): f.write('#undef Yield /* undefine macro conflicting with <winbase.h> */\n') f.write('\n') c = ChainOfVisitors(TypeDefVisitor(f), + SequenceDefVisitor(f), StructVisitor(f)) c.visit(mod) f.write("// Note: these macros affect function definitions, not only call sites.\n") @@ -1457,6 +1505,7 @@ def write_source(f, mod): generate_module_def(f, mod) v = ChainOfVisitors( + SequenceConstructorVisitor(f), PyTypesDeclareVisitor(f), PyTypesVisitor(f), Obj2ModPrototypeVisitor(f), |