diff options
author | Jeremy Hylton <jeremy@alum.mit.edu> | 2000-03-16 20:06:59 (GMT) |
---|---|---|
committer | Jeremy Hylton <jeremy@alum.mit.edu> | 2000-03-16 20:06:59 (GMT) |
commit | 36cc6a21973c38376c6a6fb646ca53e079950586 (patch) | |
tree | e58e97b683ddaf7abc6fcaf93ba56d6103e46406 | |
parent | f635abee3a3c56aa614ca4c0edb963ec9747bf82 (diff) | |
download | cpython-36cc6a21973c38376c6a6fb646ca53e079950586.zip cpython-36cc6a21973c38376c6a6fb646ca53e079950586.tar.gz cpython-36cc6a21973c38376c6a6fb646ca53e079950586.tar.bz2 |
complete rewrite
code generator uses flowgraph as intermediate representation. the old
rep uses a list with explicit "StackRefs" to indicate the target
of jumps.
pyassem converts flowgraph to bytecode, breaks up individual steps of
generating bytecode
-rw-r--r-- | Lib/compiler/pyassem.py | 624 | ||||
-rw-r--r-- | Lib/compiler/pycodegen.py | 1007 | ||||
-rw-r--r-- | Tools/compiler/compiler/pyassem.py | 624 | ||||
-rw-r--r-- | Tools/compiler/compiler/pycodegen.py | 1007 |
4 files changed, 1662 insertions, 1600 deletions
diff --git a/Lib/compiler/pyassem.py b/Lib/compiler/pyassem.py index 4cb910c..3272419 100644 --- a/Lib/compiler/pyassem.py +++ b/Lib/compiler/pyassem.py @@ -1,40 +1,127 @@ -"""Assembler for Python bytecode - -The new module is used to create the code object. The following -attribute definitions are included from the reference manual: - -co_name gives the function name -co_argcount is the number of positional arguments (including - arguments with default values) -co_nlocals is the number of local variables used by the function - (including arguments) -co_varnames is a tuple containing the names of the local variables - (starting with the argument names) -co_code is a string representing the sequence of bytecode instructions -co_consts is a tuple containing the literals used by the bytecode -co_names is a tuple containing the names used by the bytecode -co_filename is the filename from which the code was compiled -co_firstlineno is the first line number of the function -co_lnotab is a string encoding the mapping from byte code offsets - to line numbers. see LineAddrTable below. -co_stacksize is the required stack size (including local variables) -co_flags is an integer encoding a number of flags for the - interpreter. There are four flags: - CO_OPTIMIZED -- uses load fast - CO_NEWLOCALS -- everything? - CO_VARARGS -- use *args - CO_VARKEYWORDS -- uses **args - -If a code object represents a function, the first item in co_consts is -the documentation string of the function, or None if undefined. -""" - -import sys +"""A flow graph representation for Python bytecode""" + import dis import new import string +import types + +from compiler import misc + +class FlowGraph: + def __init__(self): + self.current = self.entry = Block() + self.exit = Block("exit") + self.blocks = misc.Set() + self.blocks.add(self.entry) + self.blocks.add(self.exit) + + def startBlock(self, block): + self.current = block + + def nextBlock(self, block=None): + if block is None: + block = self.newBlock() + # XXX think we need to specify when there is implicit transfer + # from one block to the next + # + # I think this strategy works: each block has a child + # designated as "next" which is returned as the last of the + # children. because the nodes in a graph are emitted in + # reverse post order, the "next" block will always be emitted + # immediately after its parent. + # Worry: maintaining this invariant could be tricky + self.current.addNext(block) + self.startBlock(block) + + def newBlock(self): + b = Block() + self.blocks.add(b) + return b + + def startExitBlock(self): + self.startBlock(self.exit) + + def emit(self, *inst): + # XXX should jump instructions implicitly call nextBlock? + if inst[0] == 'RETURN_VALUE': + self.current.addOutEdge(self.exit) + self.current.emit(inst) + + def getBlocks(self): + """Return the blocks in reverse postorder + + i.e. each node appears before all of its successors + """ + # XXX make sure every node that doesn't have an explicit next + # is set so that next points to exit + for b in self.blocks.elements(): + if b is self.exit: + continue + if not b.next: + b.addNext(self.exit) + order = dfs_postorder(self.entry, {}) + order.reverse() + # hack alert + if not self.exit in order: + order.append(self.exit) + return order + +def dfs_postorder(b, seen): + """Depth-first search of tree rooted at b, return in postorder""" + order = [] + seen[b] = b + for c in b.children(): + if seen.has_key(c): + continue + order = order + dfs_postorder(c, seen) + order.append(b) + return order + +class Block: + _count = 0 + + def __init__(self, label=''): + self.insts = [] + self.inEdges = misc.Set() + self.outEdges = misc.Set() + self.label = label + self.bid = Block._count + self.next = [] + Block._count = Block._count + 1 + + def __repr__(self): + if self.label: + return "<block %s id=%d len=%d>" % (self.label, self.bid, + len(self.insts)) + else: + return "<block id=%d len=%d>" % (self.bid, len(self.insts)) + + def __str__(self): + insts = map(str, self.insts) + return "<block %s %d:\n%s>" % (self.label, self.bid, + string.join(insts, '\n')) + + def emit(self, inst): + op = inst[0] + if op[:4] == 'JUMP': + self.outEdges.add(inst[1]) + self.insts.append(inst) + + def getInstructions(self): + return self.insts + + def addInEdge(self, block): + self.inEdges.add(block) + + def addOutEdge(self, block): + self.outEdges.add(block) + + def addNext(self, block): + self.next.append(block) + assert len(self.next) == 1, map(str, self.next) -import misc + def children(self): + return self.outEdges.elements() + self.next # flags for code objects CO_OPTIMIZED = 0x0001 @@ -42,224 +129,128 @@ CO_NEWLOCALS = 0x0002 CO_VARARGS = 0x0004 CO_VARKEYWORDS = 0x0008 -class TupleArg: - def __init__(self, count, names): - self.count = count - self.names = names - def __repr__(self): - return "TupleArg(%s, %s)" % (self.count, self.names) - def getName(self): - return ".nested%d" % self.count - -class PyAssembler: - """Creates Python code objects - """ - - # XXX this class needs to major refactoring - - def __init__(self, args=(), name='?', filename='<?>', - docstring=None): - # XXX why is the default value for flags 3? - self.insts = [] - # used by makeCodeObject - self._getArgCount(args) - self.code = '' - self.consts = [docstring] - self.filename = filename - self.flags = CO_NEWLOCALS - self.name = name - self.names = [] +# the FlowGraph is transformed in place; it exists in one of these states +RAW = "RAW" +FLAT = "FLAT" +CONV = "CONV" +DONE = "DONE" + +class PyFlowGraph(FlowGraph): + super_init = FlowGraph.__init__ + + def __init__(self, name, filename, args=(), optimized=0): + self.super_init() + self.name = name + self.filename = filename + self.docstring = None + self.args = args # XXX + self.argcount = getArgCount(args) + if optimized: + self.flags = CO_OPTIMIZED | CO_NEWLOCALS + else: + self.flags = 0 + self.firstlineno = None + self.consts = [] + self.names = [] self.varnames = list(args) or [] for i in range(len(self.varnames)): var = self.varnames[i] if isinstance(var, TupleArg): self.varnames[i] = var.getName() - # lnotab support - self.firstlineno = 0 - self.lastlineno = 0 - self.last_addr = 0 - self.lnotab = '' - - def _getArgCount(self, args): - self.argcount = len(args) - if args: - for arg in args: - if isinstance(arg, TupleArg): - numNames = len(misc.flatten(arg.names)) - self.argcount = self.argcount - numNames + self.stage = RAW - def __repr__(self): - return "<bytecode: %d instrs>" % len(self.insts) - - def setFlags(self, val): - """XXX for module's function""" - self.flags = val - - def setOptimized(self): - self.flags = self.flags | CO_OPTIMIZED - - def setVarArgs(self): - if not self.flags & CO_VARARGS: - self.flags = self.flags | CO_VARARGS - self.argcount = self.argcount - 1 - - def setKWArgs(self): - self.flags = self.flags | CO_VARKEYWORDS - - def getCurInst(self): - return len(self.insts) + def setDocstring(self, doc): + self.docstring = doc + self.consts.insert(0, doc) - def getNextInst(self): - return len(self.insts) + 1 + def setFlag(self, flag): + self.flags = self.flags | flag + if flag == CO_VARARGS: + self.argcount = self.argcount - 1 - def dump(self, io=sys.stdout): - i = 0 - for inst in self.insts: - if inst[0] == 'SET_LINENO': - io.write("\n") - io.write(" %3d " % i) - if len(inst) == 1: - io.write("%s\n" % inst) - else: - io.write("%-15.15s\t%s\n" % inst) - i = i + 1 - - def makeCodeObject(self): - """Make a Python code object - - This creates a Python code object using the new module. This - seems simpler than reverse-engineering the way marshal dumps - code objects into .pyc files. One of the key difficulties is - figuring out how to layout references to code objects that - appear on the VM stack; e.g. - 3 SET_LINENO 1 - 6 LOAD_CONST 0 (<code object fact at 8115878 [...] - 9 MAKE_FUNCTION 0 - 12 STORE_NAME 0 (fact) - """ - - self._findOffsets() - lnotab = LineAddrTable() + def getCode(self): + """Get a Python code object""" + if self.stage == RAW: + self.flattenGraph() + if self.stage == FLAT: + self.convertArgs() + if self.stage == CONV: + self.makeByteCode() + if self.stage == DONE: + return self.newCodeObject() + raise RuntimeError, "inconsistent PyFlowGraph state" + + def dump(self, io=None): + if io: + save = sys.stdout + sys.stdout = io + pc = 0 for t in self.insts: opname = t[0] + if opname == "SET_LINENO": + print if len(t) == 1: - lnotab.addCode(self.opnum[opname]) - elif len(t) == 2: - if opname == 'SET_LINENO': - oparg = t[1] - lnotab.nextLine(oparg) + print "\t", "%3d" % pc, opname + pc = pc + 1 + else: + print "\t", "%3d" % pc, opname, t[1] + pc = pc + 3 + if io: + sys.stdout = save + + def flattenGraph(self): + """Arrange the blocks in order and resolve jumps""" + assert self.stage == RAW + self.insts = insts = [] + pc = 0 + begin = {} + end = {} + for b in self.getBlocks(): + begin[b] = pc + for inst in b.getInstructions(): + insts.append(inst) + if len(inst) == 1: + pc = pc + 1 else: - oparg = self._convertArg(opname, t[1]) - try: - hi, lo = divmod(oparg, 256) - except TypeError: - raise TypeError, "untranslated arg: %s, %s" % (opname, oparg) - lnotab.addCode(self.opnum[opname], lo, hi) - - # why is a module a special case? - if self.flags == 0: - nlocals = 0 - else: - nlocals = len(self.varnames) - # XXX danger! can't pass through here twice - if self.flags & CO_VARKEYWORDS: - self.argcount = self.argcount - 1 - stacksize = findDepth(self.insts) - try: - co = new.code(self.argcount, nlocals, stacksize, - self.flags, lnotab.getCode(), self._getConsts(), - tuple(self.names), tuple(self.varnames), - self.filename, self.name, self.firstlineno, - lnotab.getTable()) - except SystemError, err: - print err - print repr(self.argcount) - print repr(nlocals) - print repr(stacksize) - print repr(self.flags) - print repr(lnotab.getCode()) - print repr(self._getConsts()) - print repr(self.names) - print repr(self.varnames) - print repr(self.filename) - print repr(self.name) - print repr(self.firstlineno) - print repr(lnotab.getTable()) - raise - return co - - def _getConsts(self): - """Return a tuple for the const slot of a code object - - Converts PythonVMCode objects to code objects - """ - l = [] - for elt in self.consts: - # XXX might be clearer to just as isinstance(CodeGen) - if hasattr(elt, 'asConst'): - l.append(elt.asConst()) + # arg takes 2 bytes + pc = pc + 3 + end[b] = pc + pc = 0 + for i in range(len(insts)): + inst = insts[i] + if len(inst) == 1: + pc = pc + 1 else: - l.append(elt) - return tuple(l) + pc = pc + 3 + opname = inst[0] + if self.hasjrel.has_elt(opname): + oparg = inst[1] + offset = begin[oparg] - pc + insts[i] = opname, offset + elif self.hasjabs.has_elt(opname): + insts[i] = opname, begin[inst[1]] + self.stacksize = findDepth(self.insts) + self.stage = FLAT - def _findOffsets(self): - """Find offsets for use in resolving StackRefs""" - self.offsets = [] - cur = 0 - for t in self.insts: - self.offsets.append(cur) - l = len(t) - if l == 1: - cur = cur + 1 - elif l == 2: - cur = cur + 3 - arg = t[1] - # XXX this is a total hack: for a reference used - # multiple times, we create a list of offsets and - # expect that we when we pass through the code again - # to actually generate the offsets, we'll pass in the - # same order. - if isinstance(arg, StackRef): - try: - arg.__offset.append(cur) - except AttributeError: - arg.__offset = [cur] - - def _convertArg(self, op, arg): - """Convert the string representation of an arg to a number - - The specific handling depends on the opcode. - - XXX This first implementation isn't going to be very - efficient. - """ - if op == 'SET_LINENO': - return arg - if op == 'LOAD_CONST': - return self._lookupName(arg, self.consts) - if op in self.localOps: - # make sure it's in self.names, but use the bytecode offset - self._lookupName(arg, self.names) - return self._lookupName(arg, self.varnames) - if op in self.globalOps: - return self._lookupName(arg, self.names) - if op in self.nameOps: - return self._lookupName(arg, self.names) - if op == 'COMPARE_OP': - return self.cmp_op.index(arg) - if self.hasjrel.has_elt(op): - offset = arg.__offset[0] - del arg.__offset[0] - return self.offsets[arg.resolve()] - offset - if self.hasjabs.has_elt(op): - return self.offsets[arg.resolve()] - return arg - - nameOps = ('STORE_NAME', 'IMPORT_NAME', 'IMPORT_FROM', - 'STORE_ATTR', 'LOAD_ATTR', 'LOAD_NAME', 'DELETE_NAME', - 'DELETE_ATTR') - localOps = ('LOAD_FAST', 'STORE_FAST', 'DELETE_FAST') - globalOps = ('LOAD_GLOBAL', 'STORE_GLOBAL', 'DELETE_GLOBAL') + hasjrel = misc.Set() + for i in dis.hasjrel: + hasjrel.add(dis.opname[i]) + hasjabs = misc.Set() + for i in dis.hasjabs: + hasjabs.add(dis.opname[i]) + + def convertArgs(self): + """Convert arguments from symbolic to concrete form""" + assert self.stage == FLAT + for i in range(len(self.insts)): + t = self.insts[i] + if len(t) == 2: + opname = t[0] + oparg = t[1] + conv = self._converters.get(opname, None) + if conv: + self.insts[i] = opname, conv(self, oparg) + self.stage = CONV def _lookupName(self, name, list): """Return index of name in list, appending if necessary""" @@ -276,32 +267,124 @@ class PyAssembler: list.append(name) return end - # Convert some stuff from the dis module for local use - - cmp_op = list(dis.cmp_op) - hasjrel = misc.Set() - for i in dis.hasjrel: - hasjrel.add(dis.opname[i]) - hasjabs = misc.Set() - for i in dis.hasjabs: - hasjabs.add(dis.opname[i]) - + _converters = {} + def _convert_LOAD_CONST(self, arg): + return self._lookupName(arg, self.consts) + + def _convert_LOAD_FAST(self, arg): + self._lookupName(arg, self.names) + return self._lookupName(arg, self.varnames) + _convert_STORE_FAST = _convert_LOAD_FAST + _convert_DELETE_FAST = _convert_LOAD_FAST + + def _convert_NAME(self, arg): + return self._lookupName(arg, self.names) + _convert_LOAD_NAME = _convert_NAME + _convert_STORE_NAME = _convert_NAME + _convert_DELETE_NAME = _convert_NAME + _convert_IMPORT_NAME = _convert_NAME + _convert_IMPORT_FROM = _convert_NAME + _convert_STORE_ATTR = _convert_NAME + _convert_LOAD_ATTR = _convert_NAME + _convert_DELETE_ATTR = _convert_NAME + _convert_LOAD_GLOBAL = _convert_NAME + _convert_STORE_GLOBAL = _convert_NAME + _convert_DELETE_GLOBAL = _convert_NAME + + _cmp = list(dis.cmp_op) + def _convert_COMPARE_OP(self, arg): + return self._cmp.index(arg) + + # similarly for other opcodes... + + for name, obj in locals().items(): + if name[:9] == "_convert_": + opname = name[9:] + _converters[opname] = obj + del name, obj, opname + + def makeByteCode(self): + assert self.stage == CONV + self.lnotab = lnotab = LineAddrTable() + for t in self.insts: + opname = t[0] + if len(t) == 1: + lnotab.addCode(self.opnum[opname]) + else: + oparg = t[1] + if opname == "SET_LINENO": + lnotab.nextLine(oparg) + if self.firstlineno is None: + self.firstlineno = oparg + hi, lo = twobyte(oparg) + try: + lnotab.addCode(self.opnum[opname], lo, hi) + except ValueError: + print opname, oparg + print self.opnum[opname], lo, hi + raise + self.stage = DONE + opnum = {} for num in range(len(dis.opname)): opnum[dis.opname[num]] = num + del num - # this version of emit + arbitrary hooks might work, but it's damn - # messy. + def newCodeObject(self): + assert self.stage == DONE + if self.flags == 0: + nlocals = 0 + else: + nlocals = len(self.varnames) + argcount = self.argcount + if self.flags & CO_VARKEYWORDS: + argcount = argcount - 1 + return new.code(argcount, nlocals, self.stacksize, self.flags, + self.lnotab.getCode(), self.getConsts(), + tuple(self.names), tuple(self.varnames), + self.filename, self.name, self.firstlineno, + self.lnotab.getTable()) + + def getConsts(self): + """Return a tuple for the const slot of the code object + + Must convert references to code (MAKE_FUNCTION) to code + objects recursively. + """ + l = [] + for elt in self.consts: + if isinstance(elt, PyFlowGraph): + elt = elt.getCode() + l.append(elt) + return tuple(l) + +def isJump(opname): + if opname[:4] == 'JUMP': + return 1 - def emit(self, *args): - self._emitDispatch(args[0], args[1:]) - self.insts.append(args) +class TupleArg: + """Helper for marking func defs with nested tuples in arglist""" + def __init__(self, count, names): + self.count = count + self.names = names + def __repr__(self): + return "TupleArg(%s, %s)" % (self.count, self.names) + def getName(self): + return ".nested%d" % self.count - def _emitDispatch(self, type, args): - for func in self._emit_hooks.get(type, []): - func(self, args) +def getArgCount(args): + argcount = len(args) + if args: + for arg in args: + if isinstance(arg, TupleArg): + numNames = len(misc.flatten(arg.names)) + argcount = argcount - numNames + return argcount - _emit_hooks = {} +def twobyte(val): + """Convert an int argument into high and low bytes""" + assert type(val) == types.IntType + return divmod(val, 256) class LineAddrTable: """lnotab @@ -361,34 +444,9 @@ class LineAddrTable: def getTable(self): return string.join(map(chr, self.lnotab), '') -class StackRef: - """Manage stack locations for jumps, loops, etc.""" - count = 0 - - def __init__(self, id=None, val=None): - if id is None: - id = StackRef.count - StackRef.count = StackRef.count + 1 - self.id = id - self.val = val - - def __repr__(self): - if self.val: - return "StackRef(val=%d)" % self.val - else: - return "StackRef(id=%d)" % self.id - - def bind(self, inst): - self.val = inst - - def resolve(self): - if self.val is None: - print "UNRESOLVE REF", self - return 0 - return self.val - class StackDepthTracker: - # XXX need to keep track of stack depth on jumps + # XXX 1. need to keep track of stack depth on jumps + # XXX 2. at least partly as a result, this code is broken def findDepth(self, insts): depth = 0 diff --git a/Lib/compiler/pycodegen.py b/Lib/compiler/pycodegen.py index b2d55d9..2e98d4e 100644 --- a/Lib/compiler/pycodegen.py +++ b/Lib/compiler/pycodegen.py @@ -1,132 +1,84 @@ -"""Python bytecode generator - -Currently contains generic ASTVisitor code, a LocalNameFinder, and a -CodeGenerator. Eventually, this will get split into the ASTVisitor as -a generic tool and CodeGenerator as a specific tool. -""" - -from compiler import parseFile, ast, visitor, walk, parse -from pyassem import StackRef, PyAssembler, TupleArg -import dis -import misc -import marshal -import new -import string -import sys import os +import marshal import stat import struct import types +from cStringIO import StringIO -class CodeGenerator: - """Generate bytecode for the Python VM""" - - OPTIMIZED = 1 - - # XXX should clean up initialization and generateXXX funcs - def __init__(self, filename="<?>"): - self.filename = filename - self.code = PyAssembler() - self.code.setFlags(0) - self.locals = misc.Stack() - self.loops = misc.Stack() - self.namespace = 0 - self.curStack = 0 - self.maxStack = 0 - - def emit(self, *args): - # XXX could just use self.emit = self.code.emit - apply(self.code.emit, args) - - def _generateFunctionOrLambdaCode(self, func): - self.name = func.name - - # keep a lookout for 'def foo((x,y)):' - args, hasTupleArg = self.generateArglist(func.argnames) - - self.code = PyAssembler(args=args, name=func.name, - filename=self.filename) - self.namespace = self.OPTIMIZED - if func.varargs: - self.code.setVarArgs() - if func.kwargs: - self.code.setKWArgs() - lnf = walk(func.code, LocalNameFinder(args), 0) - self.locals.push(lnf.getLocals()) - self.emit('SET_LINENO', func.lineno) - if hasTupleArg: - self.generateArgUnpack(func.argnames) - walk(func.code, self) +from compiler import ast, parse, walk +from compiler import pyassem, misc +from compiler.pyassem import CO_VARARGS, CO_VARKEYWORDS, TupleArg - def generateArglist(self, arglist): - args = [] - extra = [] - count = 0 - for elt in arglist: - if type(elt) == types.StringType: - args.append(elt) - elif type(elt) == types.TupleType: - args.append(TupleArg(count, elt)) - count = count + 1 - extra.extend(misc.flatten(elt)) - else: - raise ValueError, "unexpect argument type:", elt - return args + extra, count +def compile(filename): + f = open(filename) + buf = f.read() + f.close() + mod = Module(buf, filename) + mod.compile() + f = open(filename + "c", "wb") + mod.dump(f) + f.close() - def generateArgUnpack(self, args): - count = 0 - for arg in args: - if type(arg) == types.TupleType: - self.emit('LOAD_FAST', '.nested%d' % count) - count = count + 1 - self.unpackTuple(arg) - - def unpackTuple(self, tup): - self.emit('UNPACK_TUPLE', len(tup)) - for elt in tup: - if type(elt) == types.TupleType: - self.unpackTuple(elt) - else: - self.emit('STORE_FAST', elt) +class Module: + def __init__(self, source, filename): + self.filename = filename + self.source = source + self.code = None - def generateFunctionCode(self, func): - """Generate code for a function body""" - self._generateFunctionOrLambdaCode(func) - self.emit('LOAD_CONST', None) - self.emit('RETURN_VALUE') + def compile(self): + ast = parse(self.source) + root, filename = os.path.split(self.filename) + gen = ModuleCodeGenerator(filename) + walk(ast, gen, 1) + self.code = gen.getCode() - def generateLambdaCode(self, func): - self._generateFunctionOrLambdaCode(func) - self.emit('RETURN_VALUE') + def dump(self, f): + f.write(self.getPycHeader()) + marshal.dump(self.code, f) - def generateClassCode(self, klass): - self.code = PyAssembler(name=klass.name, - filename=self.filename) - self.emit('SET_LINENO', klass.lineno) - lnf = walk(klass.code, LocalNameFinder(), 0) - self.locals.push(lnf.getLocals()) - walk(klass.code, self) - self.emit('LOAD_LOCALS') - self.emit('RETURN_VALUE') + MAGIC = (20121 | (ord('\r')<<16) | (ord('\n')<<24)) + + def getPycHeader(self): + # compile.c uses marshal to write a long directly, with + # calling the interface that would also generate a 1-byte code + # to indicate the type of the value. simplest way to get the + # same effect is to call marshal and then skip the code. + magic = marshal.dumps(self.MAGIC)[1:] + mtime = os.stat(self.filename)[stat.ST_MTIME] + mtime = struct.pack('i', mtime) + return magic + mtime + +class CodeGenerator: + + optimized = 0 # is namespace access optimized? + + def __init__(self, filename): +## Subclasses must define a constructor that intializes self.graph +## before calling this init function +## self.graph = pyassem.PyFlowGraph() + self.filename = filename + self.locals = misc.Stack() + self.loops = misc.Stack() + self.curStack = 0 + self.maxStack = 0 + self._setupGraphDelegation() + + def _setupGraphDelegation(self): + self.emit = self.graph.emit + self.newBlock = self.graph.newBlock + self.startBlock = self.graph.startBlock + self.nextBlock = self.graph.nextBlock + self.setDocstring = self.graph.setDocstring - def asConst(self): - """Create a Python code object.""" - if self.namespace == self.OPTIMIZED: - self.code.setOptimized() - return self.code.makeCodeObject() + def getCode(self): + """Return a code object""" + return self.graph.getCode() + + # Next five methods handle name access def isLocalName(self, name): return self.locals.top().has_elt(name) - def _nameOp(self, prefix, name): - if self.isLocalName(name): - if self.namespace == self.OPTIMIZED: - self.emit(prefix + '_FAST', name) - else: - self.emit(prefix + '_NAME', name) - else: - self.emit(prefix + '_GLOBAL', name) - def storeName(self, name): self._nameOp('STORE', name) @@ -136,195 +88,237 @@ class CodeGenerator: def delName(self, name): self._nameOp('DELETE', name) - def visitNULL(self, node): - """Method exists only to stop warning in -v mode""" - pass - - visitStmt = visitNULL - visitGlobal = visitNULL - - def visitDiscard(self, node): - self.visit(node.expr) - self.emit('POP_TOP') - return 1 + def _nameOp(self, prefix, name): + if not self.optimized: + self.emit(prefix + '_NAME', name) + return + if self.isLocalName(name): + self.emit(prefix + '_FAST', name) + else: + self.emit(prefix + '_GLOBAL', name) - def visitPass(self, node): - self.emit('SET_LINENO', node.lineno) + # The first few visitor methods handle nodes that generator new + # code objects def visitModule(self, node): - lnf = walk(node.node, LocalNameFinder(), 0) - self.locals.push(lnf.getLocals()) - self.visit(node.node) - self.emit('LOAD_CONST', None) - self.emit('RETURN_VALUE') - return 1 + lnf = walk(node.node, LocalNameFinder(), 0) + self.locals.push(lnf.getLocals()) + self.setDocstring(node.doc) + self.visit(node.node) + self.emit('LOAD_CONST', None) + self.emit('RETURN_VALUE') - def visitImport(self, node): - self.emit('SET_LINENO', node.lineno) - for name in node.names: - self.emit('IMPORT_NAME', name) - self.storeName(name) + def visitFunction(self, node): + self._visitFuncOrLambda(node, isLambda=0) + self.storeName(node.name) - def visitFrom(self, node): - self.emit('SET_LINENO', node.lineno) - self.emit('IMPORT_NAME', node.modname) - for name in node.names: - if name == '*': - self.namespace = 0 - self.emit('IMPORT_FROM', name) - self.emit('POP_TOP') + def visitLambda(self, node): + self._visitFuncOrLambda(node, isLambda=1) +## self.storeName("<lambda>") + + def _visitFuncOrLambda(self, node, isLambda): + gen = FunctionCodeGenerator(node, self.filename, isLambda) + walk(node.code, gen) + gen.finish() + self.emit('SET_LINENO', node.lineno) + for default in node.defaults: + self.visit(default) + self.emit('LOAD_CONST', gen.getCode()) + self.emit('MAKE_FUNCTION', len(node.defaults)) def visitClass(self, node): + gen = ClassCodeGenerator(node, self.filename) + walk(node.code, gen) + gen.finish() self.emit('SET_LINENO', node.lineno) self.emit('LOAD_CONST', node.name) for base in node.bases: self.visit(base) self.emit('BUILD_TUPLE', len(node.bases)) - classBody = CodeGenerator(self.filename) - classBody.generateClassCode(node) - self.emit('LOAD_CONST', classBody) + self.emit('LOAD_CONST', gen.getCode()) self.emit('MAKE_FUNCTION', 0) self.emit('CALL_FUNCTION', 0) self.emit('BUILD_CLASS') self.storeName(node.name) - return 1 - def _visitFuncOrLambda(self, node, kind): - """Code common to Function and Lambda nodes""" - codeBody = CodeGenerator(self.filename) - getattr(codeBody, 'generate%sCode' % kind)(node) - self.emit('SET_LINENO', node.lineno) - for default in node.defaults: - self.visit(default) - self.emit('LOAD_CONST', codeBody) - self.emit('MAKE_FUNCTION', len(node.defaults)) + # The rest are standard visitor methods - def visitFunction(self, node): - self._visitFuncOrLambda(node, 'Function') - self.storeName(node.name) - return 1 + # The next few implement control-flow statements - def visitLambda(self, node): - node.name = '<lambda>' - self._visitFuncOrLambda(node, 'Lambda') - return 1 + def visitIf(self, node): + end = self.newBlock() + numtests = len(node.tests) + for i in range(numtests): + test, suite = node.tests[i] + if hasattr(test, 'lineno'): + self.emit('SET_LINENO', test.lineno) + self.visit(test) +## if i == numtests - 1 and not node.else_: +## nextTest = end +## else: +## nextTest = self.newBlock() + nextTest = self.newBlock() + self.emit('JUMP_IF_FALSE', nextTest) + self.nextBlock() + self.emit('POP_TOP') + self.visit(suite) + self.emit('JUMP_FORWARD', end) + self.nextBlock(nextTest) + self.emit('POP_TOP') + if node.else_: + self.visit(node.else_) + self.nextBlock(end) - def visitCallFunc(self, node): - pos = 0 - kw = 0 - if hasattr(node, 'lineno'): - self.emit('SET_LINENO', node.lineno) - self.visit(node.node) - for arg in node.args: - self.visit(arg) - if isinstance(arg, ast.Keyword): - kw = kw + 1 - else: - pos = pos + 1 - self.emit('CALL_FUNCTION', kw << 8 | pos) - return 1 + def visitWhile(self, node): + self.emit('SET_LINENO', node.lineno) - def visitKeyword(self, node): - self.emit('LOAD_CONST', node.name) - self.visit(node.expr) - return 1 + loop = self.newBlock() + else_ = self.newBlock() - def visitIf(self, node): - after = StackRef() - for test, suite in node.tests: - if hasattr(test, 'lineno'): - self.emit('SET_LINENO', test.lineno) - else: - print "warning", "no line number" - self.visit(test) - dest = StackRef() - self.emit('JUMP_IF_FALSE', dest) - self.emit('POP_TOP') - self.visit(suite) - self.emit('JUMP_FORWARD', after) - dest.bind(self.code.getCurInst()) - self.emit('POP_TOP') - if node.else_: - self.visit(node.else_) - after.bind(self.code.getCurInst()) - return 1 + after = self.newBlock() + self.emit('SETUP_LOOP', after) + + self.nextBlock(loop) + self.loops.push(loop) + + self.emit('SET_LINENO', node.lineno) + self.visit(node.test) + self.emit('JUMP_IF_FALSE', else_ or after) - def startLoop(self): - l = Loop() - self.loops.push(l) - self.emit('SETUP_LOOP', l.extentAnchor) - return l + self.nextBlock() + self.emit('POP_TOP') + self.visit(node.body) + self.emit('JUMP_ABSOLUTE', loop) - def finishLoop(self): - l = self.loops.pop() - i = self.code.getCurInst() - l.extentAnchor.bind(self.code.getCurInst()) + self.startBlock(else_) # or just the POPs if not else clause + self.emit('POP_TOP') + self.emit('POP_BLOCK') + if node.else_: + self.visit(node.else_) + self.loops.pop() + self.nextBlock(after) def visitFor(self, node): - # three refs needed - anchor = StackRef() + start = self.newBlock() + anchor = self.newBlock() + after = self.newBlock() + self.loops.push(start) self.emit('SET_LINENO', node.lineno) - l = self.startLoop() + self.emit('SETUP_LOOP', after) self.visit(node.list) self.visit(ast.Const(0)) - l.startAnchor.bind(self.code.getCurInst()) + self.nextBlock(start) self.emit('SET_LINENO', node.lineno) self.emit('FOR_LOOP', anchor) self.visit(node.assign) self.visit(node.body) - self.emit('JUMP_ABSOLUTE', l.startAnchor) - anchor.bind(self.code.getCurInst()) + self.emit('JUMP_ABSOLUTE', start) + self.nextBlock(anchor) self.emit('POP_BLOCK') if node.else_: self.visit(node.else_) - self.finishLoop() - return 1 - - def visitWhile(self, node): - self.emit('SET_LINENO', node.lineno) - l = self.startLoop() - if node.else_: - lElse = StackRef() - else: - lElse = l.breakAnchor - l.startAnchor.bind(self.code.getCurInst()) - if hasattr(node.test, 'lineno'): - self.emit('SET_LINENO', node.test.lineno) - self.visit(node.test) - self.emit('JUMP_IF_FALSE', lElse) - self.emit('POP_TOP') - self.visit(node.body) - self.emit('JUMP_ABSOLUTE', l.startAnchor) - # note that lElse may be an alias for l.breakAnchor - lElse.bind(self.code.getCurInst()) - self.emit('POP_TOP') - self.emit('POP_BLOCK') - if node.else_: - self.visit(node.else_) - self.finishLoop() - return 1 + self.loops.pop() + self.nextBlock(after) def visitBreak(self, node): - if not self.loops: - raise SyntaxError, "'break' outside loop" - self.emit('SET_LINENO', node.lineno) - self.emit('BREAK_LOOP') + if not self.loops: + raise SyntaxError, "'break' outside loop (%s, %d)" % \ + (self.filename, node.lineno) + self.emit('SET_LINENO', node.lineno) + self.emit('BREAK_LOOP') def visitContinue(self, node): if not self.loops: - raise SyntaxError, "'continue' outside loop" + raise SyntaxError, "'continue' outside loop (%s, %d)" % \ + (self.filename, node.lineno) l = self.loops.top() self.emit('SET_LINENO', node.lineno) - self.emit('JUMP_ABSOLUTE', l.startAnchor) + self.emit('JUMP_ABSOLUTE', l) + self.nextBlock() + + def visitTest(self, node, jump): + end = self.newBlock() + for child in node.nodes[:-1]: + self.visit(child) + self.emit(jump, end) + self.nextBlock() + self.emit('POP_TOP') + self.visit(node.nodes[-1]) + self.nextBlock(end) + + def visitAnd(self, node): + self.visitTest(node, 'JUMP_IF_FALSE') + + def visitOr(self, node): + self.visitTest(node, 'JUMP_IF_TRUE') + + def visitCompare(self, node): + self.visit(node.expr) + cleanup = self.newBlock() + for op, code in node.ops[:-1]: + self.visit(code) + self.emit('DUP_TOP') + self.emit('ROT_THREE') + self.emit('COMPARE_OP', op) + self.emit('JUMP_IF_FALSE', cleanup) + self.nextBlock() + self.emit('POP_TOP') + # now do the last comparison + if node.ops: + op, code = node.ops[-1] + self.visit(code) + self.emit('COMPARE_OP', op) + if len(node.ops) > 1: + end = self.newBlock() + self.emit('JUMP_FORWARD', end) + self.nextBlock(cleanup) + self.emit('ROT_TWO') + self.emit('POP_TOP') + self.nextBlock(end) + + # exception related + + def visitAssert(self, node): + # XXX would be interesting to implement this via a + # transformation of the AST before this stage + end = self.newBlock() + self.emit('SET_LINENO', node.lineno) + # XXX __debug__ and AssertionError appear to be special cases + # -- they are always loaded as globals even if there are local + # names. I guess this is a sort of renaming op. + self.emit('LOAD_GLOBAL', '__debug__') + self.emit('JUMP_IF_FALSE', end) + self.nextBlock() + self.emit('POP_TOP') + self.visit(node.test) + self.emit('JUMP_IF_TRUE', end) + self.nextBlock() + self.emit('LOAD_GLOBAL', 'AssertionError') + self.visit(node.fail) + self.emit('RAISE_VARARGS', 2) + self.nextBlock(end) + self.emit('POP_TOP') + + def visitRaise(self, node): + self.emit('SET_LINENO', node.lineno) + n = 0 + if node.expr1: + self.visit(node.expr1) + n = n + 1 + if node.expr2: + self.visit(node.expr2) + n = n + 1 + if node.expr3: + self.visit(node.expr3) + n = n + 1 + self.emit('RAISE_VARARGS', n) def visitTryExcept(self, node): - # XXX need to figure out exactly what is on the stack when an - # exception is raised and the first handler is checked - handlers = StackRef() - end = StackRef() + handlers = self.newBlock() + end = self.newBlock() if node.else_: - lElse = StackRef() + lElse = self.newBlock() else: lElse = end self.emit('SET_LINENO', node.lineno) @@ -332,7 +326,7 @@ class CodeGenerator: self.visit(node.body) self.emit('POP_BLOCK') self.emit('JUMP_FORWARD', lElse) - handlers.bind(self.code.getCurInst()) + self.nextBlock(handlers) last = len(node.handlers) - 1 for i in range(len(node.handlers)): @@ -342,9 +336,10 @@ class CodeGenerator: if expr: self.emit('DUP_TOP') self.visit(expr) - self.emit('COMPARE_OP', "exception match") - next = StackRef() + self.emit('COMPARE_OP', 'exception match') + next = self.newBlock() self.emit('JUMP_IF_FALSE', next) + self.nextBlock() self.emit('POP_TOP') self.emit('POP_TOP') if target: @@ -355,132 +350,72 @@ class CodeGenerator: self.visit(body) self.emit('JUMP_FORWARD', end) if expr: - next.bind(self.code.getCurInst()) + self.nextBlock(next) self.emit('POP_TOP') self.emit('END_FINALLY') if node.else_: - lElse.bind(self.code.getCurInst()) + self.nextBlock(lElse) self.visit(node.else_) - end.bind(self.code.getCurInst()) - return 1 + self.nextBlock(end) def visitTryFinally(self, node): - final = StackRef() + final = self.newBlock() self.emit('SET_LINENO', node.lineno) self.emit('SETUP_FINALLY', final) self.visit(node.body) self.emit('POP_BLOCK') self.emit('LOAD_CONST', None) - final.bind(self.code.getCurInst()) + self.nextBlock(final) self.visit(node.final) self.emit('END_FINALLY') - return 1 - def visitCompare(self, node): - """Comment from compile.c follows: - - The following code is generated for all but the last - comparison in a chain: - - label: on stack: opcode: jump to: - - a <code to load b> - a, b DUP_TOP - a, b, b ROT_THREE - b, a, b COMPARE_OP - b, 0-or-1 JUMP_IF_FALSE L1 - b, 1 POP_TOP - b - - We are now ready to repeat this sequence for the next - comparison in the chain. - - For the last we generate: - - b <code to load c> - b, c COMPARE_OP - 0-or-1 - - If there were any jumps to L1 (i.e., there was more than one - comparison), we generate: - - 0-or-1 JUMP_FORWARD L2 - L1: b, 0 ROT_TWO - 0, b POP_TOP - 0 - L2: 0-or-1 - """ - self.visit(node.expr) - # if refs are never emitted, subsequent bind call has no effect - l1 = StackRef() - l2 = StackRef() - for op, code in node.ops[:-1]: - # emit every comparison except the last - self.visit(code) - self.emit('DUP_TOP') - self.emit('ROT_THREE') - self.emit('COMPARE_OP', op) - # dupTop and compareOp cancel stack effect - self.emit('JUMP_IF_FALSE', l1) - self.emit('POP_TOP') - if node.ops: - # emit the last comparison - op, code = node.ops[-1] - self.visit(code) - self.emit('COMPARE_OP', op) - if len(node.ops) > 1: - self.emit('JUMP_FORWARD', l2) - l1.bind(self.code.getCurInst()) - self.emit('ROT_TWO') - self.emit('POP_TOP') - l2.bind(self.code.getCurInst()) - return 1 + # misc - def visitGetattr(self, node): - self.visit(node.expr) - self.emit('LOAD_ATTR', node.attrname) - return 1 +## def visitStmt(self, node): +## # nothing to do except walk the children +## pass - def visitSubscript(self, node): + def visitDiscard(self, node): self.visit(node.expr) - for sub in node.subs: - self.visit(sub) - if len(node.subs) > 1: - self.emit('BUILD_TUPLE', len(node.subs)) - if node.flags == 'OP_APPLY': - self.emit('BINARY_SUBSCR') - elif node.flags == 'OP_ASSIGN': - self.emit('STORE_SUBSCR') - elif node.flags == 'OP_DELETE': - self.emit('DELETE_SUBSCR') - return 1 + self.emit('POP_TOP') - def visitSlice(self, node): + def visitConst(self, node): + self.emit('LOAD_CONST', node.value) + + def visitKeyword(self, node): + self.emit('LOAD_CONST', node.name) + self.visit(node.expr) + + def visitGlobal(self, node): + # no code to generate + pass + + def visitName(self, node): + self.loadName(node.name) + + def visitPass(self, node): + self.emit('SET_LINENO', node.lineno) + + def visitImport(self, node): + self.emit('SET_LINENO', node.lineno) + for name in node.names: + self.emit('IMPORT_NAME', name) + self.storeName(name) + + def visitFrom(self, node): + self.emit('SET_LINENO', node.lineno) + self.emit('IMPORT_NAME', node.modname) + for name in node.names: + if name == '*': + self.namespace = 0 + self.emit('IMPORT_FROM', name) + self.emit('POP_TOP') + + def visitGetattr(self, node): self.visit(node.expr) - slice = 0 - if node.lower: - self.visit(node.lower) - slice = slice | 1 - if node.upper: - self.visit(node.upper) - slice = slice | 2 - if node.flags == 'OP_APPLY': - self.emit('SLICE+%d' % slice) - elif node.flags == 'OP_ASSIGN': - self.emit('STORE_SLICE+%d' % slice) - elif node.flags == 'OP_DELETE': - self.emit('DELETE_SLICE+%d' % slice) - else: - print "weird slice", node.flags - raise - return 1 + self.emit('LOAD_ATTR', node.attrname) - def visitSliceobj(self, node): - for child in node.nodes: - print child - self.visit(child) - self.emit('BUILD_SLICE', len(node.nodes)) - return 1 + # next five implement assignments def visitAssign(self, node): self.emit('SET_LINENO', node.lineno) @@ -492,7 +427,6 @@ class CodeGenerator: self.emit('DUP_TOP') if isinstance(elt, ast.Node): self.visit(elt) - return 1 def visitAssName(self, node): if node.flags == 'OP_ASSIGN': @@ -501,7 +435,6 @@ class CodeGenerator: self.delName(node.name) else: print "oops", node.flags - return 1 def visitAssAttr(self, node): self.visit(node.expr) @@ -512,24 +445,96 @@ class CodeGenerator: else: print "warning: unexpected flags:", node.flags print node - return 1 def visitAssTuple(self, node): if findOp(node) != 'OP_DELETE': self.emit('UNPACK_TUPLE', len(node.nodes)) for child in node.nodes: self.visit(child) - return 1 visitAssList = visitAssTuple + def visitExec(self, node): + self.visit(node.expr) + if node.locals is None: + self.emit('LOAD_CONST', None) + else: + self.visit(node.locals) + if node.globals is None: + self.emit('DUP_TOP') + else: + self.visit(node.globals) + self.emit('EXEC_STMT') + + def visitCallFunc(self, node): + pos = 0 + kw = 0 + if hasattr(node, 'lineno'): + self.emit('SET_LINENO', node.lineno) + self.visit(node.node) + for arg in node.args: + self.visit(arg) + if isinstance(arg, ast.Keyword): + kw = kw + 1 + else: + pos = pos + 1 + self.emit('CALL_FUNCTION', kw << 8 | pos) + + def visitPrint(self, node): + self.emit('SET_LINENO', node.lineno) + for child in node.nodes: + self.visit(child) + self.emit('PRINT_ITEM') + + def visitPrintnl(self, node): + self.visitPrint(node) + self.emit('PRINT_NEWLINE') + + def visitReturn(self, node): + self.emit('SET_LINENO', node.lineno) + self.visit(node.value) + self.emit('RETURN_VALUE') + + # slice and subscript stuff + + def visitSlice(self, node): + self.visit(node.expr) + slice = 0 + if node.lower: + self.visit(node.lower) + slice = slice | 1 + if node.upper: + self.visit(node.upper) + slice = slice | 2 + if node.flags == 'OP_APPLY': + self.emit('SLICE+%d' % slice) + elif node.flags == 'OP_ASSIGN': + self.emit('STORE_SLICE+%d' % slice) + elif node.flags == 'OP_DELETE': + self.emit('DELETE_SLICE+%d' % slice) + else: + print "weird slice", node.flags + raise + + def visitSubscript(self, node): + self.visit(node.expr) + for sub in node.subs: + self.visit(sub) + if len(node.subs) > 1: + self.emit('BUILD_TUPLE', len(node.subs)) + if node.flags == 'OP_APPLY': + self.emit('BINARY_SUBSCR') + elif node.flags == 'OP_ASSIGN': + self.emit('STORE_SUBSCR') + elif node.flags == 'OP_DELETE': + self.emit('DELETE_SUBSCR') + # binary ops def binaryOp(self, node, op): self.visit(node.left) self.visit(node.right) self.emit(op) - return 1 def visitAdd(self, node): return self.binaryOp(node, 'BINARY_ADD') @@ -560,7 +565,6 @@ class CodeGenerator: def unaryOp(self, node, op): self.visit(node.expr) self.emit(op) - return 1 def visitInvert(self, node): return self.unaryOp(node, 'UNARY_INVERT') @@ -587,7 +591,6 @@ class CodeGenerator: for node in nodes[1:]: self.visit(node) self.emit(op) - return 1 def visitBitand(self, node): return self.bitOp(node.nodes, 'BINARY_AND') @@ -598,121 +601,137 @@ class CodeGenerator: def visitBitxor(self, node): return self.bitOp(node.nodes, 'BINARY_XOR') - def visitAssert(self, node): - # XXX __debug__ and AssertionError appear to be special cases - # -- they are always loaded as globals even if there are local - # names. I guess this is a sort of renaming op. - skip = StackRef() - self.emit('SET_LINENO', node.lineno) - self.emit('LOAD_GLOBAL', '__debug__') - self.emit('JUMP_IF_FALSE', skip) - self.emit('POP_TOP') - self.visit(node.test) - self.emit('JUMP_IF_TRUE', skip) - self.emit('LOAD_GLOBAL', 'AssertionError') - self.visit(node.fail) - self.emit('RAISE_VARARGS', 2) - skip.bind(self.code.getCurInst()) - self.emit('POP_TOP') - return 1 - - def visitTest(self, node, jump): - end = StackRef() - for child in node.nodes[:-1]: - self.visit(child) - self.emit(jump, end) - self.emit('POP_TOP') - self.visit(node.nodes[-1]) - end.bind(self.code.getCurInst()) - return 1 - - def visitAnd(self, node): - return self.visitTest(node, 'JUMP_IF_FALSE') - - def visitOr(self, node): - return self.visitTest(node, 'JUMP_IF_TRUE') - - def visitName(self, node): - self.loadName(node.name) - - def visitConst(self, node): - self.emit('LOAD_CONST', node.value) - return 1 + # object constructors def visitEllipsis(self, node): self.emit('LOAD_CONST', Ellipsis) - return 1 def visitTuple(self, node): for elt in node.nodes: self.visit(elt) self.emit('BUILD_TUPLE', len(node.nodes)) - return 1 def visitList(self, node): for elt in node.nodes: self.visit(elt) self.emit('BUILD_LIST', len(node.nodes)) - return 1 + + def visitSliceobj(self, node): + for child in node.nodes: + self.visit(child) + self.emit('BUILD_SLICE', len(node.nodes)) def visitDict(self, node): + # XXX is this a good general strategy? could it be done + # separately from the general visitor + lineno = getattr(node, 'lineno', None) + if lineno: + self.emit('SET_LINENO', lineno) self.emit('BUILD_MAP', 0) for k, v in node.items: - # XXX need to add set lineno when there aren't constants + lineno2 = getattr(node, 'lineno', None) + if lineno != lineno2: + self.emit('SET_LINENO', lineno2) + lineno = lineno2 self.emit('DUP_TOP') self.visit(v) self.emit('ROT_TWO') self.visit(k) self.emit('STORE_SUBSCR') - return 1 - def visitReturn(self, node): - self.emit('SET_LINENO', node.lineno) - self.visit(node.value) - self.emit('RETURN_VALUE') - return 1 +class ModuleCodeGenerator(CodeGenerator): + super_init = CodeGenerator.__init__ + + def __init__(self, filename): + # XXX <module> is ? in compile.c + self.graph = pyassem.PyFlowGraph("<module>", filename) + self.super_init(filename) + +class FunctionCodeGenerator(CodeGenerator): + super_init = CodeGenerator.__init__ + + optimized = 1 + lambdaCount = 0 + + def __init__(self, func, filename, isLambda=0): + if isLambda: + klass = FunctionCodeGenerator + name = "<lambda.%d>" % klass.lambdaCount + klass.lambdaCount = klass.lambdaCount + 1 + else: + name = func.name + args, hasTupleArg = generateArgList(func.argnames) + self.graph = pyassem.PyFlowGraph(name, filename, args, + optimized=1) + self.isLambda = isLambda + self.super_init(filename) - def visitRaise(self, node): - self.emit('SET_LINENO', node.lineno) - n = 0 - if node.expr1: - self.visit(node.expr1) - n = n + 1 - if node.expr2: - self.visit(node.expr2) - n = n + 1 - if node.expr3: - self.visit(node.expr3) - n = n + 1 - self.emit('RAISE_VARARGS', n) - return 1 + lnf = walk(func.code, LocalNameFinder(args), 0) + self.locals.push(lnf.getLocals()) + if func.varargs: + self.graph.setFlag(CO_VARARGS) + if func.kwargs: + self.graph.setFlag(CO_VARKEYWORDS) + self.emit('SET_LINENO', func.lineno) + if hasTupleArg: + self.generateArgUnpack(func.argnames) - def visitPrint(self, node): - self.emit('SET_LINENO', node.lineno) - for child in node.nodes: - self.visit(child) - self.emit('PRINT_ITEM') - return 1 + def finish(self): + self.graph.startExitBlock() + if not self.isLambda: + self.emit('LOAD_CONST', None) + self.emit('RETURN_VALUE') - def visitPrintnl(self, node): - self.visitPrint(node) - self.emit('PRINT_NEWLINE') - return 1 + def generateArgUnpack(self, args): + count = 0 + for arg in args: + if type(arg) == types.TupleType: + self.emit('LOAD_FAST', '.nested%d' % count) + count = count + 1 + self.unpackTuple(arg) + + def unpackTuple(self, tup): + self.emit('UNPACK_TUPLE', len(tup)) + for elt in tup: + if type(elt) == types.TupleType: + self.unpackTuple(elt) + else: + self.emit('STORE_FAST', elt) - def visitExec(self, node): - self.visit(node.expr) - if node.locals is None: - self.emit('LOAD_CONST', None) - else: - self.visit(node.locals) - if node.globals is None: - self.emit('DUP_TOP') - else: - self.visit(node.globals) - self.emit('EXEC_STMT') - return 1 +class ClassCodeGenerator(CodeGenerator): + super_init = CodeGenerator.__init__ + + def __init__(self, klass, filename): + self.graph = pyassem.PyFlowGraph(klass.name, filename, + optimized=0) + self.super_init(filename) + lnf = walk(klass.code, LocalNameFinder(), 0) + self.locals.push(lnf.getLocals()) + + def finish(self): + self.graph.startExitBlock() + self.emit('LOAD_LOCALS') + self.emit('RETURN_VALUE') + + +def generateArgList(arglist): + """Generate an arg list marking TupleArgs""" + args = [] + extra = [] + count = 0 + for elt in arglist: + if type(elt) == types.StringType: + args.append(elt) + elif type(elt) == types.TupleType: + args.append(TupleArg(count, elt)) + count = count + 1 + extra.extend(misc.flatten(elt)) + else: + raise ValueError, "unexpect argument type:", elt + return args + extra, count class LocalNameFinder: + """Find local names in scope""" def __init__(self, names=()): self.names = misc.Set() self.globals = misc.Set() @@ -720,25 +739,23 @@ class LocalNameFinder: self.names.add(name) def getLocals(self): - for elt in self.globals.items(): + for elt in self.globals.elements(): if self.names.has_elt(elt): self.names.remove(elt) return self.names def visitDict(self, node): - return 1 + pass def visitGlobal(self, node): for name in node.names: self.globals.add(name) - return 1 def visitFunction(self, node): self.names.add(node.name) - return 1 def visitLambda(self, node): - return 1 + pass def visitImport(self, node): for name in node.names: @@ -750,11 +767,16 @@ class LocalNameFinder: def visitClass(self, node): self.names.add(node.name) - return 1 def visitAssName(self, node): self.names.add(node.name) +def findOp(node): + """Find the op (DELETE, LOAD, STORE) in an AssTuple tree""" + v = OpFinder() + walk(node, v, 0) + return v.op + class OpFinder: def __init__(self): self.op = None @@ -764,57 +786,8 @@ class OpFinder: elif self.op != node.flags: raise ValueError, "mixed ops in stmt" -def findOp(node): - v = OpFinder() - walk(node, v) - return v.op - -class Loop: - def __init__(self): - self.startAnchor = StackRef() - self.breakAnchor = StackRef() - self.extentAnchor = StackRef() - -class CompiledModule: - """Store the code object for a compiled module - - XXX Not clear how the code objects will be stored. Seems possible - that a single code attribute is sufficient, because it will - contains references to all the need code objects. That might be - messy, though. - """ - MAGIC = (20121 | (ord('\r')<<16) | (ord('\n')<<24)) - - def __init__(self, source, filename): - self.source = source - self.filename = filename +if __name__ == "__main__": + import sys - def compile(self): - self.ast = parse(self.source) - cg = CodeGenerator(self.filename) - walk(self.ast, cg) - self.code = cg.asConst() - - def dump(self, path): - """create a .pyc file""" - f = open(path, 'wb') - f.write(self._pyc_header()) - marshal.dump(self.code, f) - f.close() - - def _pyc_header(self): - # compile.c uses marshal to write a long directly, with - # calling the interface that would also generate a 1-byte code - # to indicate the type of the value. simplest way to get the - # same effect is to call marshal and then skip the code. - magic = marshal.dumps(self.MAGIC)[1:] - mtime = os.stat(self.filename)[stat.ST_MTIME] - mtime = struct.pack('i', mtime) - return magic + mtime - -def compile(filename): - buf = open(filename).read() - mod = CompiledModule(buf, filename) - mod.compile() - mod.dump(filename + 'c') - + for file in sys.argv[1:]: + compile(file) diff --git a/Tools/compiler/compiler/pyassem.py b/Tools/compiler/compiler/pyassem.py index 4cb910c..3272419 100644 --- a/Tools/compiler/compiler/pyassem.py +++ b/Tools/compiler/compiler/pyassem.py @@ -1,40 +1,127 @@ -"""Assembler for Python bytecode - -The new module is used to create the code object. The following -attribute definitions are included from the reference manual: - -co_name gives the function name -co_argcount is the number of positional arguments (including - arguments with default values) -co_nlocals is the number of local variables used by the function - (including arguments) -co_varnames is a tuple containing the names of the local variables - (starting with the argument names) -co_code is a string representing the sequence of bytecode instructions -co_consts is a tuple containing the literals used by the bytecode -co_names is a tuple containing the names used by the bytecode -co_filename is the filename from which the code was compiled -co_firstlineno is the first line number of the function -co_lnotab is a string encoding the mapping from byte code offsets - to line numbers. see LineAddrTable below. -co_stacksize is the required stack size (including local variables) -co_flags is an integer encoding a number of flags for the - interpreter. There are four flags: - CO_OPTIMIZED -- uses load fast - CO_NEWLOCALS -- everything? - CO_VARARGS -- use *args - CO_VARKEYWORDS -- uses **args - -If a code object represents a function, the first item in co_consts is -the documentation string of the function, or None if undefined. -""" - -import sys +"""A flow graph representation for Python bytecode""" + import dis import new import string +import types + +from compiler import misc + +class FlowGraph: + def __init__(self): + self.current = self.entry = Block() + self.exit = Block("exit") + self.blocks = misc.Set() + self.blocks.add(self.entry) + self.blocks.add(self.exit) + + def startBlock(self, block): + self.current = block + + def nextBlock(self, block=None): + if block is None: + block = self.newBlock() + # XXX think we need to specify when there is implicit transfer + # from one block to the next + # + # I think this strategy works: each block has a child + # designated as "next" which is returned as the last of the + # children. because the nodes in a graph are emitted in + # reverse post order, the "next" block will always be emitted + # immediately after its parent. + # Worry: maintaining this invariant could be tricky + self.current.addNext(block) + self.startBlock(block) + + def newBlock(self): + b = Block() + self.blocks.add(b) + return b + + def startExitBlock(self): + self.startBlock(self.exit) + + def emit(self, *inst): + # XXX should jump instructions implicitly call nextBlock? + if inst[0] == 'RETURN_VALUE': + self.current.addOutEdge(self.exit) + self.current.emit(inst) + + def getBlocks(self): + """Return the blocks in reverse postorder + + i.e. each node appears before all of its successors + """ + # XXX make sure every node that doesn't have an explicit next + # is set so that next points to exit + for b in self.blocks.elements(): + if b is self.exit: + continue + if not b.next: + b.addNext(self.exit) + order = dfs_postorder(self.entry, {}) + order.reverse() + # hack alert + if not self.exit in order: + order.append(self.exit) + return order + +def dfs_postorder(b, seen): + """Depth-first search of tree rooted at b, return in postorder""" + order = [] + seen[b] = b + for c in b.children(): + if seen.has_key(c): + continue + order = order + dfs_postorder(c, seen) + order.append(b) + return order + +class Block: + _count = 0 + + def __init__(self, label=''): + self.insts = [] + self.inEdges = misc.Set() + self.outEdges = misc.Set() + self.label = label + self.bid = Block._count + self.next = [] + Block._count = Block._count + 1 + + def __repr__(self): + if self.label: + return "<block %s id=%d len=%d>" % (self.label, self.bid, + len(self.insts)) + else: + return "<block id=%d len=%d>" % (self.bid, len(self.insts)) + + def __str__(self): + insts = map(str, self.insts) + return "<block %s %d:\n%s>" % (self.label, self.bid, + string.join(insts, '\n')) + + def emit(self, inst): + op = inst[0] + if op[:4] == 'JUMP': + self.outEdges.add(inst[1]) + self.insts.append(inst) + + def getInstructions(self): + return self.insts + + def addInEdge(self, block): + self.inEdges.add(block) + + def addOutEdge(self, block): + self.outEdges.add(block) + + def addNext(self, block): + self.next.append(block) + assert len(self.next) == 1, map(str, self.next) -import misc + def children(self): + return self.outEdges.elements() + self.next # flags for code objects CO_OPTIMIZED = 0x0001 @@ -42,224 +129,128 @@ CO_NEWLOCALS = 0x0002 CO_VARARGS = 0x0004 CO_VARKEYWORDS = 0x0008 -class TupleArg: - def __init__(self, count, names): - self.count = count - self.names = names - def __repr__(self): - return "TupleArg(%s, %s)" % (self.count, self.names) - def getName(self): - return ".nested%d" % self.count - -class PyAssembler: - """Creates Python code objects - """ - - # XXX this class needs to major refactoring - - def __init__(self, args=(), name='?', filename='<?>', - docstring=None): - # XXX why is the default value for flags 3? - self.insts = [] - # used by makeCodeObject - self._getArgCount(args) - self.code = '' - self.consts = [docstring] - self.filename = filename - self.flags = CO_NEWLOCALS - self.name = name - self.names = [] +# the FlowGraph is transformed in place; it exists in one of these states +RAW = "RAW" +FLAT = "FLAT" +CONV = "CONV" +DONE = "DONE" + +class PyFlowGraph(FlowGraph): + super_init = FlowGraph.__init__ + + def __init__(self, name, filename, args=(), optimized=0): + self.super_init() + self.name = name + self.filename = filename + self.docstring = None + self.args = args # XXX + self.argcount = getArgCount(args) + if optimized: + self.flags = CO_OPTIMIZED | CO_NEWLOCALS + else: + self.flags = 0 + self.firstlineno = None + self.consts = [] + self.names = [] self.varnames = list(args) or [] for i in range(len(self.varnames)): var = self.varnames[i] if isinstance(var, TupleArg): self.varnames[i] = var.getName() - # lnotab support - self.firstlineno = 0 - self.lastlineno = 0 - self.last_addr = 0 - self.lnotab = '' - - def _getArgCount(self, args): - self.argcount = len(args) - if args: - for arg in args: - if isinstance(arg, TupleArg): - numNames = len(misc.flatten(arg.names)) - self.argcount = self.argcount - numNames + self.stage = RAW - def __repr__(self): - return "<bytecode: %d instrs>" % len(self.insts) - - def setFlags(self, val): - """XXX for module's function""" - self.flags = val - - def setOptimized(self): - self.flags = self.flags | CO_OPTIMIZED - - def setVarArgs(self): - if not self.flags & CO_VARARGS: - self.flags = self.flags | CO_VARARGS - self.argcount = self.argcount - 1 - - def setKWArgs(self): - self.flags = self.flags | CO_VARKEYWORDS - - def getCurInst(self): - return len(self.insts) + def setDocstring(self, doc): + self.docstring = doc + self.consts.insert(0, doc) - def getNextInst(self): - return len(self.insts) + 1 + def setFlag(self, flag): + self.flags = self.flags | flag + if flag == CO_VARARGS: + self.argcount = self.argcount - 1 - def dump(self, io=sys.stdout): - i = 0 - for inst in self.insts: - if inst[0] == 'SET_LINENO': - io.write("\n") - io.write(" %3d " % i) - if len(inst) == 1: - io.write("%s\n" % inst) - else: - io.write("%-15.15s\t%s\n" % inst) - i = i + 1 - - def makeCodeObject(self): - """Make a Python code object - - This creates a Python code object using the new module. This - seems simpler than reverse-engineering the way marshal dumps - code objects into .pyc files. One of the key difficulties is - figuring out how to layout references to code objects that - appear on the VM stack; e.g. - 3 SET_LINENO 1 - 6 LOAD_CONST 0 (<code object fact at 8115878 [...] - 9 MAKE_FUNCTION 0 - 12 STORE_NAME 0 (fact) - """ - - self._findOffsets() - lnotab = LineAddrTable() + def getCode(self): + """Get a Python code object""" + if self.stage == RAW: + self.flattenGraph() + if self.stage == FLAT: + self.convertArgs() + if self.stage == CONV: + self.makeByteCode() + if self.stage == DONE: + return self.newCodeObject() + raise RuntimeError, "inconsistent PyFlowGraph state" + + def dump(self, io=None): + if io: + save = sys.stdout + sys.stdout = io + pc = 0 for t in self.insts: opname = t[0] + if opname == "SET_LINENO": + print if len(t) == 1: - lnotab.addCode(self.opnum[opname]) - elif len(t) == 2: - if opname == 'SET_LINENO': - oparg = t[1] - lnotab.nextLine(oparg) + print "\t", "%3d" % pc, opname + pc = pc + 1 + else: + print "\t", "%3d" % pc, opname, t[1] + pc = pc + 3 + if io: + sys.stdout = save + + def flattenGraph(self): + """Arrange the blocks in order and resolve jumps""" + assert self.stage == RAW + self.insts = insts = [] + pc = 0 + begin = {} + end = {} + for b in self.getBlocks(): + begin[b] = pc + for inst in b.getInstructions(): + insts.append(inst) + if len(inst) == 1: + pc = pc + 1 else: - oparg = self._convertArg(opname, t[1]) - try: - hi, lo = divmod(oparg, 256) - except TypeError: - raise TypeError, "untranslated arg: %s, %s" % (opname, oparg) - lnotab.addCode(self.opnum[opname], lo, hi) - - # why is a module a special case? - if self.flags == 0: - nlocals = 0 - else: - nlocals = len(self.varnames) - # XXX danger! can't pass through here twice - if self.flags & CO_VARKEYWORDS: - self.argcount = self.argcount - 1 - stacksize = findDepth(self.insts) - try: - co = new.code(self.argcount, nlocals, stacksize, - self.flags, lnotab.getCode(), self._getConsts(), - tuple(self.names), tuple(self.varnames), - self.filename, self.name, self.firstlineno, - lnotab.getTable()) - except SystemError, err: - print err - print repr(self.argcount) - print repr(nlocals) - print repr(stacksize) - print repr(self.flags) - print repr(lnotab.getCode()) - print repr(self._getConsts()) - print repr(self.names) - print repr(self.varnames) - print repr(self.filename) - print repr(self.name) - print repr(self.firstlineno) - print repr(lnotab.getTable()) - raise - return co - - def _getConsts(self): - """Return a tuple for the const slot of a code object - - Converts PythonVMCode objects to code objects - """ - l = [] - for elt in self.consts: - # XXX might be clearer to just as isinstance(CodeGen) - if hasattr(elt, 'asConst'): - l.append(elt.asConst()) + # arg takes 2 bytes + pc = pc + 3 + end[b] = pc + pc = 0 + for i in range(len(insts)): + inst = insts[i] + if len(inst) == 1: + pc = pc + 1 else: - l.append(elt) - return tuple(l) + pc = pc + 3 + opname = inst[0] + if self.hasjrel.has_elt(opname): + oparg = inst[1] + offset = begin[oparg] - pc + insts[i] = opname, offset + elif self.hasjabs.has_elt(opname): + insts[i] = opname, begin[inst[1]] + self.stacksize = findDepth(self.insts) + self.stage = FLAT - def _findOffsets(self): - """Find offsets for use in resolving StackRefs""" - self.offsets = [] - cur = 0 - for t in self.insts: - self.offsets.append(cur) - l = len(t) - if l == 1: - cur = cur + 1 - elif l == 2: - cur = cur + 3 - arg = t[1] - # XXX this is a total hack: for a reference used - # multiple times, we create a list of offsets and - # expect that we when we pass through the code again - # to actually generate the offsets, we'll pass in the - # same order. - if isinstance(arg, StackRef): - try: - arg.__offset.append(cur) - except AttributeError: - arg.__offset = [cur] - - def _convertArg(self, op, arg): - """Convert the string representation of an arg to a number - - The specific handling depends on the opcode. - - XXX This first implementation isn't going to be very - efficient. - """ - if op == 'SET_LINENO': - return arg - if op == 'LOAD_CONST': - return self._lookupName(arg, self.consts) - if op in self.localOps: - # make sure it's in self.names, but use the bytecode offset - self._lookupName(arg, self.names) - return self._lookupName(arg, self.varnames) - if op in self.globalOps: - return self._lookupName(arg, self.names) - if op in self.nameOps: - return self._lookupName(arg, self.names) - if op == 'COMPARE_OP': - return self.cmp_op.index(arg) - if self.hasjrel.has_elt(op): - offset = arg.__offset[0] - del arg.__offset[0] - return self.offsets[arg.resolve()] - offset - if self.hasjabs.has_elt(op): - return self.offsets[arg.resolve()] - return arg - - nameOps = ('STORE_NAME', 'IMPORT_NAME', 'IMPORT_FROM', - 'STORE_ATTR', 'LOAD_ATTR', 'LOAD_NAME', 'DELETE_NAME', - 'DELETE_ATTR') - localOps = ('LOAD_FAST', 'STORE_FAST', 'DELETE_FAST') - globalOps = ('LOAD_GLOBAL', 'STORE_GLOBAL', 'DELETE_GLOBAL') + hasjrel = misc.Set() + for i in dis.hasjrel: + hasjrel.add(dis.opname[i]) + hasjabs = misc.Set() + for i in dis.hasjabs: + hasjabs.add(dis.opname[i]) + + def convertArgs(self): + """Convert arguments from symbolic to concrete form""" + assert self.stage == FLAT + for i in range(len(self.insts)): + t = self.insts[i] + if len(t) == 2: + opname = t[0] + oparg = t[1] + conv = self._converters.get(opname, None) + if conv: + self.insts[i] = opname, conv(self, oparg) + self.stage = CONV def _lookupName(self, name, list): """Return index of name in list, appending if necessary""" @@ -276,32 +267,124 @@ class PyAssembler: list.append(name) return end - # Convert some stuff from the dis module for local use - - cmp_op = list(dis.cmp_op) - hasjrel = misc.Set() - for i in dis.hasjrel: - hasjrel.add(dis.opname[i]) - hasjabs = misc.Set() - for i in dis.hasjabs: - hasjabs.add(dis.opname[i]) - + _converters = {} + def _convert_LOAD_CONST(self, arg): + return self._lookupName(arg, self.consts) + + def _convert_LOAD_FAST(self, arg): + self._lookupName(arg, self.names) + return self._lookupName(arg, self.varnames) + _convert_STORE_FAST = _convert_LOAD_FAST + _convert_DELETE_FAST = _convert_LOAD_FAST + + def _convert_NAME(self, arg): + return self._lookupName(arg, self.names) + _convert_LOAD_NAME = _convert_NAME + _convert_STORE_NAME = _convert_NAME + _convert_DELETE_NAME = _convert_NAME + _convert_IMPORT_NAME = _convert_NAME + _convert_IMPORT_FROM = _convert_NAME + _convert_STORE_ATTR = _convert_NAME + _convert_LOAD_ATTR = _convert_NAME + _convert_DELETE_ATTR = _convert_NAME + _convert_LOAD_GLOBAL = _convert_NAME + _convert_STORE_GLOBAL = _convert_NAME + _convert_DELETE_GLOBAL = _convert_NAME + + _cmp = list(dis.cmp_op) + def _convert_COMPARE_OP(self, arg): + return self._cmp.index(arg) + + # similarly for other opcodes... + + for name, obj in locals().items(): + if name[:9] == "_convert_": + opname = name[9:] + _converters[opname] = obj + del name, obj, opname + + def makeByteCode(self): + assert self.stage == CONV + self.lnotab = lnotab = LineAddrTable() + for t in self.insts: + opname = t[0] + if len(t) == 1: + lnotab.addCode(self.opnum[opname]) + else: + oparg = t[1] + if opname == "SET_LINENO": + lnotab.nextLine(oparg) + if self.firstlineno is None: + self.firstlineno = oparg + hi, lo = twobyte(oparg) + try: + lnotab.addCode(self.opnum[opname], lo, hi) + except ValueError: + print opname, oparg + print self.opnum[opname], lo, hi + raise + self.stage = DONE + opnum = {} for num in range(len(dis.opname)): opnum[dis.opname[num]] = num + del num - # this version of emit + arbitrary hooks might work, but it's damn - # messy. + def newCodeObject(self): + assert self.stage == DONE + if self.flags == 0: + nlocals = 0 + else: + nlocals = len(self.varnames) + argcount = self.argcount + if self.flags & CO_VARKEYWORDS: + argcount = argcount - 1 + return new.code(argcount, nlocals, self.stacksize, self.flags, + self.lnotab.getCode(), self.getConsts(), + tuple(self.names), tuple(self.varnames), + self.filename, self.name, self.firstlineno, + self.lnotab.getTable()) + + def getConsts(self): + """Return a tuple for the const slot of the code object + + Must convert references to code (MAKE_FUNCTION) to code + objects recursively. + """ + l = [] + for elt in self.consts: + if isinstance(elt, PyFlowGraph): + elt = elt.getCode() + l.append(elt) + return tuple(l) + +def isJump(opname): + if opname[:4] == 'JUMP': + return 1 - def emit(self, *args): - self._emitDispatch(args[0], args[1:]) - self.insts.append(args) +class TupleArg: + """Helper for marking func defs with nested tuples in arglist""" + def __init__(self, count, names): + self.count = count + self.names = names + def __repr__(self): + return "TupleArg(%s, %s)" % (self.count, self.names) + def getName(self): + return ".nested%d" % self.count - def _emitDispatch(self, type, args): - for func in self._emit_hooks.get(type, []): - func(self, args) +def getArgCount(args): + argcount = len(args) + if args: + for arg in args: + if isinstance(arg, TupleArg): + numNames = len(misc.flatten(arg.names)) + argcount = argcount - numNames + return argcount - _emit_hooks = {} +def twobyte(val): + """Convert an int argument into high and low bytes""" + assert type(val) == types.IntType + return divmod(val, 256) class LineAddrTable: """lnotab @@ -361,34 +444,9 @@ class LineAddrTable: def getTable(self): return string.join(map(chr, self.lnotab), '') -class StackRef: - """Manage stack locations for jumps, loops, etc.""" - count = 0 - - def __init__(self, id=None, val=None): - if id is None: - id = StackRef.count - StackRef.count = StackRef.count + 1 - self.id = id - self.val = val - - def __repr__(self): - if self.val: - return "StackRef(val=%d)" % self.val - else: - return "StackRef(id=%d)" % self.id - - def bind(self, inst): - self.val = inst - - def resolve(self): - if self.val is None: - print "UNRESOLVE REF", self - return 0 - return self.val - class StackDepthTracker: - # XXX need to keep track of stack depth on jumps + # XXX 1. need to keep track of stack depth on jumps + # XXX 2. at least partly as a result, this code is broken def findDepth(self, insts): depth = 0 diff --git a/Tools/compiler/compiler/pycodegen.py b/Tools/compiler/compiler/pycodegen.py index b2d55d9..2e98d4e 100644 --- a/Tools/compiler/compiler/pycodegen.py +++ b/Tools/compiler/compiler/pycodegen.py @@ -1,132 +1,84 @@ -"""Python bytecode generator - -Currently contains generic ASTVisitor code, a LocalNameFinder, and a -CodeGenerator. Eventually, this will get split into the ASTVisitor as -a generic tool and CodeGenerator as a specific tool. -""" - -from compiler import parseFile, ast, visitor, walk, parse -from pyassem import StackRef, PyAssembler, TupleArg -import dis -import misc -import marshal -import new -import string -import sys import os +import marshal import stat import struct import types +from cStringIO import StringIO -class CodeGenerator: - """Generate bytecode for the Python VM""" - - OPTIMIZED = 1 - - # XXX should clean up initialization and generateXXX funcs - def __init__(self, filename="<?>"): - self.filename = filename - self.code = PyAssembler() - self.code.setFlags(0) - self.locals = misc.Stack() - self.loops = misc.Stack() - self.namespace = 0 - self.curStack = 0 - self.maxStack = 0 - - def emit(self, *args): - # XXX could just use self.emit = self.code.emit - apply(self.code.emit, args) - - def _generateFunctionOrLambdaCode(self, func): - self.name = func.name - - # keep a lookout for 'def foo((x,y)):' - args, hasTupleArg = self.generateArglist(func.argnames) - - self.code = PyAssembler(args=args, name=func.name, - filename=self.filename) - self.namespace = self.OPTIMIZED - if func.varargs: - self.code.setVarArgs() - if func.kwargs: - self.code.setKWArgs() - lnf = walk(func.code, LocalNameFinder(args), 0) - self.locals.push(lnf.getLocals()) - self.emit('SET_LINENO', func.lineno) - if hasTupleArg: - self.generateArgUnpack(func.argnames) - walk(func.code, self) +from compiler import ast, parse, walk +from compiler import pyassem, misc +from compiler.pyassem import CO_VARARGS, CO_VARKEYWORDS, TupleArg - def generateArglist(self, arglist): - args = [] - extra = [] - count = 0 - for elt in arglist: - if type(elt) == types.StringType: - args.append(elt) - elif type(elt) == types.TupleType: - args.append(TupleArg(count, elt)) - count = count + 1 - extra.extend(misc.flatten(elt)) - else: - raise ValueError, "unexpect argument type:", elt - return args + extra, count +def compile(filename): + f = open(filename) + buf = f.read() + f.close() + mod = Module(buf, filename) + mod.compile() + f = open(filename + "c", "wb") + mod.dump(f) + f.close() - def generateArgUnpack(self, args): - count = 0 - for arg in args: - if type(arg) == types.TupleType: - self.emit('LOAD_FAST', '.nested%d' % count) - count = count + 1 - self.unpackTuple(arg) - - def unpackTuple(self, tup): - self.emit('UNPACK_TUPLE', len(tup)) - for elt in tup: - if type(elt) == types.TupleType: - self.unpackTuple(elt) - else: - self.emit('STORE_FAST', elt) +class Module: + def __init__(self, source, filename): + self.filename = filename + self.source = source + self.code = None - def generateFunctionCode(self, func): - """Generate code for a function body""" - self._generateFunctionOrLambdaCode(func) - self.emit('LOAD_CONST', None) - self.emit('RETURN_VALUE') + def compile(self): + ast = parse(self.source) + root, filename = os.path.split(self.filename) + gen = ModuleCodeGenerator(filename) + walk(ast, gen, 1) + self.code = gen.getCode() - def generateLambdaCode(self, func): - self._generateFunctionOrLambdaCode(func) - self.emit('RETURN_VALUE') + def dump(self, f): + f.write(self.getPycHeader()) + marshal.dump(self.code, f) - def generateClassCode(self, klass): - self.code = PyAssembler(name=klass.name, - filename=self.filename) - self.emit('SET_LINENO', klass.lineno) - lnf = walk(klass.code, LocalNameFinder(), 0) - self.locals.push(lnf.getLocals()) - walk(klass.code, self) - self.emit('LOAD_LOCALS') - self.emit('RETURN_VALUE') + MAGIC = (20121 | (ord('\r')<<16) | (ord('\n')<<24)) + + def getPycHeader(self): + # compile.c uses marshal to write a long directly, with + # calling the interface that would also generate a 1-byte code + # to indicate the type of the value. simplest way to get the + # same effect is to call marshal and then skip the code. + magic = marshal.dumps(self.MAGIC)[1:] + mtime = os.stat(self.filename)[stat.ST_MTIME] + mtime = struct.pack('i', mtime) + return magic + mtime + +class CodeGenerator: + + optimized = 0 # is namespace access optimized? + + def __init__(self, filename): +## Subclasses must define a constructor that intializes self.graph +## before calling this init function +## self.graph = pyassem.PyFlowGraph() + self.filename = filename + self.locals = misc.Stack() + self.loops = misc.Stack() + self.curStack = 0 + self.maxStack = 0 + self._setupGraphDelegation() + + def _setupGraphDelegation(self): + self.emit = self.graph.emit + self.newBlock = self.graph.newBlock + self.startBlock = self.graph.startBlock + self.nextBlock = self.graph.nextBlock + self.setDocstring = self.graph.setDocstring - def asConst(self): - """Create a Python code object.""" - if self.namespace == self.OPTIMIZED: - self.code.setOptimized() - return self.code.makeCodeObject() + def getCode(self): + """Return a code object""" + return self.graph.getCode() + + # Next five methods handle name access def isLocalName(self, name): return self.locals.top().has_elt(name) - def _nameOp(self, prefix, name): - if self.isLocalName(name): - if self.namespace == self.OPTIMIZED: - self.emit(prefix + '_FAST', name) - else: - self.emit(prefix + '_NAME', name) - else: - self.emit(prefix + '_GLOBAL', name) - def storeName(self, name): self._nameOp('STORE', name) @@ -136,195 +88,237 @@ class CodeGenerator: def delName(self, name): self._nameOp('DELETE', name) - def visitNULL(self, node): - """Method exists only to stop warning in -v mode""" - pass - - visitStmt = visitNULL - visitGlobal = visitNULL - - def visitDiscard(self, node): - self.visit(node.expr) - self.emit('POP_TOP') - return 1 + def _nameOp(self, prefix, name): + if not self.optimized: + self.emit(prefix + '_NAME', name) + return + if self.isLocalName(name): + self.emit(prefix + '_FAST', name) + else: + self.emit(prefix + '_GLOBAL', name) - def visitPass(self, node): - self.emit('SET_LINENO', node.lineno) + # The first few visitor methods handle nodes that generator new + # code objects def visitModule(self, node): - lnf = walk(node.node, LocalNameFinder(), 0) - self.locals.push(lnf.getLocals()) - self.visit(node.node) - self.emit('LOAD_CONST', None) - self.emit('RETURN_VALUE') - return 1 + lnf = walk(node.node, LocalNameFinder(), 0) + self.locals.push(lnf.getLocals()) + self.setDocstring(node.doc) + self.visit(node.node) + self.emit('LOAD_CONST', None) + self.emit('RETURN_VALUE') - def visitImport(self, node): - self.emit('SET_LINENO', node.lineno) - for name in node.names: - self.emit('IMPORT_NAME', name) - self.storeName(name) + def visitFunction(self, node): + self._visitFuncOrLambda(node, isLambda=0) + self.storeName(node.name) - def visitFrom(self, node): - self.emit('SET_LINENO', node.lineno) - self.emit('IMPORT_NAME', node.modname) - for name in node.names: - if name == '*': - self.namespace = 0 - self.emit('IMPORT_FROM', name) - self.emit('POP_TOP') + def visitLambda(self, node): + self._visitFuncOrLambda(node, isLambda=1) +## self.storeName("<lambda>") + + def _visitFuncOrLambda(self, node, isLambda): + gen = FunctionCodeGenerator(node, self.filename, isLambda) + walk(node.code, gen) + gen.finish() + self.emit('SET_LINENO', node.lineno) + for default in node.defaults: + self.visit(default) + self.emit('LOAD_CONST', gen.getCode()) + self.emit('MAKE_FUNCTION', len(node.defaults)) def visitClass(self, node): + gen = ClassCodeGenerator(node, self.filename) + walk(node.code, gen) + gen.finish() self.emit('SET_LINENO', node.lineno) self.emit('LOAD_CONST', node.name) for base in node.bases: self.visit(base) self.emit('BUILD_TUPLE', len(node.bases)) - classBody = CodeGenerator(self.filename) - classBody.generateClassCode(node) - self.emit('LOAD_CONST', classBody) + self.emit('LOAD_CONST', gen.getCode()) self.emit('MAKE_FUNCTION', 0) self.emit('CALL_FUNCTION', 0) self.emit('BUILD_CLASS') self.storeName(node.name) - return 1 - def _visitFuncOrLambda(self, node, kind): - """Code common to Function and Lambda nodes""" - codeBody = CodeGenerator(self.filename) - getattr(codeBody, 'generate%sCode' % kind)(node) - self.emit('SET_LINENO', node.lineno) - for default in node.defaults: - self.visit(default) - self.emit('LOAD_CONST', codeBody) - self.emit('MAKE_FUNCTION', len(node.defaults)) + # The rest are standard visitor methods - def visitFunction(self, node): - self._visitFuncOrLambda(node, 'Function') - self.storeName(node.name) - return 1 + # The next few implement control-flow statements - def visitLambda(self, node): - node.name = '<lambda>' - self._visitFuncOrLambda(node, 'Lambda') - return 1 + def visitIf(self, node): + end = self.newBlock() + numtests = len(node.tests) + for i in range(numtests): + test, suite = node.tests[i] + if hasattr(test, 'lineno'): + self.emit('SET_LINENO', test.lineno) + self.visit(test) +## if i == numtests - 1 and not node.else_: +## nextTest = end +## else: +## nextTest = self.newBlock() + nextTest = self.newBlock() + self.emit('JUMP_IF_FALSE', nextTest) + self.nextBlock() + self.emit('POP_TOP') + self.visit(suite) + self.emit('JUMP_FORWARD', end) + self.nextBlock(nextTest) + self.emit('POP_TOP') + if node.else_: + self.visit(node.else_) + self.nextBlock(end) - def visitCallFunc(self, node): - pos = 0 - kw = 0 - if hasattr(node, 'lineno'): - self.emit('SET_LINENO', node.lineno) - self.visit(node.node) - for arg in node.args: - self.visit(arg) - if isinstance(arg, ast.Keyword): - kw = kw + 1 - else: - pos = pos + 1 - self.emit('CALL_FUNCTION', kw << 8 | pos) - return 1 + def visitWhile(self, node): + self.emit('SET_LINENO', node.lineno) - def visitKeyword(self, node): - self.emit('LOAD_CONST', node.name) - self.visit(node.expr) - return 1 + loop = self.newBlock() + else_ = self.newBlock() - def visitIf(self, node): - after = StackRef() - for test, suite in node.tests: - if hasattr(test, 'lineno'): - self.emit('SET_LINENO', test.lineno) - else: - print "warning", "no line number" - self.visit(test) - dest = StackRef() - self.emit('JUMP_IF_FALSE', dest) - self.emit('POP_TOP') - self.visit(suite) - self.emit('JUMP_FORWARD', after) - dest.bind(self.code.getCurInst()) - self.emit('POP_TOP') - if node.else_: - self.visit(node.else_) - after.bind(self.code.getCurInst()) - return 1 + after = self.newBlock() + self.emit('SETUP_LOOP', after) + + self.nextBlock(loop) + self.loops.push(loop) + + self.emit('SET_LINENO', node.lineno) + self.visit(node.test) + self.emit('JUMP_IF_FALSE', else_ or after) - def startLoop(self): - l = Loop() - self.loops.push(l) - self.emit('SETUP_LOOP', l.extentAnchor) - return l + self.nextBlock() + self.emit('POP_TOP') + self.visit(node.body) + self.emit('JUMP_ABSOLUTE', loop) - def finishLoop(self): - l = self.loops.pop() - i = self.code.getCurInst() - l.extentAnchor.bind(self.code.getCurInst()) + self.startBlock(else_) # or just the POPs if not else clause + self.emit('POP_TOP') + self.emit('POP_BLOCK') + if node.else_: + self.visit(node.else_) + self.loops.pop() + self.nextBlock(after) def visitFor(self, node): - # three refs needed - anchor = StackRef() + start = self.newBlock() + anchor = self.newBlock() + after = self.newBlock() + self.loops.push(start) self.emit('SET_LINENO', node.lineno) - l = self.startLoop() + self.emit('SETUP_LOOP', after) self.visit(node.list) self.visit(ast.Const(0)) - l.startAnchor.bind(self.code.getCurInst()) + self.nextBlock(start) self.emit('SET_LINENO', node.lineno) self.emit('FOR_LOOP', anchor) self.visit(node.assign) self.visit(node.body) - self.emit('JUMP_ABSOLUTE', l.startAnchor) - anchor.bind(self.code.getCurInst()) + self.emit('JUMP_ABSOLUTE', start) + self.nextBlock(anchor) self.emit('POP_BLOCK') if node.else_: self.visit(node.else_) - self.finishLoop() - return 1 - - def visitWhile(self, node): - self.emit('SET_LINENO', node.lineno) - l = self.startLoop() - if node.else_: - lElse = StackRef() - else: - lElse = l.breakAnchor - l.startAnchor.bind(self.code.getCurInst()) - if hasattr(node.test, 'lineno'): - self.emit('SET_LINENO', node.test.lineno) - self.visit(node.test) - self.emit('JUMP_IF_FALSE', lElse) - self.emit('POP_TOP') - self.visit(node.body) - self.emit('JUMP_ABSOLUTE', l.startAnchor) - # note that lElse may be an alias for l.breakAnchor - lElse.bind(self.code.getCurInst()) - self.emit('POP_TOP') - self.emit('POP_BLOCK') - if node.else_: - self.visit(node.else_) - self.finishLoop() - return 1 + self.loops.pop() + self.nextBlock(after) def visitBreak(self, node): - if not self.loops: - raise SyntaxError, "'break' outside loop" - self.emit('SET_LINENO', node.lineno) - self.emit('BREAK_LOOP') + if not self.loops: + raise SyntaxError, "'break' outside loop (%s, %d)" % \ + (self.filename, node.lineno) + self.emit('SET_LINENO', node.lineno) + self.emit('BREAK_LOOP') def visitContinue(self, node): if not self.loops: - raise SyntaxError, "'continue' outside loop" + raise SyntaxError, "'continue' outside loop (%s, %d)" % \ + (self.filename, node.lineno) l = self.loops.top() self.emit('SET_LINENO', node.lineno) - self.emit('JUMP_ABSOLUTE', l.startAnchor) + self.emit('JUMP_ABSOLUTE', l) + self.nextBlock() + + def visitTest(self, node, jump): + end = self.newBlock() + for child in node.nodes[:-1]: + self.visit(child) + self.emit(jump, end) + self.nextBlock() + self.emit('POP_TOP') + self.visit(node.nodes[-1]) + self.nextBlock(end) + + def visitAnd(self, node): + self.visitTest(node, 'JUMP_IF_FALSE') + + def visitOr(self, node): + self.visitTest(node, 'JUMP_IF_TRUE') + + def visitCompare(self, node): + self.visit(node.expr) + cleanup = self.newBlock() + for op, code in node.ops[:-1]: + self.visit(code) + self.emit('DUP_TOP') + self.emit('ROT_THREE') + self.emit('COMPARE_OP', op) + self.emit('JUMP_IF_FALSE', cleanup) + self.nextBlock() + self.emit('POP_TOP') + # now do the last comparison + if node.ops: + op, code = node.ops[-1] + self.visit(code) + self.emit('COMPARE_OP', op) + if len(node.ops) > 1: + end = self.newBlock() + self.emit('JUMP_FORWARD', end) + self.nextBlock(cleanup) + self.emit('ROT_TWO') + self.emit('POP_TOP') + self.nextBlock(end) + + # exception related + + def visitAssert(self, node): + # XXX would be interesting to implement this via a + # transformation of the AST before this stage + end = self.newBlock() + self.emit('SET_LINENO', node.lineno) + # XXX __debug__ and AssertionError appear to be special cases + # -- they are always loaded as globals even if there are local + # names. I guess this is a sort of renaming op. + self.emit('LOAD_GLOBAL', '__debug__') + self.emit('JUMP_IF_FALSE', end) + self.nextBlock() + self.emit('POP_TOP') + self.visit(node.test) + self.emit('JUMP_IF_TRUE', end) + self.nextBlock() + self.emit('LOAD_GLOBAL', 'AssertionError') + self.visit(node.fail) + self.emit('RAISE_VARARGS', 2) + self.nextBlock(end) + self.emit('POP_TOP') + + def visitRaise(self, node): + self.emit('SET_LINENO', node.lineno) + n = 0 + if node.expr1: + self.visit(node.expr1) + n = n + 1 + if node.expr2: + self.visit(node.expr2) + n = n + 1 + if node.expr3: + self.visit(node.expr3) + n = n + 1 + self.emit('RAISE_VARARGS', n) def visitTryExcept(self, node): - # XXX need to figure out exactly what is on the stack when an - # exception is raised and the first handler is checked - handlers = StackRef() - end = StackRef() + handlers = self.newBlock() + end = self.newBlock() if node.else_: - lElse = StackRef() + lElse = self.newBlock() else: lElse = end self.emit('SET_LINENO', node.lineno) @@ -332,7 +326,7 @@ class CodeGenerator: self.visit(node.body) self.emit('POP_BLOCK') self.emit('JUMP_FORWARD', lElse) - handlers.bind(self.code.getCurInst()) + self.nextBlock(handlers) last = len(node.handlers) - 1 for i in range(len(node.handlers)): @@ -342,9 +336,10 @@ class CodeGenerator: if expr: self.emit('DUP_TOP') self.visit(expr) - self.emit('COMPARE_OP', "exception match") - next = StackRef() + self.emit('COMPARE_OP', 'exception match') + next = self.newBlock() self.emit('JUMP_IF_FALSE', next) + self.nextBlock() self.emit('POP_TOP') self.emit('POP_TOP') if target: @@ -355,132 +350,72 @@ class CodeGenerator: self.visit(body) self.emit('JUMP_FORWARD', end) if expr: - next.bind(self.code.getCurInst()) + self.nextBlock(next) self.emit('POP_TOP') self.emit('END_FINALLY') if node.else_: - lElse.bind(self.code.getCurInst()) + self.nextBlock(lElse) self.visit(node.else_) - end.bind(self.code.getCurInst()) - return 1 + self.nextBlock(end) def visitTryFinally(self, node): - final = StackRef() + final = self.newBlock() self.emit('SET_LINENO', node.lineno) self.emit('SETUP_FINALLY', final) self.visit(node.body) self.emit('POP_BLOCK') self.emit('LOAD_CONST', None) - final.bind(self.code.getCurInst()) + self.nextBlock(final) self.visit(node.final) self.emit('END_FINALLY') - return 1 - def visitCompare(self, node): - """Comment from compile.c follows: - - The following code is generated for all but the last - comparison in a chain: - - label: on stack: opcode: jump to: - - a <code to load b> - a, b DUP_TOP - a, b, b ROT_THREE - b, a, b COMPARE_OP - b, 0-or-1 JUMP_IF_FALSE L1 - b, 1 POP_TOP - b - - We are now ready to repeat this sequence for the next - comparison in the chain. - - For the last we generate: - - b <code to load c> - b, c COMPARE_OP - 0-or-1 - - If there were any jumps to L1 (i.e., there was more than one - comparison), we generate: - - 0-or-1 JUMP_FORWARD L2 - L1: b, 0 ROT_TWO - 0, b POP_TOP - 0 - L2: 0-or-1 - """ - self.visit(node.expr) - # if refs are never emitted, subsequent bind call has no effect - l1 = StackRef() - l2 = StackRef() - for op, code in node.ops[:-1]: - # emit every comparison except the last - self.visit(code) - self.emit('DUP_TOP') - self.emit('ROT_THREE') - self.emit('COMPARE_OP', op) - # dupTop and compareOp cancel stack effect - self.emit('JUMP_IF_FALSE', l1) - self.emit('POP_TOP') - if node.ops: - # emit the last comparison - op, code = node.ops[-1] - self.visit(code) - self.emit('COMPARE_OP', op) - if len(node.ops) > 1: - self.emit('JUMP_FORWARD', l2) - l1.bind(self.code.getCurInst()) - self.emit('ROT_TWO') - self.emit('POP_TOP') - l2.bind(self.code.getCurInst()) - return 1 + # misc - def visitGetattr(self, node): - self.visit(node.expr) - self.emit('LOAD_ATTR', node.attrname) - return 1 +## def visitStmt(self, node): +## # nothing to do except walk the children +## pass - def visitSubscript(self, node): + def visitDiscard(self, node): self.visit(node.expr) - for sub in node.subs: - self.visit(sub) - if len(node.subs) > 1: - self.emit('BUILD_TUPLE', len(node.subs)) - if node.flags == 'OP_APPLY': - self.emit('BINARY_SUBSCR') - elif node.flags == 'OP_ASSIGN': - self.emit('STORE_SUBSCR') - elif node.flags == 'OP_DELETE': - self.emit('DELETE_SUBSCR') - return 1 + self.emit('POP_TOP') - def visitSlice(self, node): + def visitConst(self, node): + self.emit('LOAD_CONST', node.value) + + def visitKeyword(self, node): + self.emit('LOAD_CONST', node.name) + self.visit(node.expr) + + def visitGlobal(self, node): + # no code to generate + pass + + def visitName(self, node): + self.loadName(node.name) + + def visitPass(self, node): + self.emit('SET_LINENO', node.lineno) + + def visitImport(self, node): + self.emit('SET_LINENO', node.lineno) + for name in node.names: + self.emit('IMPORT_NAME', name) + self.storeName(name) + + def visitFrom(self, node): + self.emit('SET_LINENO', node.lineno) + self.emit('IMPORT_NAME', node.modname) + for name in node.names: + if name == '*': + self.namespace = 0 + self.emit('IMPORT_FROM', name) + self.emit('POP_TOP') + + def visitGetattr(self, node): self.visit(node.expr) - slice = 0 - if node.lower: - self.visit(node.lower) - slice = slice | 1 - if node.upper: - self.visit(node.upper) - slice = slice | 2 - if node.flags == 'OP_APPLY': - self.emit('SLICE+%d' % slice) - elif node.flags == 'OP_ASSIGN': - self.emit('STORE_SLICE+%d' % slice) - elif node.flags == 'OP_DELETE': - self.emit('DELETE_SLICE+%d' % slice) - else: - print "weird slice", node.flags - raise - return 1 + self.emit('LOAD_ATTR', node.attrname) - def visitSliceobj(self, node): - for child in node.nodes: - print child - self.visit(child) - self.emit('BUILD_SLICE', len(node.nodes)) - return 1 + # next five implement assignments def visitAssign(self, node): self.emit('SET_LINENO', node.lineno) @@ -492,7 +427,6 @@ class CodeGenerator: self.emit('DUP_TOP') if isinstance(elt, ast.Node): self.visit(elt) - return 1 def visitAssName(self, node): if node.flags == 'OP_ASSIGN': @@ -501,7 +435,6 @@ class CodeGenerator: self.delName(node.name) else: print "oops", node.flags - return 1 def visitAssAttr(self, node): self.visit(node.expr) @@ -512,24 +445,96 @@ class CodeGenerator: else: print "warning: unexpected flags:", node.flags print node - return 1 def visitAssTuple(self, node): if findOp(node) != 'OP_DELETE': self.emit('UNPACK_TUPLE', len(node.nodes)) for child in node.nodes: self.visit(child) - return 1 visitAssList = visitAssTuple + def visitExec(self, node): + self.visit(node.expr) + if node.locals is None: + self.emit('LOAD_CONST', None) + else: + self.visit(node.locals) + if node.globals is None: + self.emit('DUP_TOP') + else: + self.visit(node.globals) + self.emit('EXEC_STMT') + + def visitCallFunc(self, node): + pos = 0 + kw = 0 + if hasattr(node, 'lineno'): + self.emit('SET_LINENO', node.lineno) + self.visit(node.node) + for arg in node.args: + self.visit(arg) + if isinstance(arg, ast.Keyword): + kw = kw + 1 + else: + pos = pos + 1 + self.emit('CALL_FUNCTION', kw << 8 | pos) + + def visitPrint(self, node): + self.emit('SET_LINENO', node.lineno) + for child in node.nodes: + self.visit(child) + self.emit('PRINT_ITEM') + + def visitPrintnl(self, node): + self.visitPrint(node) + self.emit('PRINT_NEWLINE') + + def visitReturn(self, node): + self.emit('SET_LINENO', node.lineno) + self.visit(node.value) + self.emit('RETURN_VALUE') + + # slice and subscript stuff + + def visitSlice(self, node): + self.visit(node.expr) + slice = 0 + if node.lower: + self.visit(node.lower) + slice = slice | 1 + if node.upper: + self.visit(node.upper) + slice = slice | 2 + if node.flags == 'OP_APPLY': + self.emit('SLICE+%d' % slice) + elif node.flags == 'OP_ASSIGN': + self.emit('STORE_SLICE+%d' % slice) + elif node.flags == 'OP_DELETE': + self.emit('DELETE_SLICE+%d' % slice) + else: + print "weird slice", node.flags + raise + + def visitSubscript(self, node): + self.visit(node.expr) + for sub in node.subs: + self.visit(sub) + if len(node.subs) > 1: + self.emit('BUILD_TUPLE', len(node.subs)) + if node.flags == 'OP_APPLY': + self.emit('BINARY_SUBSCR') + elif node.flags == 'OP_ASSIGN': + self.emit('STORE_SUBSCR') + elif node.flags == 'OP_DELETE': + self.emit('DELETE_SUBSCR') + # binary ops def binaryOp(self, node, op): self.visit(node.left) self.visit(node.right) self.emit(op) - return 1 def visitAdd(self, node): return self.binaryOp(node, 'BINARY_ADD') @@ -560,7 +565,6 @@ class CodeGenerator: def unaryOp(self, node, op): self.visit(node.expr) self.emit(op) - return 1 def visitInvert(self, node): return self.unaryOp(node, 'UNARY_INVERT') @@ -587,7 +591,6 @@ class CodeGenerator: for node in nodes[1:]: self.visit(node) self.emit(op) - return 1 def visitBitand(self, node): return self.bitOp(node.nodes, 'BINARY_AND') @@ -598,121 +601,137 @@ class CodeGenerator: def visitBitxor(self, node): return self.bitOp(node.nodes, 'BINARY_XOR') - def visitAssert(self, node): - # XXX __debug__ and AssertionError appear to be special cases - # -- they are always loaded as globals even if there are local - # names. I guess this is a sort of renaming op. - skip = StackRef() - self.emit('SET_LINENO', node.lineno) - self.emit('LOAD_GLOBAL', '__debug__') - self.emit('JUMP_IF_FALSE', skip) - self.emit('POP_TOP') - self.visit(node.test) - self.emit('JUMP_IF_TRUE', skip) - self.emit('LOAD_GLOBAL', 'AssertionError') - self.visit(node.fail) - self.emit('RAISE_VARARGS', 2) - skip.bind(self.code.getCurInst()) - self.emit('POP_TOP') - return 1 - - def visitTest(self, node, jump): - end = StackRef() - for child in node.nodes[:-1]: - self.visit(child) - self.emit(jump, end) - self.emit('POP_TOP') - self.visit(node.nodes[-1]) - end.bind(self.code.getCurInst()) - return 1 - - def visitAnd(self, node): - return self.visitTest(node, 'JUMP_IF_FALSE') - - def visitOr(self, node): - return self.visitTest(node, 'JUMP_IF_TRUE') - - def visitName(self, node): - self.loadName(node.name) - - def visitConst(self, node): - self.emit('LOAD_CONST', node.value) - return 1 + # object constructors def visitEllipsis(self, node): self.emit('LOAD_CONST', Ellipsis) - return 1 def visitTuple(self, node): for elt in node.nodes: self.visit(elt) self.emit('BUILD_TUPLE', len(node.nodes)) - return 1 def visitList(self, node): for elt in node.nodes: self.visit(elt) self.emit('BUILD_LIST', len(node.nodes)) - return 1 + + def visitSliceobj(self, node): + for child in node.nodes: + self.visit(child) + self.emit('BUILD_SLICE', len(node.nodes)) def visitDict(self, node): + # XXX is this a good general strategy? could it be done + # separately from the general visitor + lineno = getattr(node, 'lineno', None) + if lineno: + self.emit('SET_LINENO', lineno) self.emit('BUILD_MAP', 0) for k, v in node.items: - # XXX need to add set lineno when there aren't constants + lineno2 = getattr(node, 'lineno', None) + if lineno != lineno2: + self.emit('SET_LINENO', lineno2) + lineno = lineno2 self.emit('DUP_TOP') self.visit(v) self.emit('ROT_TWO') self.visit(k) self.emit('STORE_SUBSCR') - return 1 - def visitReturn(self, node): - self.emit('SET_LINENO', node.lineno) - self.visit(node.value) - self.emit('RETURN_VALUE') - return 1 +class ModuleCodeGenerator(CodeGenerator): + super_init = CodeGenerator.__init__ + + def __init__(self, filename): + # XXX <module> is ? in compile.c + self.graph = pyassem.PyFlowGraph("<module>", filename) + self.super_init(filename) + +class FunctionCodeGenerator(CodeGenerator): + super_init = CodeGenerator.__init__ + + optimized = 1 + lambdaCount = 0 + + def __init__(self, func, filename, isLambda=0): + if isLambda: + klass = FunctionCodeGenerator + name = "<lambda.%d>" % klass.lambdaCount + klass.lambdaCount = klass.lambdaCount + 1 + else: + name = func.name + args, hasTupleArg = generateArgList(func.argnames) + self.graph = pyassem.PyFlowGraph(name, filename, args, + optimized=1) + self.isLambda = isLambda + self.super_init(filename) - def visitRaise(self, node): - self.emit('SET_LINENO', node.lineno) - n = 0 - if node.expr1: - self.visit(node.expr1) - n = n + 1 - if node.expr2: - self.visit(node.expr2) - n = n + 1 - if node.expr3: - self.visit(node.expr3) - n = n + 1 - self.emit('RAISE_VARARGS', n) - return 1 + lnf = walk(func.code, LocalNameFinder(args), 0) + self.locals.push(lnf.getLocals()) + if func.varargs: + self.graph.setFlag(CO_VARARGS) + if func.kwargs: + self.graph.setFlag(CO_VARKEYWORDS) + self.emit('SET_LINENO', func.lineno) + if hasTupleArg: + self.generateArgUnpack(func.argnames) - def visitPrint(self, node): - self.emit('SET_LINENO', node.lineno) - for child in node.nodes: - self.visit(child) - self.emit('PRINT_ITEM') - return 1 + def finish(self): + self.graph.startExitBlock() + if not self.isLambda: + self.emit('LOAD_CONST', None) + self.emit('RETURN_VALUE') - def visitPrintnl(self, node): - self.visitPrint(node) - self.emit('PRINT_NEWLINE') - return 1 + def generateArgUnpack(self, args): + count = 0 + for arg in args: + if type(arg) == types.TupleType: + self.emit('LOAD_FAST', '.nested%d' % count) + count = count + 1 + self.unpackTuple(arg) + + def unpackTuple(self, tup): + self.emit('UNPACK_TUPLE', len(tup)) + for elt in tup: + if type(elt) == types.TupleType: + self.unpackTuple(elt) + else: + self.emit('STORE_FAST', elt) - def visitExec(self, node): - self.visit(node.expr) - if node.locals is None: - self.emit('LOAD_CONST', None) - else: - self.visit(node.locals) - if node.globals is None: - self.emit('DUP_TOP') - else: - self.visit(node.globals) - self.emit('EXEC_STMT') - return 1 +class ClassCodeGenerator(CodeGenerator): + super_init = CodeGenerator.__init__ + + def __init__(self, klass, filename): + self.graph = pyassem.PyFlowGraph(klass.name, filename, + optimized=0) + self.super_init(filename) + lnf = walk(klass.code, LocalNameFinder(), 0) + self.locals.push(lnf.getLocals()) + + def finish(self): + self.graph.startExitBlock() + self.emit('LOAD_LOCALS') + self.emit('RETURN_VALUE') + + +def generateArgList(arglist): + """Generate an arg list marking TupleArgs""" + args = [] + extra = [] + count = 0 + for elt in arglist: + if type(elt) == types.StringType: + args.append(elt) + elif type(elt) == types.TupleType: + args.append(TupleArg(count, elt)) + count = count + 1 + extra.extend(misc.flatten(elt)) + else: + raise ValueError, "unexpect argument type:", elt + return args + extra, count class LocalNameFinder: + """Find local names in scope""" def __init__(self, names=()): self.names = misc.Set() self.globals = misc.Set() @@ -720,25 +739,23 @@ class LocalNameFinder: self.names.add(name) def getLocals(self): - for elt in self.globals.items(): + for elt in self.globals.elements(): if self.names.has_elt(elt): self.names.remove(elt) return self.names def visitDict(self, node): - return 1 + pass def visitGlobal(self, node): for name in node.names: self.globals.add(name) - return 1 def visitFunction(self, node): self.names.add(node.name) - return 1 def visitLambda(self, node): - return 1 + pass def visitImport(self, node): for name in node.names: @@ -750,11 +767,16 @@ class LocalNameFinder: def visitClass(self, node): self.names.add(node.name) - return 1 def visitAssName(self, node): self.names.add(node.name) +def findOp(node): + """Find the op (DELETE, LOAD, STORE) in an AssTuple tree""" + v = OpFinder() + walk(node, v, 0) + return v.op + class OpFinder: def __init__(self): self.op = None @@ -764,57 +786,8 @@ class OpFinder: elif self.op != node.flags: raise ValueError, "mixed ops in stmt" -def findOp(node): - v = OpFinder() - walk(node, v) - return v.op - -class Loop: - def __init__(self): - self.startAnchor = StackRef() - self.breakAnchor = StackRef() - self.extentAnchor = StackRef() - -class CompiledModule: - """Store the code object for a compiled module - - XXX Not clear how the code objects will be stored. Seems possible - that a single code attribute is sufficient, because it will - contains references to all the need code objects. That might be - messy, though. - """ - MAGIC = (20121 | (ord('\r')<<16) | (ord('\n')<<24)) - - def __init__(self, source, filename): - self.source = source - self.filename = filename +if __name__ == "__main__": + import sys - def compile(self): - self.ast = parse(self.source) - cg = CodeGenerator(self.filename) - walk(self.ast, cg) - self.code = cg.asConst() - - def dump(self, path): - """create a .pyc file""" - f = open(path, 'wb') - f.write(self._pyc_header()) - marshal.dump(self.code, f) - f.close() - - def _pyc_header(self): - # compile.c uses marshal to write a long directly, with - # calling the interface that would also generate a 1-byte code - # to indicate the type of the value. simplest way to get the - # same effect is to call marshal and then skip the code. - magic = marshal.dumps(self.MAGIC)[1:] - mtime = os.stat(self.filename)[stat.ST_MTIME] - mtime = struct.pack('i', mtime) - return magic + mtime - -def compile(filename): - buf = open(filename).read() - mod = CompiledModule(buf, filename) - mod.compile() - mod.dump(filename + 'c') - + for file in sys.argv[1:]: + compile(file) |