summaryrefslogtreecommitdiffstats
path: root/Lib/ast.py
diff options
context:
space:
mode:
authorBatuhan Taşkaya <47358913+isidentical@users.noreply.github.com>2020-03-02 18:59:01 (GMT)
committerGitHub <noreply@github.com>2020-03-02 18:59:01 (GMT)
commit89aa4694fc8c6d190325ef8ed6ce6a6b8efb3e50 (patch)
treece9a506b9121d6188986725bbbef7a678321833c /Lib/ast.py
parent66b7973c1b2e6aa6a2462c6b13971a08cd665af2 (diff)
downloadcpython-89aa4694fc8c6d190325ef8ed6ce6a6b8efb3e50.zip
cpython-89aa4694fc8c6d190325ef8ed6ce6a6b8efb3e50.tar.gz
cpython-89aa4694fc8c6d190325ef8ed6ce6a6b8efb3e50.tar.bz2
bpo-38870: Add docstring support to ast.unparse (GH-17760)
Allow ast.unparse to detect docstrings in functions, modules and classes and produce nicely formatted unparsed output for said docstrings. Co-Authored-By: Pablo Galindo <Pablogsal@gmail.com>
Diffstat (limited to 'Lib/ast.py')
-rw-r--r--Lib/ast.py55
1 files changed, 45 insertions, 10 deletions
diff --git a/Lib/ast.py b/Lib/ast.py
index 4839201..93ffa1e 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -667,6 +667,22 @@ class _Unparser(NodeVisitor):
for node in nodes:
self._precedences[node] = precedence
+ def get_raw_docstring(self, node):
+ """If a docstring node is found in the body of the *node* parameter,
+ return that docstring node, None otherwise.
+
+ Logic mirrored from ``_PyAST_GetDocString``."""
+ if not isinstance(
+ node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)
+ ) or len(node.body) < 1:
+ return None
+ node = node.body[0]
+ if not isinstance(node, Expr):
+ return None
+ node = node.value
+ if isinstance(node, Constant) and isinstance(node.value, str):
+ return node
+
def traverse(self, node):
if isinstance(node, list):
for item in node:
@@ -681,9 +697,15 @@ class _Unparser(NodeVisitor):
self.traverse(node)
return "".join(self._source)
+ def _write_docstring_and_traverse_body(self, node):
+ if (docstring := self.get_raw_docstring(node)):
+ self._write_docstring(docstring)
+ self.traverse(node.body[1:])
+ else:
+ self.traverse(node.body)
+
def visit_Module(self, node):
- for subnode in node.body:
- self.traverse(subnode)
+ self._write_docstring_and_traverse_body(node)
def visit_Expr(self, node):
self.fill()
@@ -850,15 +872,15 @@ class _Unparser(NodeVisitor):
self.traverse(e)
with self.block():
- self.traverse(node.body)
+ self._write_docstring_and_traverse_body(node)
def visit_FunctionDef(self, node):
- self.__FunctionDef_helper(node, "def")
+ self._function_helper(node, "def")
def visit_AsyncFunctionDef(self, node):
- self.__FunctionDef_helper(node, "async def")
+ self._function_helper(node, "async def")
- def __FunctionDef_helper(self, node, fill_suffix):
+ def _function_helper(self, node, fill_suffix):
self.write("\n")
for deco in node.decorator_list:
self.fill("@")
@@ -871,15 +893,15 @@ class _Unparser(NodeVisitor):
self.write(" -> ")
self.traverse(node.returns)
with self.block():
- self.traverse(node.body)
+ self._write_docstring_and_traverse_body(node)
def visit_For(self, node):
- self.__For_helper("for ", node)
+ self._for_helper("for ", node)
def visit_AsyncFor(self, node):
- self.__For_helper("async for ", node)
+ self._for_helper("async for ", node)
- def __For_helper(self, fill, node):
+ def _for_helper(self, fill, node):
self.fill(fill)
self.traverse(node.target)
self.write(" in ")
@@ -974,6 +996,19 @@ class _Unparser(NodeVisitor):
def visit_Name(self, node):
self.write(node.id)
+ def _write_docstring(self, node):
+ self.fill()
+ if node.kind == "u":
+ self.write("u")
+
+ # Preserve quotes in the docstring by escaping them
+ value = node.value.replace("\\", "\\\\")
+ value = value.replace('"""', '""\"')
+ if value[-1] == '"':
+ value = value.replace('"', '\\"', -1)
+
+ self.write(f'"""{value}"""')
+
def _write_constant(self, value):
if isinstance(value, (float, complex)):
# Substitute overflowing decimal literal for AST infinities.