diff options
author | Batuhan Taskaya <isidentical@gmail.com> | 2021-05-08 23:32:04 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-08 23:32:04 (GMT) |
commit | 3d98ececda1335c7ed2a6c6a2b0d3bb46f2d3c55 (patch) | |
tree | 06e8fc589778b1456e87080873238e64f445dd45 /Lib/ast.py | |
parent | a0bd9e9c11f5f52c7ddd19144c8230da016b53c6 (diff) | |
download | cpython-3d98ececda1335c7ed2a6c6a2b0d3bb46f2d3c55.zip cpython-3d98ececda1335c7ed2a6c6a2b0d3bb46f2d3c55.tar.gz cpython-3d98ececda1335c7ed2a6c6a2b0d3bb46f2d3c55.tar.bz2 |
bpo-43417: Better buffer handling for ast.unparse (GH-24772)
Diffstat (limited to 'Lib/ast.py')
-rw-r--r-- | Lib/ast.py | 116 |
1 files changed, 59 insertions, 57 deletions
@@ -678,7 +678,6 @@ class _Unparser(NodeVisitor): def __init__(self, *, _avoid_backslashes=False): self._source = [] - self._buffer = [] self._precedences = {} self._type_ignores = {} self._indent = 0 @@ -721,14 +720,15 @@ class _Unparser(NodeVisitor): """Append a piece of text""" self._source.append(text) - def buffer_writer(self, text): - self._buffer.append(text) + @contextmanager + def buffered(self, buffer = None): + if buffer is None: + buffer = [] - @property - def buffer(self): - value = "".join(self._buffer) - self._buffer.clear() - return value + original_source = self._source + self._source = buffer + yield buffer + self._source = original_source @contextmanager def block(self, *, extra = None): @@ -1127,9 +1127,9 @@ class _Unparser(NodeVisitor): def visit_JoinedStr(self, node): self.write("f") if self._avoid_backslashes: - self._fstring_JoinedStr(node, self.buffer_writer) - self._write_str_avoiding_backslashes(self.buffer) - return + with self.buffered() as buffer: + self._write_fstring_inner(node) + return self._write_str_avoiding_backslashes("".join(buffer)) # If we don't need to avoid backslashes globally (i.e., we only need # to avoid them inside FormattedValues), it's cosmetically preferred @@ -1137,60 +1137,62 @@ class _Unparser(NodeVisitor): # for cases like: f"{x}\n". To accomplish this, we keep track of what # in our buffer corresponds to FormattedValues and what corresponds to # Constant parts of the f-string, and allow escapes accordingly. - buffer = [] + fstring_parts = [] for value in node.values: - meth = getattr(self, "_fstring_" + type(value).__name__) - meth(value, self.buffer_writer) - buffer.append((self.buffer, isinstance(value, Constant))) - new_buffer = [] - quote_types = _ALL_QUOTES - for value, is_constant in buffer: - # Repeatedly narrow down the list of possible quote_types + with self.buffered() as buffer: + self._write_fstring_inner(value) + fstring_parts.append( + ("".join(buffer), isinstance(value, Constant)) + ) + + new_fstring_parts = [] + quote_types = list(_ALL_QUOTES) + for value, is_constant in fstring_parts: value, quote_types = self._str_literal_helper( - value, quote_types=quote_types, - escape_special_whitespace=is_constant + value, + quote_types=quote_types, + escape_special_whitespace=is_constant, ) - new_buffer.append(value) - value = "".join(new_buffer) + new_fstring_parts.append(value) + + value = "".join(new_fstring_parts) quote_type = quote_types[0] self.write(f"{quote_type}{value}{quote_type}") + def _write_fstring_inner(self, node): + if isinstance(node, JoinedStr): + # for both the f-string itself, and format_spec + for value in node.values: + self._write_fstring_inner(value) + elif isinstance(node, Constant) and isinstance(node.value, str): + value = node.value.replace("{", "{{").replace("}", "}}") + self.write(value) + elif isinstance(node, FormattedValue): + self.visit_FormattedValue(node) + else: + raise ValueError(f"Unexpected node inside JoinedStr, {node!r}") + def visit_FormattedValue(self, node): - self.write("f") - self._fstring_FormattedValue(node, self.buffer_writer) - self._write_str_avoiding_backslashes(self.buffer) + def unparse_inner(inner): + unparser = type(self)(_avoid_backslashes=True) + unparser.set_precedence(_Precedence.TEST.next(), inner) + return unparser.visit(inner) - def _fstring_JoinedStr(self, node, write): - for value in node.values: - meth = getattr(self, "_fstring_" + type(value).__name__) - meth(value, write) - - def _fstring_Constant(self, node, write): - if not isinstance(node.value, str): - raise ValueError("Constants inside JoinedStr should be a string.") - value = node.value.replace("{", "{{").replace("}", "}}") - write(value) - - def _fstring_FormattedValue(self, node, write): - write("{") - unparser = type(self)(_avoid_backslashes=True) - unparser.set_precedence(_Precedence.TEST.next(), node.value) - expr = unparser.visit(node.value) - if expr.startswith("{"): - write(" ") # Separate pair of opening brackets as "{ {" - if "\\" in expr: - raise ValueError("Unable to avoid backslash in f-string expression part") - write(expr) - if node.conversion != -1: - conversion = chr(node.conversion) - if conversion not in "sra": - raise ValueError("Unknown f-string conversion.") - write(f"!{conversion}") - if node.format_spec: - write(":") - meth = getattr(self, "_fstring_" + type(node.format_spec).__name__) - meth(node.format_spec, write) - write("}") + with self.delimit("{", "}"): + expr = unparse_inner(node.value) + if "\\" in expr: + raise ValueError( + "Unable to avoid backslash in f-string expression part" + ) + if expr.startswith("{"): + # Separate pair of opening brackets as "{ {" + self.write(" ") + self.write(expr) + if node.conversion != -1: + self.write(f"!{chr(node.conversion)}") + if node.format_spec: + self.write(":") + self._write_fstring_inner(node.format_spec) def visit_Name(self, node): self.write(node.id) |