summaryrefslogtreecommitdiffstats
path: root/Lib/lib2to3/refactor.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/lib2to3/refactor.py')
-rwxr-xr-xLib/lib2to3/refactor.py86
1 files changed, 56 insertions, 30 deletions
diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py
index 32aabfc..9a4ef6a 100755
--- a/Lib/lib2to3/refactor.py
+++ b/Lib/lib2to3/refactor.py
@@ -22,8 +22,7 @@ from collections import defaultdict
from itertools import chain
# Local imports
-from .pgen2 import driver
-from .pgen2 import tokenize
+from .pgen2 import driver, tokenize
from . import pytree
from . import patcomp
@@ -87,6 +86,25 @@ def get_fixers_from_package(pkg_name):
return [pkg_name + "." + fix_name
for fix_name in get_all_fix_names(pkg_name, False)]
+def _identity(obj):
+ return obj
+
+if sys.version_info < (3, 0):
+ import codecs
+ _open_with_encoding = codecs.open
+ # codecs.open doesn't translate newlines sadly.
+ def _from_system_newlines(input):
+ return input.replace(u"\r\n", u"\n")
+ def _to_system_newlines(input):
+ if os.linesep != "\n":
+ return input.replace(u"\n", os.linesep)
+ else:
+ return input
+else:
+ _open_with_encoding = open
+ _from_system_newlines = _identity
+ _to_system_newlines = _identity
+
class FixerError(Exception):
"""A fixer could not be loaded."""
@@ -213,29 +231,42 @@ class RefactoringTool(object):
# Modify dirnames in-place to remove subdirs with leading dots
dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
- def refactor_file(self, filename, write=False, doctests_only=False):
- """Refactors a file."""
+ def _read_python_source(self, filename):
+ """
+ Do our best to decode a Python source file correctly.
+ """
try:
- f = open(filename)
+ f = open(filename, "rb")
except IOError, err:
self.log_error("Can't open %s: %s", filename, err)
- return
+ return None, None
try:
- input = f.read() + "\n" # Silence certain parse errors
+ encoding = tokenize.detect_encoding(f.readline)[0]
finally:
f.close()
+ with _open_with_encoding(filename, "r", encoding=encoding) as f:
+ return _from_system_newlines(f.read()), encoding
+
+ def refactor_file(self, filename, write=False, doctests_only=False):
+ """Refactors a file."""
+ input, encoding = self._read_python_source(filename)
+ if input is None:
+ # Reading the file failed.
+ return
+ input += u"\n" # Silence certain parse errors
if doctests_only:
self.log_debug("Refactoring doctests in %s", filename)
output = self.refactor_docstring(input, filename)
if output != input:
- self.processed_file(output, filename, input, write=write)
+ self.processed_file(output, filename, input, write, encoding)
else:
self.log_debug("No doctest changes in %s", filename)
else:
tree = self.refactor_string(input, filename)
if tree and tree.was_changed:
# The [:-1] is to take off the \n we added earlier
- self.processed_file(str(tree)[:-1], filename, write=write)
+ self.processed_file(unicode(tree)[:-1], filename,
+ write=write, encoding=encoding)
else:
self.log_debug("No changes in %s", filename)
@@ -321,31 +352,26 @@ class RefactoringTool(object):
node.replace(new)
node = new
- def processed_file(self, new_text, filename, old_text=None, write=False):
+ def processed_file(self, new_text, filename, old_text=None, write=False,
+ encoding=None):
"""
Called when a file has been refactored, and there are changes.
"""
self.files.append(filename)
if old_text is None:
- try:
- f = open(filename, "r")
- except IOError, err:
- self.log_error("Can't read %s: %s", filename, err)
+ old_text = self._read_python_source(filename)[0]
+ if old_text is None:
return
- try:
- old_text = f.read()
- finally:
- f.close()
if old_text == new_text:
self.log_debug("No changes to %s", filename)
return
self.print_output(diff_texts(old_text, new_text, filename))
if write:
- self.write_file(new_text, filename, old_text)
+ self.write_file(new_text, filename, old_text, encoding)
else:
self.log_debug("Not writing changes to %s", filename)
- def write_file(self, new_text, filename, old_text):
+ def write_file(self, new_text, filename, old_text, encoding=None):
"""Writes a string to a file.
It first shows a unified diff between the old text and the new text, and
@@ -353,12 +379,12 @@ class RefactoringTool(object):
set.
"""
try:
- f = open(filename, "w")
+ f = _open_with_encoding(filename, "w", encoding=encoding)
except os.error, err:
self.log_error("Can't create %s: %s", filename, err)
return
try:
- f.write(new_text)
+ f.write(_to_system_newlines(new_text))
except os.error, err:
self.log_error("Can't write %s: %s", filename, err)
finally:
@@ -398,7 +424,7 @@ class RefactoringTool(object):
indent = line[:i]
elif (indent is not None and
(line.startswith(indent + self.PS2) or
- line == indent + self.PS2.rstrip() + "\n")):
+ line == indent + self.PS2.rstrip() + u"\n")):
block.append(line)
else:
if block is not None:
@@ -410,7 +436,7 @@ class RefactoringTool(object):
if block is not None:
result.extend(self.refactor_doctest(block, block_lineno,
indent, filename))
- return "".join(result)
+ return u"".join(result)
def refactor_doctest(self, block, lineno, indent, filename):
"""Refactors one doctest.
@@ -425,7 +451,7 @@ class RefactoringTool(object):
except Exception, err:
if self.log.isEnabledFor(logging.DEBUG):
for line in block:
- self.log_debug("Source: %s", line.rstrip("\n"))
+ self.log_debug("Source: %s", line.rstrip(u"\n"))
self.log_error("Can't parse docstring in %s line %s: %s: %s",
filename, lineno, err.__class__.__name__, err)
return block
@@ -433,9 +459,9 @@ class RefactoringTool(object):
new = str(tree).splitlines(True)
# Undo the adjustment of the line numbers in wrap_toks() below.
clipped, new = new[:lineno-1], new[lineno-1:]
- assert clipped == ["\n"] * (lineno-1), clipped
- if not new[-1].endswith("\n"):
- new[-1] += "\n"
+ assert clipped == [u"\n"] * (lineno-1), clipped
+ if not new[-1].endswith(u"\n"):
+ new[-1] += u"\n"
block = [indent + self.PS1 + new.pop(0)]
if new:
block += [indent + self.PS2 + line for line in new]
@@ -497,8 +523,8 @@ class RefactoringTool(object):
for line in block:
if line.startswith(prefix):
yield line[len(prefix):]
- elif line == prefix.rstrip() + "\n":
- yield "\n"
+ elif line == prefix.rstrip() + u"\n":
+ yield u"\n"
else:
raise AssertionError("line=%r, prefix=%r" % (line, prefix))
prefix = prefix2