summaryrefslogtreecommitdiffstats
path: root/Lib/lib2to3/fixes/fix_idioms.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/lib2to3/fixes/fix_idioms.py')
-rw-r--r--Lib/lib2to3/fixes/fix_idioms.py25
1 files changed, 22 insertions, 3 deletions
diff --git a/Lib/lib2to3/fixes/fix_idioms.py b/Lib/lib2to3/fixes/fix_idioms.py
index 1f68faf..9bee99b 100644
--- a/Lib/lib2to3/fixes/fix_idioms.py
+++ b/Lib/lib2to3/fixes/fix_idioms.py
@@ -29,7 +29,7 @@ into
# Local imports
from .. import fixer_base
-from ..fixer_util import Call, Comma, Name, Node, syms
+from ..fixer_util import Call, Comma, Name, Node, BlankLine, syms
CMP = "(n='!=' | '==' | 'is' | n=comp_op< 'is' 'not' >)"
TYPE = "power< 'type' trailer< '(' x=any ')' > >"
@@ -130,5 +130,24 @@ class FixIdioms(fixer_base.BaseFix):
else:
raise RuntimeError("should not have reached here")
sort_stmt.remove()
- if next_stmt:
- next_stmt[0].prefix = sort_stmt.prefix
+
+ btwn = sort_stmt.prefix
+ # Keep any prefix lines between the sort_stmt and the list_call and
+ # shove them right after the sorted() call.
+ if "\n" in btwn:
+ if next_stmt:
+ # The new prefix should be everything from the sort_stmt's
+ # prefix up to the last newline, then the old prefix after a new
+ # line.
+ prefix_lines = (btwn.rpartition("\n")[0], next_stmt[0].prefix)
+ next_stmt[0].prefix = "\n".join(prefix_lines)
+ else:
+ assert list_call.parent
+ assert list_call.next_sibling is None
+ # Put a blank line after list_call and set its prefix.
+ end_line = BlankLine()
+ list_call.parent.append_child(end_line)
+ assert list_call.next_sibling is end_line
+ # The new prefix should be everything up to the first new line
+ # of sort_stmt's prefix.
+ end_line.prefix = btwn.rpartition("\n")[0]