diff options
author | Alexandre Vassalotti <alexandre@peadrop.com> | 2010-01-11 22:36:12 (GMT) |
---|---|---|
committer | Alexandre Vassalotti <alexandre@peadrop.com> | 2010-01-11 22:36:12 (GMT) |
commit | b646547bb45fe1df6abefd94f892c633798d91d2 (patch) | |
tree | ef1add045741d309129266726f5ba45562184091 /Lib/compiler | |
parent | 0ca7452794bef03b66f56cc996a73cac066d0ec1 (diff) | |
download | cpython-b646547bb45fe1df6abefd94f892c633798d91d2.zip cpython-b646547bb45fe1df6abefd94f892c633798d91d2.tar.gz cpython-b646547bb45fe1df6abefd94f892c633798d91d2.tar.bz2 |
Issue #2333: Backport set and dict comprehensions syntax.
Diffstat (limited to 'Lib/compiler')
-rw-r--r-- | Lib/compiler/ast.py | 45 | ||||
-rw-r--r-- | Lib/compiler/pyassem.py | 4 | ||||
-rw-r--r-- | Lib/compiler/pycodegen.py | 49 | ||||
-rw-r--r-- | Lib/compiler/transformer.py | 232 |
4 files changed, 225 insertions, 105 deletions
diff --git a/Lib/compiler/ast.py b/Lib/compiler/ast.py index f47434c..4c3fc16 100644 --- a/Lib/compiler/ast.py +++ b/Lib/compiler/ast.py @@ -890,6 +890,51 @@ class ListCompIf(Node): def __repr__(self): return "ListCompIf(%s)" % (repr(self.test),) +class SetComp(Node): + def __init__(self, expr, quals, lineno=None): + self.expr = expr + self.quals = quals + self.lineno = lineno + + def getChildren(self): + children = [] + children.append(self.expr) + children.extend(flatten(self.quals)) + return tuple(children) + + def getChildNodes(self): + nodelist = [] + nodelist.append(self.expr) + nodelist.extend(flatten_nodes(self.quals)) + return tuple(nodelist) + + def __repr__(self): + return "SetComp(%s, %s)" % (repr(self.expr), repr(self.quals)) + +class DictComp(Node): + def __init__(self, key, value, quals, lineno=None): + self.key = key + self.value = value + self.quals = quals + self.lineno = lineno + + def getChildren(self): + children = [] + children.append(self.key) + children.append(self.value) + children.extend(flatten(self.quals)) + return tuple(children) + + def getChildNodes(self): + nodelist = [] + nodelist.append(self.key) + nodelist.append(self.value) + nodelist.extend(flatten_nodes(self.quals)) + return tuple(nodelist) + + def __repr__(self): + return "DictComp(%s, %s, %s)" % (repr(self.key), repr(self.value), repr(self.quals)) + class Mod(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] diff --git a/Lib/compiler/pyassem.py b/Lib/compiler/pyassem.py index 88510de..286be0c 100644 --- a/Lib/compiler/pyassem.py +++ b/Lib/compiler/pyassem.py @@ -685,7 +685,9 @@ class StackDepthTracker: effect = { 'POP_TOP': -1, 'DUP_TOP': 1, - 'LIST_APPEND': -2, + 'LIST_APPEND': -1, + 'SET_ADD': -1, + 'MAP_ADD': -2, 'SLICE+1': -1, 'SLICE+2': -1, 'SLICE+3': -2, diff --git a/Lib/compiler/pycodegen.py b/Lib/compiler/pycodegen.py index bef9c70..4f2ecf2 100644 --- a/Lib/compiler/pycodegen.py +++ b/Lib/compiler/pycodegen.py @@ -589,6 +589,55 @@ class CodeGenerator: self.emit('JUMP_ABSOLUTE', start) self.startBlock(anchor) + def visitSetComp(self, node): + self.set_lineno(node) + # setup list + self.emit('BUILD_SET', 0) + + stack = [] + for i, for_ in zip(range(len(node.quals)), node.quals): + start, anchor = self.visit(for_) + cont = None + for if_ in for_.ifs: + if cont is None: + cont = self.newBlock() + self.visit(if_, cont) + stack.insert(0, (start, cont, anchor)) + + self.visit(node.expr) + self.emit('SET_ADD', len(node.quals) + 1) + + for start, cont, anchor in stack: + if cont: + self.nextBlock(cont) + self.emit('JUMP_ABSOLUTE', start) + self.startBlock(anchor) + + def visitDictComp(self, node): + self.set_lineno(node) + # setup list + self.emit('BUILD_MAP', 0) + + stack = [] + for i, for_ in zip(range(len(node.quals)), node.quals): + start, anchor = self.visit(for_) + cont = None + for if_ in for_.ifs: + if cont is None: + cont = self.newBlock() + self.visit(if_, cont) + stack.insert(0, (start, cont, anchor)) + + self.visit(node.value) + self.visit(node.key) + self.emit('MAP_ADD', len(node.quals) + 1) + + for start, cont, anchor in stack: + if cont: + self.nextBlock(cont) + self.emit('JUMP_ABSOLUTE', start) + self.startBlock(anchor) + def visitListCompFor(self, node): start = self.newBlock() anchor = self.newBlock() diff --git a/Lib/compiler/transformer.py b/Lib/compiler/transformer.py index 816f13b..d4f4613 100644 --- a/Lib/compiler/transformer.py +++ b/Lib/compiler/transformer.py @@ -581,8 +581,10 @@ class Transformer: testlist1 = testlist exprlist = testlist - def testlist_gexp(self, nodelist): - if len(nodelist) == 2 and nodelist[1][0] == symbol.gen_for: + def testlist_comp(self, nodelist): + # test ( comp_for | (',' test)* [','] ) + assert nodelist[0][0] == symbol.test + if len(nodelist) == 2 and nodelist[1][0] == symbol.comp_for: test = self.com_node(nodelist[0]) return self.com_generator_expression(test, nodelist[1]) return self.testlist(nodelist) @@ -1001,7 +1003,7 @@ class Transformer: # loop to avoid trivial recursion while 1: t = node[0] - if t in (symbol.exprlist, symbol.testlist, symbol.testlist_safe, symbol.testlist_gexp): + if t in (symbol.exprlist, symbol.testlist, symbol.testlist_safe, symbol.testlist_comp): if len(node) > 2: return self.com_assign_tuple(node, assigning) node = node[1] @@ -1099,116 +1101,138 @@ class Transformer: else: stmts.append(result) - if hasattr(symbol, 'list_for'): - def com_list_constructor(self, nodelist): - # listmaker: test ( list_for | (',' test)* [','] ) - values = [] - for i in range(1, len(nodelist)): - if nodelist[i][0] == symbol.list_for: - assert len(nodelist[i:]) == 1 - return self.com_list_comprehension(values[0], - nodelist[i]) - elif nodelist[i][0] == token.COMMA: - continue - values.append(self.com_node(nodelist[i])) - return List(values, lineno=values[0].lineno) - - def com_list_comprehension(self, expr, node): - # list_iter: list_for | list_if - # list_for: 'for' exprlist 'in' testlist [list_iter] - # list_if: 'if' test [list_iter] - - # XXX should raise SyntaxError for assignment - - lineno = node[1][2] - fors = [] - while node: - t = node[1][1] - if t == 'for': - assignNode = self.com_assign(node[2], OP_ASSIGN) - listNode = self.com_node(node[4]) - newfor = ListCompFor(assignNode, listNode, []) - newfor.lineno = node[1][2] - fors.append(newfor) - if len(node) == 5: - node = None - else: - node = self.com_list_iter(node[5]) - elif t == 'if': - test = self.com_node(node[2]) - newif = ListCompIf(test, lineno=node[1][2]) - newfor.ifs.append(newif) - if len(node) == 3: - node = None - else: - node = self.com_list_iter(node[3]) + def com_list_constructor(self, nodelist): + # listmaker: test ( list_for | (',' test)* [','] ) + values = [] + for i in range(1, len(nodelist)): + if nodelist[i][0] == symbol.list_for: + assert len(nodelist[i:]) == 1 + return self.com_list_comprehension(values[0], + nodelist[i]) + elif nodelist[i][0] == token.COMMA: + continue + values.append(self.com_node(nodelist[i])) + return List(values, lineno=values[0].lineno) + + def com_list_comprehension(self, expr, node): + return self.com_comprehension(expr, None, node, 'list') + + def com_comprehension(self, expr1, expr2, node, type): + # list_iter: list_for | list_if + # list_for: 'for' exprlist 'in' testlist [list_iter] + # list_if: 'if' test [list_iter] + + # XXX should raise SyntaxError for assignment + # XXX(avassalotti) Set and dict comprehensions should have generator + # semantics. In other words, they shouldn't leak + # variables outside of the comprehension's scope. + + lineno = node[1][2] + fors = [] + while node: + t = node[1][1] + if t == 'for': + assignNode = self.com_assign(node[2], OP_ASSIGN) + compNode = self.com_node(node[4]) + newfor = ListCompFor(assignNode, compNode, []) + newfor.lineno = node[1][2] + fors.append(newfor) + if len(node) == 5: + node = None + elif type == 'list': + node = self.com_list_iter(node[5]) else: - raise SyntaxError, \ - ("unexpected list comprehension element: %s %d" - % (node, lineno)) - return ListComp(expr, fors, lineno=lineno) - - def com_list_iter(self, node): - assert node[0] == symbol.list_iter - return node[1] - else: - def com_list_constructor(self, nodelist): - values = [] - for i in range(1, len(nodelist), 2): - values.append(self.com_node(nodelist[i])) - return List(values, lineno=values[0].lineno) - - if hasattr(symbol, 'gen_for'): - def com_generator_expression(self, expr, node): - # gen_iter: gen_for | gen_if - # gen_for: 'for' exprlist 'in' test [gen_iter] - # gen_if: 'if' test [gen_iter] - - lineno = node[1][2] - fors = [] - while node: - t = node[1][1] - if t == 'for': - assignNode = self.com_assign(node[2], OP_ASSIGN) - genNode = self.com_node(node[4]) - newfor = GenExprFor(assignNode, genNode, [], - lineno=node[1][2]) - fors.append(newfor) - if (len(node)) == 5: - node = None - else: - node = self.com_gen_iter(node[5]) - elif t == 'if': - test = self.com_node(node[2]) - newif = GenExprIf(test, lineno=node[1][2]) - newfor.ifs.append(newif) - if len(node) == 3: - node = None - else: - node = self.com_gen_iter(node[3]) + node = self.com_comp_iter(node[5]) + elif t == 'if': + test = self.com_node(node[2]) + newif = ListCompIf(test, lineno=node[1][2]) + newfor.ifs.append(newif) + if len(node) == 3: + node = None + elif type == 'list': + node = self.com_list_iter(node[3]) else: - raise SyntaxError, \ - ("unexpected generator expression element: %s %d" - % (node, lineno)) - fors[0].is_outmost = True - return GenExpr(GenExprInner(expr, fors), lineno=lineno) + node = self.com_comp_iter(node[3]) + else: + raise SyntaxError, \ + ("unexpected comprehension element: %s %d" + % (node, lineno)) + if type == 'list': + return ListComp(expr1, fors, lineno=lineno) + elif type == 'set': + return SetComp(expr1, fors, lineno=lineno) + elif type == 'dict': + return DictComp(expr1, expr2, fors, lineno=lineno) + else: + raise ValueError("unexpected comprehension type: " + repr(type)) + + def com_list_iter(self, node): + assert node[0] == symbol.list_iter + return node[1] + + def com_comp_iter(self, node): + assert node[0] == symbol.comp_iter + return node[1] - def com_gen_iter(self, node): - assert node[0] == symbol.gen_iter - return node[1] + def com_generator_expression(self, expr, node): + # comp_iter: comp_for | comp_if + # comp_for: 'for' exprlist 'in' test [comp_iter] + # comp_if: 'if' test [comp_iter] + + lineno = node[1][2] + fors = [] + while node: + t = node[1][1] + if t == 'for': + assignNode = self.com_assign(node[2], OP_ASSIGN) + genNode = self.com_node(node[4]) + newfor = GenExprFor(assignNode, genNode, [], + lineno=node[1][2]) + fors.append(newfor) + if (len(node)) == 5: + node = None + else: + node = self.com_comp_iter(node[5]) + elif t == 'if': + test = self.com_node(node[2]) + newif = GenExprIf(test, lineno=node[1][2]) + newfor.ifs.append(newif) + if len(node) == 3: + node = None + else: + node = self.com_comp_iter(node[3]) + else: + raise SyntaxError, \ + ("unexpected generator expression element: %s %d" + % (node, lineno)) + fors[0].is_outmost = True + return GenExpr(GenExprInner(expr, fors), lineno=lineno) def com_dictorsetmaker(self, nodelist): - # dictorsetmaker: ( (test ':' test (',' test ':' test)* [',']) | - # (test (',' test)* [',']) ) + # dictorsetmaker: ( (test ':' test (comp_for | (',' test ':' test)* [','])) | + # (test (comp_for | (',' test)* [','])) ) assert nodelist[0] == symbol.dictorsetmaker - if len(nodelist) == 2 or nodelist[2][0] == token.COMMA: + nodelist = nodelist[1:] + if len(nodelist) == 1 or nodelist[1][0] == token.COMMA: + # set literal items = [] - for i in range(1, len(nodelist), 2): + for i in range(0, len(nodelist), 2): items.append(self.com_node(nodelist[i])) return Set(items, lineno=items[0].lineno) + elif nodelist[1][0] == symbol.comp_for: + # set comprehension + expr = self.com_node(nodelist[0]) + return self.com_comprehension(expr, None, nodelist[1], 'set') + elif len(nodelist) > 3 and nodelist[3][0] == symbol.comp_for: + # dict comprehension + assert nodelist[1][0] == token.COLON + key = self.com_node(nodelist[0]) + value = self.com_node(nodelist[2]) + return self.com_comprehension(key, value, nodelist[3], 'dict') else: + # dict literal items = [] - for i in range(1, len(nodelist), 4): + for i in range(0, len(nodelist), 4): items.append((self.com_node(nodelist[i]), self.com_node(nodelist[i+2]))) return Dict(items, lineno=items[0][0].lineno) @@ -1257,7 +1281,7 @@ class Transformer: kw, result = self.com_argument(node, kw, star_node) if len_nodelist != 2 and isinstance(result, GenExpr) \ - and len(node) == 3 and node[2][0] == symbol.gen_for: + and len(node) == 3 and node[2][0] == symbol.comp_for: # allow f(x for x in y), but reject f(x for x in y, 1) # should use f((x for x in y), 1) instead of f(x for x in y, 1) raise SyntaxError, 'generator expression needs parenthesis' @@ -1269,7 +1293,7 @@ class Transformer: lineno=extractLineNo(nodelist)) def com_argument(self, nodelist, kw, star_node): - if len(nodelist) == 3 and nodelist[2][0] == symbol.gen_for: + if len(nodelist) == 3 and nodelist[2][0] == symbol.comp_for: test = self.com_node(nodelist[1]) return 0, self.com_generator_expression(test, nodelist[2]) if len(nodelist) == 2: |