summaryrefslogtreecommitdiffstats
path: root/Lib/ast.py
diff options
context:
space:
mode:
authorBatuhan Taskaya <batuhanosmantaskaya@gmail.com>2020-05-16 23:04:12 (GMT)
committerGitHub <noreply@github.com>2020-05-16 23:04:12 (GMT)
commitdff92bb31f7db1a80ac431811f8108bd0ef9be43 (patch)
tree89f9511753cbf1f7b91ef4f16e86fce4fa4ac1bd /Lib/ast.py
parente966af7cff78e14e1d289db587433504b4b53533 (diff)
downloadcpython-dff92bb31f7db1a80ac431811f8108bd0ef9be43.zip
cpython-dff92bb31f7db1a80ac431811f8108bd0ef9be43.tar.gz
cpython-dff92bb31f7db1a80ac431811f8108bd0ef9be43.tar.bz2
bpo-38870: Implement round tripping support for typed AST in ast.unparse (GH-17797)
Diffstat (limited to 'Lib/ast.py')
-rw-r--r--Lib/ast.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/Lib/ast.py b/Lib/ast.py
index 5d0171f..61fbe03 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -648,6 +648,7 @@ class _Unparser(NodeVisitor):
self._source = []
self._buffer = []
self._precedences = {}
+ self._type_ignores = {}
self._indent = 0
def interleave(self, inter, f, seq):
@@ -697,11 +698,15 @@ class _Unparser(NodeVisitor):
return value
@contextmanager
- def block(self):
+ def block(self, *, extra = None):
"""A context manager for preparing the source for blocks. It adds
the character':', increases the indentation on enter and decreases
- the indentation on exit."""
+ the indentation on exit. If *extra* is given, it will be directly
+ appended after the colon character.
+ """
self.write(":")
+ if extra:
+ self.write(extra)
self._indent += 1
yield
self._indent -= 1
@@ -748,6 +753,11 @@ class _Unparser(NodeVisitor):
if isinstance(node, Constant) and isinstance(node.value, str):
return node
+ def get_type_comment(self, node):
+ comment = self._type_ignores.get(node.lineno) or node.type_comment
+ if comment is not None:
+ return f" # type: {comment}"
+
def traverse(self, node):
if isinstance(node, list):
for item in node:
@@ -770,7 +780,12 @@ class _Unparser(NodeVisitor):
self.traverse(node.body)
def visit_Module(self, node):
+ self._type_ignores = {
+ ignore.lineno: f"ignore{ignore.tag}"
+ for ignore in node.type_ignores
+ }
self._write_docstring_and_traverse_body(node)
+ self._type_ignores.clear()
def visit_FunctionType(self, node):
with self.delimit("(", ")"):
@@ -811,6 +826,8 @@ class _Unparser(NodeVisitor):
self.traverse(target)
self.write(" = ")
self.traverse(node.value)
+ if type_comment := self.get_type_comment(node):
+ self.write(type_comment)
def visit_AugAssign(self, node):
self.fill()
@@ -966,7 +983,7 @@ class _Unparser(NodeVisitor):
if node.returns:
self.write(" -> ")
self.traverse(node.returns)
- with self.block():
+ with self.block(extra=self.get_type_comment(node)):
self._write_docstring_and_traverse_body(node)
def visit_For(self, node):
@@ -980,7 +997,7 @@ class _Unparser(NodeVisitor):
self.traverse(node.target)
self.write(" in ")
self.traverse(node.iter)
- with self.block():
+ with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)
if node.orelse:
self.fill("else")
@@ -1018,13 +1035,13 @@ class _Unparser(NodeVisitor):
def visit_With(self, node):
self.fill("with ")
self.interleave(lambda: self.write(", "), self.traverse, node.items)
- with self.block():
+ with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)
def visit_AsyncWith(self, node):
self.fill("async with ")
self.interleave(lambda: self.write(", "), self.traverse, node.items)
- with self.block():
+ with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)
def visit_JoinedStr(self, node):