From 89aa4694fc8c6d190325ef8ed6ce6a6b8efb3e50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Batuhan=20Ta=C5=9Fkaya?= <47358913+isidentical@users.noreply.github.com> Date: Mon, 2 Mar 2020 21:59:01 +0300 Subject: bpo-38870: Add docstring support to ast.unparse (GH-17760) Allow ast.unparse to detect docstrings in functions, modules and classes and produce nicely formatted unparsed output for said docstrings. Co-Authored-By: Pablo Galindo --- Lib/ast.py | 55 ++++++++++--- Lib/test/test_unparse.py | 196 ++++++++++++++++++++++++++++++----------------- 2 files changed, 171 insertions(+), 80 deletions(-) diff --git a/Lib/ast.py b/Lib/ast.py index 4839201..93ffa1e 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -667,6 +667,22 @@ class _Unparser(NodeVisitor): for node in nodes: self._precedences[node] = precedence + def get_raw_docstring(self, node): + """If a docstring node is found in the body of the *node* parameter, + return that docstring node, None otherwise. + + Logic mirrored from ``_PyAST_GetDocString``.""" + if not isinstance( + node, (AsyncFunctionDef, FunctionDef, ClassDef, Module) + ) or len(node.body) < 1: + return None + node = node.body[0] + if not isinstance(node, Expr): + return None + node = node.value + if isinstance(node, Constant) and isinstance(node.value, str): + return node + def traverse(self, node): if isinstance(node, list): for item in node: @@ -681,9 +697,15 @@ class _Unparser(NodeVisitor): self.traverse(node) return "".join(self._source) + def _write_docstring_and_traverse_body(self, node): + if (docstring := self.get_raw_docstring(node)): + self._write_docstring(docstring) + self.traverse(node.body[1:]) + else: + self.traverse(node.body) + def visit_Module(self, node): - for subnode in node.body: - self.traverse(subnode) + self._write_docstring_and_traverse_body(node) def visit_Expr(self, node): self.fill() @@ -850,15 +872,15 @@ class _Unparser(NodeVisitor): self.traverse(e) with self.block(): - self.traverse(node.body) + self._write_docstring_and_traverse_body(node) def visit_FunctionDef(self, node): - self.__FunctionDef_helper(node, "def") + self._function_helper(node, "def") def visit_AsyncFunctionDef(self, node): - self.__FunctionDef_helper(node, "async def") + self._function_helper(node, "async def") - def __FunctionDef_helper(self, node, fill_suffix): + def _function_helper(self, node, fill_suffix): self.write("\n") for deco in node.decorator_list: self.fill("@") @@ -871,15 +893,15 @@ class _Unparser(NodeVisitor): self.write(" -> ") self.traverse(node.returns) with self.block(): - self.traverse(node.body) + self._write_docstring_and_traverse_body(node) def visit_For(self, node): - self.__For_helper("for ", node) + self._for_helper("for ", node) def visit_AsyncFor(self, node): - self.__For_helper("async for ", node) + self._for_helper("async for ", node) - def __For_helper(self, fill, node): + def _for_helper(self, fill, node): self.fill(fill) self.traverse(node.target) self.write(" in ") @@ -974,6 +996,19 @@ class _Unparser(NodeVisitor): def visit_Name(self, node): self.write(node.id) + def _write_docstring(self, node): + self.fill() + if node.kind == "u": + self.write("u") + + # Preserve quotes in the docstring by escaping them + value = node.value.replace("\\", "\\\\") + value = value.replace('"""', '""\"') + if value[-1] == '"': + value = value.replace('"', '\\"', -1) + + self.write(f'"""{value}"""') + def _write_constant(self, value): if isinstance(value, (float, complex)): # Substitute overflowing decimal literal for AST infinities. diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index f7fcb2b..d04db4d 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -111,12 +111,18 @@ with f() as x, g() as y: suite1 """ +docstring_prefixes = [ + "", + "class foo():\n ", + "def foo():\n ", + "async def foo():\n ", +] class ASTTestCase(unittest.TestCase): def assertASTEqual(self, ast1, ast2): self.assertEqual(ast.dump(ast1), ast.dump(ast2)) - def check_roundtrip(self, code1): + def check_ast_roundtrip(self, code1): ast1 = ast.parse(code1) code2 = ast.unparse(ast1) ast2 = ast.parse(code2) @@ -125,147 +131,154 @@ class ASTTestCase(unittest.TestCase): def check_invalid(self, node, raises=ValueError): self.assertRaises(raises, ast.unparse, node) - def check_src_roundtrip(self, code1, code2=None, strip=True): + def get_source(self, code1, code2=None, strip=True): code2 = code2 or code1 code1 = ast.unparse(ast.parse(code1)) if strip: code1 = code1.strip() + return code1, code2 + + def check_src_roundtrip(self, code1, code2=None, strip=True): + code1, code2 = self.get_source(code1, code2, strip) self.assertEqual(code2, code1) + def check_src_dont_roundtrip(self, code1, code2=None, strip=True): + code1, code2 = self.get_source(code1, code2, strip) + self.assertNotEqual(code2, code1) class UnparseTestCase(ASTTestCase): # Tests for specific bugs found in earlier versions of unparse def test_fstrings(self): # See issue 25180 - self.check_roundtrip(r"""f'{f"{0}"*3}'""") - self.check_roundtrip(r"""f'{f"{y}"*3}'""") + self.check_ast_roundtrip(r"""f'{f"{0}"*3}'""") + self.check_ast_roundtrip(r"""f'{f"{y}"*3}'""") def test_strings(self): - self.check_roundtrip("u'foo'") - self.check_roundtrip("r'foo'") - self.check_roundtrip("b'foo'") + self.check_ast_roundtrip("u'foo'") + self.check_ast_roundtrip("r'foo'") + self.check_ast_roundtrip("b'foo'") def test_del_statement(self): - self.check_roundtrip("del x, y, z") + self.check_ast_roundtrip("del x, y, z") def test_shifts(self): - self.check_roundtrip("45 << 2") - self.check_roundtrip("13 >> 7") + self.check_ast_roundtrip("45 << 2") + self.check_ast_roundtrip("13 >> 7") def test_for_else(self): - self.check_roundtrip(for_else) + self.check_ast_roundtrip(for_else) def test_while_else(self): - self.check_roundtrip(while_else) + self.check_ast_roundtrip(while_else) def test_unary_parens(self): - self.check_roundtrip("(-1)**7") - self.check_roundtrip("(-1.)**8") - self.check_roundtrip("(-1j)**6") - self.check_roundtrip("not True or False") - self.check_roundtrip("True or not False") + self.check_ast_roundtrip("(-1)**7") + self.check_ast_roundtrip("(-1.)**8") + self.check_ast_roundtrip("(-1j)**6") + self.check_ast_roundtrip("not True or False") + self.check_ast_roundtrip("True or not False") def test_integer_parens(self): - self.check_roundtrip("3 .__abs__()") + self.check_ast_roundtrip("3 .__abs__()") def test_huge_float(self): - self.check_roundtrip("1e1000") - self.check_roundtrip("-1e1000") - self.check_roundtrip("1e1000j") - self.check_roundtrip("-1e1000j") + self.check_ast_roundtrip("1e1000") + self.check_ast_roundtrip("-1e1000") + self.check_ast_roundtrip("1e1000j") + self.check_ast_roundtrip("-1e1000j") def test_min_int(self): - self.check_roundtrip(str(-(2 ** 31))) - self.check_roundtrip(str(-(2 ** 63))) + self.check_ast_roundtrip(str(-(2 ** 31))) + self.check_ast_roundtrip(str(-(2 ** 63))) def test_imaginary_literals(self): - self.check_roundtrip("7j") - self.check_roundtrip("-7j") - self.check_roundtrip("0j") - self.check_roundtrip("-0j") + self.check_ast_roundtrip("7j") + self.check_ast_roundtrip("-7j") + self.check_ast_roundtrip("0j") + self.check_ast_roundtrip("-0j") def test_lambda_parentheses(self): - self.check_roundtrip("(lambda: int)()") + self.check_ast_roundtrip("(lambda: int)()") def test_chained_comparisons(self): - self.check_roundtrip("1 < 4 <= 5") - self.check_roundtrip("a is b is c is not d") + self.check_ast_roundtrip("1 < 4 <= 5") + self.check_ast_roundtrip("a is b is c is not d") def test_function_arguments(self): - self.check_roundtrip("def f(): pass") - self.check_roundtrip("def f(a): pass") - self.check_roundtrip("def f(b = 2): pass") - self.check_roundtrip("def f(a, b): pass") - self.check_roundtrip("def f(a, b = 2): pass") - self.check_roundtrip("def f(a = 5, b = 2): pass") - self.check_roundtrip("def f(*, a = 1, b = 2): pass") - self.check_roundtrip("def f(*, a = 1, b): pass") - self.check_roundtrip("def f(*, a, b = 2): pass") - self.check_roundtrip("def f(a, b = None, *, c, **kwds): pass") - self.check_roundtrip("def f(a=2, *args, c=5, d, **kwds): pass") - self.check_roundtrip("def f(*args, **kwargs): pass") + self.check_ast_roundtrip("def f(): pass") + self.check_ast_roundtrip("def f(a): pass") + self.check_ast_roundtrip("def f(b = 2): pass") + self.check_ast_roundtrip("def f(a, b): pass") + self.check_ast_roundtrip("def f(a, b = 2): pass") + self.check_ast_roundtrip("def f(a = 5, b = 2): pass") + self.check_ast_roundtrip("def f(*, a = 1, b = 2): pass") + self.check_ast_roundtrip("def f(*, a = 1, b): pass") + self.check_ast_roundtrip("def f(*, a, b = 2): pass") + self.check_ast_roundtrip("def f(a, b = None, *, c, **kwds): pass") + self.check_ast_roundtrip("def f(a=2, *args, c=5, d, **kwds): pass") + self.check_ast_roundtrip("def f(*args, **kwargs): pass") def test_relative_import(self): - self.check_roundtrip(relative_import) + self.check_ast_roundtrip(relative_import) def test_nonlocal(self): - self.check_roundtrip(nonlocal_ex) + self.check_ast_roundtrip(nonlocal_ex) def test_raise_from(self): - self.check_roundtrip(raise_from) + self.check_ast_roundtrip(raise_from) def test_bytes(self): - self.check_roundtrip("b'123'") + self.check_ast_roundtrip("b'123'") def test_annotations(self): - self.check_roundtrip("def f(a : int): pass") - self.check_roundtrip("def f(a: int = 5): pass") - self.check_roundtrip("def f(*args: [int]): pass") - self.check_roundtrip("def f(**kwargs: dict): pass") - self.check_roundtrip("def f() -> None: pass") + self.check_ast_roundtrip("def f(a : int): pass") + self.check_ast_roundtrip("def f(a: int = 5): pass") + self.check_ast_roundtrip("def f(*args: [int]): pass") + self.check_ast_roundtrip("def f(**kwargs: dict): pass") + self.check_ast_roundtrip("def f() -> None: pass") def test_set_literal(self): - self.check_roundtrip("{'a', 'b', 'c'}") + self.check_ast_roundtrip("{'a', 'b', 'c'}") def test_set_comprehension(self): - self.check_roundtrip("{x for x in range(5)}") + self.check_ast_roundtrip("{x for x in range(5)}") def test_dict_comprehension(self): - self.check_roundtrip("{x: x*x for x in range(10)}") + self.check_ast_roundtrip("{x: x*x for x in range(10)}") def test_class_decorators(self): - self.check_roundtrip(class_decorator) + self.check_ast_roundtrip(class_decorator) def test_class_definition(self): - self.check_roundtrip("class A(metaclass=type, *[], **{}): pass") + self.check_ast_roundtrip("class A(metaclass=type, *[], **{}): pass") def test_elifs(self): - self.check_roundtrip(elif1) - self.check_roundtrip(elif2) + self.check_ast_roundtrip(elif1) + self.check_ast_roundtrip(elif2) def test_try_except_finally(self): - self.check_roundtrip(try_except_finally) + self.check_ast_roundtrip(try_except_finally) def test_starred_assignment(self): - self.check_roundtrip("a, *b, c = seq") - self.check_roundtrip("a, (*b, c) = seq") - self.check_roundtrip("a, *b[0], c = seq") - self.check_roundtrip("a, *(b, c) = seq") + self.check_ast_roundtrip("a, *b, c = seq") + self.check_ast_roundtrip("a, (*b, c) = seq") + self.check_ast_roundtrip("a, *b[0], c = seq") + self.check_ast_roundtrip("a, *(b, c) = seq") def test_with_simple(self): - self.check_roundtrip(with_simple) + self.check_ast_roundtrip(with_simple) def test_with_as(self): - self.check_roundtrip(with_as) + self.check_ast_roundtrip(with_as) def test_with_two_items(self): - self.check_roundtrip(with_two_items) + self.check_ast_roundtrip(with_two_items) def test_dict_unpacking_in_dict(self): # See issue 26489 - self.check_roundtrip(r"""{**{'y': 2}, 'x': 1}""") - self.check_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""") + self.check_ast_roundtrip(r"""{**{'y': 2}, 'x': 1}""") + self.check_ast_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""") def test_invalid_raise(self): self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X"))) @@ -288,6 +301,16 @@ class UnparseTestCase(ASTTestCase): def test_invalid_yield_from(self): self.check_invalid(ast.YieldFrom(value=None)) + def test_docstrings(self): + docstrings = ( + 'this ends with double quote"', + 'this includes a """triple quote"""' + ) + for docstring in docstrings: + # check as Module docstrings for easy testing + self.check_ast_roundtrip(f"'{docstring}'") + + class CosmeticTestCase(ASTTestCase): """Test if there are cosmetic issues caused by unnecesary additions""" @@ -321,6 +344,39 @@ class CosmeticTestCase(ASTTestCase): self.check_src_roundtrip("call((yield x))") self.check_src_roundtrip("return x + (yield x)") + def test_docstrings(self): + docstrings = ( + '"""simple doc string"""', + '''"""A more complex one + with some newlines"""''', + '''"""Foo bar baz + + empty newline"""''', + '"""With some \t"""', + '"""Foo "bar" baz """', + ) + + for prefix in docstring_prefixes: + for docstring in docstrings: + self.check_src_roundtrip(f"{prefix}{docstring}") + + def test_docstrings_negative_cases(self): + # Test some cases that involve strings in the children of the + # first node but aren't docstrings to make sure we don't have + # False positives. + docstrings_negative = ( + 'a = """false"""', + '"""false""" + """unless its optimized"""', + '1 + 1\n"""false"""', + 'f"""no, top level but f-fstring"""' + ) + for prefix in docstring_prefixes: + for negative in docstrings_negative: + # this cases should be result with single quote + # rather then triple quoted docstring + src = f"{prefix}{negative}" + self.check_ast_roundtrip(src) + self.check_src_dont_roundtrip(src) class DirectoryTestCase(ASTTestCase): """Test roundtrip behaviour on all files in Lib and Lib/test.""" @@ -379,7 +435,7 @@ class DirectoryTestCase(ASTTestCase): with self.subTest(filename=item): source = read_pyfile(item) - self.check_roundtrip(source) + self.check_ast_roundtrip(source) if __name__ == "__main__": -- cgit v0.12