From e7cab7f780ac253999512ee86374fc3454342811 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Batuhan=20Ta=C5=9Fkaya?= <47358913+isidentical@users.noreply.github.com> Date: Mon, 9 Mar 2020 23:27:03 +0300 Subject: bpo-38870: Simplify sequence interleaves in ast.unparse (GH-17892) --- Lib/ast.py | 37 ++++++++++++++----------------------- Lib/test/test_unparse.py | 34 ++++++++++++++++++++-------------- 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/Lib/ast.py b/Lib/ast.py index 2719f6f..9a3d380 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -613,6 +613,16 @@ class _Unparser(NodeVisitor): inter() f(x) + def items_view(self, traverser, items): + """Traverse and separate the given *items* with a comma and append it to + the buffer. If *items* is a single item sequence, a trailing comma + will be added.""" + if len(items) == 1: + traverser(items[0]) + self.write(",") + else: + self.interleave(lambda: self.write(", "), traverser, items) + def fill(self, text=""): """Indent a piece of text and append it, according to the current indentation level""" @@ -1020,11 +1030,7 @@ class _Unparser(NodeVisitor): value = node.value if isinstance(value, tuple): with self.delimit("(", ")"): - if len(value) == 1: - self._write_constant(value[0]) - self.write(",") - else: - self.interleave(lambda: self.write(", "), self._write_constant, value) + self.items_view(self._write_constant, value) elif value is ...: self.write("...") else: @@ -1116,12 +1122,7 @@ class _Unparser(NodeVisitor): def visit_Tuple(self, node): with self.delimit("(", ")"): - if len(node.elts) == 1: - elt = node.elts[0] - self.traverse(elt) - self.write(",") - else: - self.interleave(lambda: self.write(", "), self.traverse, node.elts) + self.items_view(self.traverse, node.elts) unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} unop_precedence = { @@ -1264,12 +1265,7 @@ class _Unparser(NodeVisitor): if (isinstance(node.slice, Index) and isinstance(node.slice.value, Tuple) and node.slice.value.elts): - if len(node.slice.value.elts) == 1: - elt = node.slice.value.elts[0] - self.traverse(elt) - self.write(",") - else: - self.interleave(lambda: self.write(", "), self.traverse, node.slice.value.elts) + self.items_view(self.traverse, node.slice.value.elts) else: self.traverse(node.slice) @@ -1296,12 +1292,7 @@ class _Unparser(NodeVisitor): self.traverse(node.step) def visit_ExtSlice(self, node): - if len(node.dims) == 1: - elt = node.dims[0] - self.traverse(elt) - self.write(",") - else: - self.interleave(lambda: self.write(", "), self.traverse, node.dims) + self.items_view(self.traverse, node.dims) def visit_arg(self, node): self.write(node.arg) diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index d33f32e..3d87cfb 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -280,6 +280,20 @@ class UnparseTestCase(ASTTestCase): self.check_ast_roundtrip(r"""{**{'y': 2}, 'x': 1}""") self.check_ast_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""") + def test_ext_slices(self): + self.check_ast_roundtrip("a[i]") + self.check_ast_roundtrip("a[i,]") + self.check_ast_roundtrip("a[i, j]") + self.check_ast_roundtrip("a[()]") + self.check_ast_roundtrip("a[i:j]") + self.check_ast_roundtrip("a[:j]") + self.check_ast_roundtrip("a[i:]") + self.check_ast_roundtrip("a[i:j:k]") + self.check_ast_roundtrip("a[:j:k]") + self.check_ast_roundtrip("a[i::k]") + self.check_ast_roundtrip("a[i:j,]") + self.check_ast_roundtrip("a[i:j, k]") + def test_invalid_raise(self): self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X"))) @@ -310,6 +324,12 @@ class UnparseTestCase(ASTTestCase): # check as Module docstrings for easy testing self.check_ast_roundtrip(f"'{docstring}'") + def test_constant_tuples(self): + self.check_src_roundtrip(ast.Constant(value=(1,), kind=None), "(1,)") + self.check_src_roundtrip( + ast.Constant(value=(1, 2, 3), kind=None), "(1, 2, 3)" + ) + class CosmeticTestCase(ASTTestCase): """Test if there are cosmetic issues caused by unnecesary additions""" @@ -344,20 +364,6 @@ class CosmeticTestCase(ASTTestCase): self.check_src_roundtrip("call((yield x))") self.check_src_roundtrip("return x + (yield x)") - def test_subscript(self): - self.check_src_roundtrip("a[i]") - self.check_src_roundtrip("a[i,]") - self.check_src_roundtrip("a[i, j]") - self.check_src_roundtrip("a[()]") - self.check_src_roundtrip("a[i:j]") - self.check_src_roundtrip("a[:j]") - self.check_src_roundtrip("a[i:]") - self.check_src_roundtrip("a[i:j:k]") - self.check_src_roundtrip("a[:j:k]") - self.check_src_roundtrip("a[i::k]") - self.check_src_roundtrip("a[i:j,]") - self.check_src_roundtrip("a[i:j, k]") - def test_docstrings(self): docstrings = ( '"""simple doc string"""', -- cgit v0.12