diff options
author | Batuhan Taşkaya <47358913+isidentical@users.noreply.github.com> | 2020-03-01 20:12:17 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-01 20:12:17 (GMT) |
commit | 397b96f6d7a89f778ebc0591e32216a8183fe667 (patch) | |
tree | bca7cb5940e1a5cb1cdc80b3578a238dfac318d1 | |
parent | 185903de12de8837bf0dc0008a16e5e56c66a019 (diff) | |
download | cpython-397b96f6d7a89f778ebc0591e32216a8183fe667.zip cpython-397b96f6d7a89f778ebc0591e32216a8183fe667.tar.gz cpython-397b96f6d7a89f778ebc0591e32216a8183fe667.tar.bz2 |
bpo-38870: Implement a precedence algorithm in ast.unparse (GH-17377)
Implement a simple precedence algorithm for ast.unparse in order to avoid redundant
parenthesis for nested structures in the final output.
-rw-r--r-- | Lib/ast.py | 138 | ||||
-rw-r--r-- | Lib/test/test_ast.py | 9 | ||||
-rw-r--r-- | Lib/test/test_unparse.py | 41 |
3 files changed, 172 insertions, 16 deletions
@@ -27,6 +27,7 @@ import sys from _ast import * from contextlib import contextmanager, nullcontext +from enum import IntEnum, auto def parse(source, filename='<unknown>', mode='exec', *, @@ -560,6 +561,35 @@ _const_node_type_names = { # We unparse those infinities to INFSTR. _INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) +class _Precedence(IntEnum): + """Precedence table that originated from python grammar.""" + + TUPLE = auto() + YIELD = auto() # 'yield', 'yield from' + TEST = auto() # 'if'-'else', 'lambda' + OR = auto() # 'or' + AND = auto() # 'and' + NOT = auto() # 'not' + CMP = auto() # '<', '>', '==', '>=', '<=', '!=', + # 'in', 'not in', 'is', 'is not' + EXPR = auto() + BOR = EXPR # '|' + BXOR = auto() # '^' + BAND = auto() # '&' + SHIFT = auto() # '<<', '>>' + ARITH = auto() # '+', '-' + TERM = auto() # '*', '@', '/', '%', '//' + FACTOR = auto() # unary '+', '-', '~' + POWER = auto() # '**' + AWAIT = auto() # 'await' + ATOM = auto() + + def next(self): + try: + return self.__class__(self + 1) + except ValueError: + return self + class _Unparser(NodeVisitor): """Methods in this class recursively traverse an AST and output source code for the abstract syntax; original formatting @@ -568,6 +598,7 @@ class _Unparser(NodeVisitor): def __init__(self): self._source = [] self._buffer = [] + self._precedences = {} self._indent = 0 def interleave(self, inter, f, seq): @@ -625,6 +656,17 @@ class _Unparser(NodeVisitor): else: return nullcontext() + def require_parens(self, precedence, node): + """Shortcut to adding precedence related parens""" + return self.delimit_if("(", ")", self.get_precedence(node) > precedence) + + def get_precedence(self, node): + return self._precedences.get(node, _Precedence.TEST) + + def set_precedence(self, precedence, *nodes): + for node in nodes: + self._precedences[node] = precedence + def traverse(self, node): if isinstance(node, list): for item in node: @@ -645,10 +687,12 @@ class _Unparser(NodeVisitor): def visit_Expr(self, node): self.fill() + self.set_precedence(_Precedence.YIELD, node.value) self.traverse(node.value) def visit_NamedExpr(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.TUPLE, node): + self.set_precedence(_Precedence.ATOM, node.target, node.value) self.traverse(node.target) self.write(" := ") self.traverse(node.value) @@ -723,24 +767,27 @@ class _Unparser(NodeVisitor): self.interleave(lambda: self.write(", "), self.write, node.names) def visit_Await(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.AWAIT, node): self.write("await") if node.value: self.write(" ") + self.set_precedence(_Precedence.ATOM, node.value) self.traverse(node.value) def visit_Yield(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.YIELD, node): self.write("yield") if node.value: self.write(" ") + self.set_precedence(_Precedence.ATOM, node.value) self.traverse(node.value) def visit_YieldFrom(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.YIELD, node): self.write("yield from ") if not node.value: raise ValueError("Node can't be used without a value attribute.") + self.set_precedence(_Precedence.ATOM, node.value) self.traverse(node.value) def visit_Raise(self, node): @@ -907,7 +954,9 @@ class _Unparser(NodeVisitor): def _fstring_FormattedValue(self, node, write): write("{") - expr = type(self)().visit(node.value).rstrip("\n") + unparser = type(self)() + unparser.set_precedence(_Precedence.TEST.next(), node.value) + expr = unparser.visit(node.value).rstrip("\n") if expr.startswith("{"): write(" ") # Separate pair of opening brackets as "{ {" write(expr) @@ -983,19 +1032,23 @@ class _Unparser(NodeVisitor): self.write(" async for ") else: self.write(" for ") + self.set_precedence(_Precedence.TUPLE, node.target) self.traverse(node.target) self.write(" in ") + self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs) self.traverse(node.iter) for if_clause in node.ifs: self.write(" if ") self.traverse(if_clause) def visit_IfExp(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.TEST, node): + self.set_precedence(_Precedence.TEST.next(), node.body, node.test) self.traverse(node.body) self.write(" if ") self.traverse(node.test) self.write(" else ") + self.set_precedence(_Precedence.TEST, node.orelse) self.traverse(node.orelse) def visit_Set(self, node): @@ -1016,6 +1069,7 @@ class _Unparser(NodeVisitor): # for dictionary unpacking operator in dicts {**{'y': 2}} # see PEP 448 for details self.write("**") + self.set_precedence(_Precedence.EXPR, v) self.traverse(v) else: write_key_value_pair(k, v) @@ -1035,11 +1089,20 @@ class _Unparser(NodeVisitor): self.interleave(lambda: self.write(", "), self.traverse, node.elts) unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} + unop_precedence = { + "~": _Precedence.FACTOR, + "not": _Precedence.NOT, + "+": _Precedence.FACTOR, + "-": _Precedence.FACTOR + } def visit_UnaryOp(self, node): - with self.delimit("(", ")"): - self.write(self.unop[node.op.__class__.__name__]) + operator = self.unop[node.op.__class__.__name__] + operator_precedence = self.unop_precedence[operator] + with self.require_parens(operator_precedence, node): + self.write(operator) self.write(" ") + self.set_precedence(operator_precedence, node.operand) self.traverse(node.operand) binop = { @@ -1058,10 +1121,38 @@ class _Unparser(NodeVisitor): "Pow": "**", } + binop_precedence = { + "+": _Precedence.ARITH, + "-": _Precedence.ARITH, + "*": _Precedence.TERM, + "@": _Precedence.TERM, + "/": _Precedence.TERM, + "%": _Precedence.TERM, + "<<": _Precedence.SHIFT, + ">>": _Precedence.SHIFT, + "|": _Precedence.BOR, + "^": _Precedence.BXOR, + "&": _Precedence.BAND, + "//": _Precedence.TERM, + "**": _Precedence.POWER, + } + + binop_rassoc = frozenset(("**",)) def visit_BinOp(self, node): - with self.delimit("(", ")"): + operator = self.binop[node.op.__class__.__name__] + operator_precedence = self.binop_precedence[operator] + with self.require_parens(operator_precedence, node): + if operator in self.binop_rassoc: + left_precedence = operator_precedence.next() + right_precedence = operator_precedence + else: + left_precedence = operator_precedence + right_precedence = operator_precedence.next() + + self.set_precedence(left_precedence, node.left) self.traverse(node.left) - self.write(" " + self.binop[node.op.__class__.__name__] + " ") + self.write(f" {operator} ") + self.set_precedence(right_precedence, node.right) self.traverse(node.right) cmpops = { @@ -1078,20 +1169,32 @@ class _Unparser(NodeVisitor): } def visit_Compare(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.CMP, node): + self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators) self.traverse(node.left) for o, e in zip(node.ops, node.comparators): self.write(" " + self.cmpops[o.__class__.__name__] + " ") self.traverse(e) boolops = {"And": "and", "Or": "or"} + boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR} def visit_BoolOp(self, node): - with self.delimit("(", ")"): - s = " %s " % self.boolops[node.op.__class__.__name__] - self.interleave(lambda: self.write(s), self.traverse, node.values) + operator = self.boolops[node.op.__class__.__name__] + operator_precedence = self.boolop_precedence[operator] + + def increasing_level_traverse(node): + nonlocal operator_precedence + operator_precedence = operator_precedence.next() + self.set_precedence(operator_precedence, node) + self.traverse(node) + + with self.require_parens(operator_precedence, node): + s = f" {operator} " + self.interleave(lambda: self.write(s), increasing_level_traverse, node.values) def visit_Attribute(self, node): + self.set_precedence(_Precedence.ATOM, node.value) self.traverse(node.value) # Special case: 3.__abs__() is a syntax error, so if node.value # is an integer literal then we need to either parenthesize @@ -1102,6 +1205,7 @@ class _Unparser(NodeVisitor): self.write(node.attr) def visit_Call(self, node): + self.set_precedence(_Precedence.ATOM, node.func) self.traverse(node.func) with self.delimit("(", ")"): comma = False @@ -1119,18 +1223,21 @@ class _Unparser(NodeVisitor): self.traverse(e) def visit_Subscript(self, node): + self.set_precedence(_Precedence.ATOM, node.value) self.traverse(node.value) with self.delimit("[", "]"): self.traverse(node.slice) def visit_Starred(self, node): self.write("*") + self.set_precedence(_Precedence.EXPR, node.value) self.traverse(node.value) def visit_Ellipsis(self, node): self.write("...") def visit_Index(self, node): + self.set_precedence(_Precedence.TUPLE, node.value) self.traverse(node.value) def visit_Slice(self, node): @@ -1212,10 +1319,11 @@ class _Unparser(NodeVisitor): self.traverse(node.value) def visit_Lambda(self, node): - with self.delimit("(", ")"): + with self.require_parens(_Precedence.TEST, node): self.write("lambda ") self.traverse(node.args) self.write(": ") + self.set_precedence(_Precedence.TEST, node.body) self.traverse(node.body) def visit_alias(self, node): diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 2ed4657..e788485 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -247,6 +247,13 @@ eval_tests = [ class AST_Tests(unittest.TestCase): + def _is_ast_node(self, name, node): + if not isinstance(node, type): + return False + if "ast" not in node.__module__: + return False + return name != 'AST' and name[0].isupper() + def _assertTrueorder(self, ast_node, parent_pos): if not isinstance(ast_node, ast.AST) or ast_node._fields is None: return @@ -335,7 +342,7 @@ class AST_Tests(unittest.TestCase): def test_field_attr_existence(self): for name, item in ast.__dict__.items(): - if isinstance(item, type) and name != 'AST' and name[0].isupper(): + if self._is_ast_node(name, item): x = item() if isinstance(x, ast.AST): self.assertEqual(type(x._fields), tuple) diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index e8b0d4b..f7fcb2b 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -125,6 +125,13 @@ class ASTTestCase(unittest.TestCase): def check_invalid(self, node, raises=ValueError): self.assertRaises(raises, ast.unparse, node) + def check_src_roundtrip(self, code1, code2=None, strip=True): + code2 = code2 or code1 + code1 = ast.unparse(ast.parse(code1)) + if strip: + code1 = code1.strip() + self.assertEqual(code2, code1) + class UnparseTestCase(ASTTestCase): # Tests for specific bugs found in earlier versions of unparse @@ -281,6 +288,40 @@ class UnparseTestCase(ASTTestCase): def test_invalid_yield_from(self): self.check_invalid(ast.YieldFrom(value=None)) +class CosmeticTestCase(ASTTestCase): + """Test if there are cosmetic issues caused by unnecesary additions""" + + def test_simple_expressions_parens(self): + self.check_src_roundtrip("(a := b)") + self.check_src_roundtrip("await x") + self.check_src_roundtrip("x if x else y") + self.check_src_roundtrip("lambda x: x") + self.check_src_roundtrip("1 + 1") + self.check_src_roundtrip("1 + 2 / 3") + self.check_src_roundtrip("(1 + 2) / 3") + self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2)") + self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2) ** 2") + self.check_src_roundtrip("~ x") + self.check_src_roundtrip("x and y") + self.check_src_roundtrip("x and y and z") + self.check_src_roundtrip("x and (y and x)") + self.check_src_roundtrip("(x and y) and z") + self.check_src_roundtrip("(x ** y) ** z ** q") + self.check_src_roundtrip("x >> y") + self.check_src_roundtrip("x << y") + self.check_src_roundtrip("x >> y and x >> z") + self.check_src_roundtrip("x + y - z * q ^ t ** k") + self.check_src_roundtrip("P * V if P and V else n * R * T") + self.check_src_roundtrip("lambda P, V, n: P * V == n * R * T") + self.check_src_roundtrip("flag & (other | foo)") + self.check_src_roundtrip("not x == y") + self.check_src_roundtrip("x == (not y)") + self.check_src_roundtrip("yield x") + self.check_src_roundtrip("yield from x") + self.check_src_roundtrip("call((yield x))") + self.check_src_roundtrip("return x + (yield x)") + + class DirectoryTestCase(ASTTestCase): """Test roundtrip behaviour on all files in Lib and Lib/test.""" |