summaryrefslogtreecommitdiffstats
path: root/Lib/ast.py
diff options
context:
space:
mode:
authorBatuhan Taskaya <isidentical@gmail.com>2021-05-08 23:32:04 (GMT)
committerGitHub <noreply@github.com>2021-05-08 23:32:04 (GMT)
commit3d98ececda1335c7ed2a6c6a2b0d3bb46f2d3c55 (patch)
tree06e8fc589778b1456e87080873238e64f445dd45 /Lib/ast.py
parenta0bd9e9c11f5f52c7ddd19144c8230da016b53c6 (diff)
downloadcpython-3d98ececda1335c7ed2a6c6a2b0d3bb46f2d3c55.zip
cpython-3d98ececda1335c7ed2a6c6a2b0d3bb46f2d3c55.tar.gz
cpython-3d98ececda1335c7ed2a6c6a2b0d3bb46f2d3c55.tar.bz2
bpo-43417: Better buffer handling for ast.unparse (GH-24772)
Diffstat (limited to 'Lib/ast.py')
-rw-r--r--Lib/ast.py116
1 files changed, 59 insertions, 57 deletions
diff --git a/Lib/ast.py b/Lib/ast.py
index 66bcee8..18163d6 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -678,7 +678,6 @@ class _Unparser(NodeVisitor):
def __init__(self, *, _avoid_backslashes=False):
self._source = []
- self._buffer = []
self._precedences = {}
self._type_ignores = {}
self._indent = 0
@@ -721,14 +720,15 @@ class _Unparser(NodeVisitor):
"""Append a piece of text"""
self._source.append(text)
- def buffer_writer(self, text):
- self._buffer.append(text)
+ @contextmanager
+ def buffered(self, buffer = None):
+ if buffer is None:
+ buffer = []
- @property
- def buffer(self):
- value = "".join(self._buffer)
- self._buffer.clear()
- return value
+ original_source = self._source
+ self._source = buffer
+ yield buffer
+ self._source = original_source
@contextmanager
def block(self, *, extra = None):
@@ -1127,9 +1127,9 @@ class _Unparser(NodeVisitor):
def visit_JoinedStr(self, node):
self.write("f")
if self._avoid_backslashes:
- self._fstring_JoinedStr(node, self.buffer_writer)
- self._write_str_avoiding_backslashes(self.buffer)
- return
+ with self.buffered() as buffer:
+ self._write_fstring_inner(node)
+ return self._write_str_avoiding_backslashes("".join(buffer))
# If we don't need to avoid backslashes globally (i.e., we only need
# to avoid them inside FormattedValues), it's cosmetically preferred
@@ -1137,60 +1137,62 @@ class _Unparser(NodeVisitor):
# for cases like: f"{x}\n". To accomplish this, we keep track of what
# in our buffer corresponds to FormattedValues and what corresponds to
# Constant parts of the f-string, and allow escapes accordingly.
- buffer = []
+ fstring_parts = []
for value in node.values:
- meth = getattr(self, "_fstring_" + type(value).__name__)
- meth(value, self.buffer_writer)
- buffer.append((self.buffer, isinstance(value, Constant)))
- new_buffer = []
- quote_types = _ALL_QUOTES
- for value, is_constant in buffer:
- # Repeatedly narrow down the list of possible quote_types
+ with self.buffered() as buffer:
+ self._write_fstring_inner(value)
+ fstring_parts.append(
+ ("".join(buffer), isinstance(value, Constant))
+ )
+
+ new_fstring_parts = []
+ quote_types = list(_ALL_QUOTES)
+ for value, is_constant in fstring_parts:
value, quote_types = self._str_literal_helper(
- value, quote_types=quote_types,
- escape_special_whitespace=is_constant
+ value,
+ quote_types=quote_types,
+ escape_special_whitespace=is_constant,
)
- new_buffer.append(value)
- value = "".join(new_buffer)
+ new_fstring_parts.append(value)
+
+ value = "".join(new_fstring_parts)
quote_type = quote_types[0]
self.write(f"{quote_type}{value}{quote_type}")
+ def _write_fstring_inner(self, node):
+ if isinstance(node, JoinedStr):
+ # for both the f-string itself, and format_spec
+ for value in node.values:
+ self._write_fstring_inner(value)
+ elif isinstance(node, Constant) and isinstance(node.value, str):
+ value = node.value.replace("{", "{{").replace("}", "}}")
+ self.write(value)
+ elif isinstance(node, FormattedValue):
+ self.visit_FormattedValue(node)
+ else:
+ raise ValueError(f"Unexpected node inside JoinedStr, {node!r}")
+
def visit_FormattedValue(self, node):
- self.write("f")
- self._fstring_FormattedValue(node, self.buffer_writer)
- self._write_str_avoiding_backslashes(self.buffer)
+ def unparse_inner(inner):
+ unparser = type(self)(_avoid_backslashes=True)
+ unparser.set_precedence(_Precedence.TEST.next(), inner)
+ return unparser.visit(inner)
- def _fstring_JoinedStr(self, node, write):
- for value in node.values:
- meth = getattr(self, "_fstring_" + type(value).__name__)
- meth(value, write)
-
- def _fstring_Constant(self, node, write):
- if not isinstance(node.value, str):
- raise ValueError("Constants inside JoinedStr should be a string.")
- value = node.value.replace("{", "{{").replace("}", "}}")
- write(value)
-
- def _fstring_FormattedValue(self, node, write):
- write("{")
- unparser = type(self)(_avoid_backslashes=True)
- unparser.set_precedence(_Precedence.TEST.next(), node.value)
- expr = unparser.visit(node.value)
- if expr.startswith("{"):
- write(" ") # Separate pair of opening brackets as "{ {"
- if "\\" in expr:
- raise ValueError("Unable to avoid backslash in f-string expression part")
- write(expr)
- if node.conversion != -1:
- conversion = chr(node.conversion)
- if conversion not in "sra":
- raise ValueError("Unknown f-string conversion.")
- write(f"!{conversion}")
- if node.format_spec:
- write(":")
- meth = getattr(self, "_fstring_" + type(node.format_spec).__name__)
- meth(node.format_spec, write)
- write("}")
+ with self.delimit("{", "}"):
+ expr = unparse_inner(node.value)
+ if "\\" in expr:
+ raise ValueError(
+ "Unable to avoid backslash in f-string expression part"
+ )
+ if expr.startswith("{"):
+ # Separate pair of opening brackets as "{ {"
+ self.write(" ")
+ self.write(expr)
+ if node.conversion != -1:
+ self.write(f"!{chr(node.conversion)}")
+ if node.format_spec:
+ self.write(":")
+ self._write_fstring_inner(node.format_spec)
def visit_Name(self, node):
self.write(node.id)