diff options
Diffstat (limited to 'Lib/lib2to3/refactor.py')
-rwxr-xr-x | Lib/lib2to3/refactor.py | 70 |
1 files changed, 48 insertions, 22 deletions
diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py index b679db4..82a98d1 100755 --- a/Lib/lib2to3/refactor.py +++ b/Lib/lib2to3/refactor.py @@ -22,8 +22,7 @@ from collections import defaultdict from itertools import chain # Local imports -from .pgen2 import driver -from .pgen2 import tokenize +from .pgen2 import driver, tokenize from . import pytree from . import patcomp @@ -87,6 +86,25 @@ def get_fixers_from_package(pkg_name): return [pkg_name + "." + fix_name for fix_name in get_all_fix_names(pkg_name, False)] +def _identity(obj): + return obj + +if sys.version_info < (3, 0): + import codecs + _open_with_encoding = codecs.open + # codecs.open doesn't translate newlines sadly. + def _from_system_newlines(input): + return input.replace("\r\n", "\n") + def _to_system_newlines(input): + if os.linesep != "\n": + return input.replace("\n", os.linesep) + else: + return input +else: + _open_with_encoding = open + _from_system_newlines = _identity + _to_system_newlines = _identity + class FixerError(Exception): """A fixer could not be loaded.""" @@ -213,29 +231,42 @@ class RefactoringTool(object): # Modify dirnames in-place to remove subdirs with leading dots dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")] - def refactor_file(self, filename, write=False, doctests_only=False): - """Refactors a file.""" + def _read_python_source(self, filename): + """ + Do our best to decode a Python source file correctly. + """ try: - f = open(filename) + f = open(filename, "rb") except IOError as err: self.log_error("Can't open %s: %s", filename, err) - return + return None, None try: - input = f.read() + "\n" # Silence certain parse errors + encoding = tokenize.detect_encoding(f.readline)[0] finally: f.close() + with _open_with_encoding(filename, "r", encoding=encoding) as f: + return _from_system_newlines(f.read()), encoding + + def refactor_file(self, filename, write=False, doctests_only=False): + """Refactors a file.""" + input, encoding = self._read_python_source(filename) + if input is None: + # Reading the file failed. + return + input += "\n" # Silence certain parse errors if doctests_only: self.log_debug("Refactoring doctests in %s", filename) output = self.refactor_docstring(input, filename) if output != input: - self.processed_file(output, filename, input, write=write) + self.processed_file(output, filename, input, write, encoding) else: self.log_debug("No doctest changes in %s", filename) else: tree = self.refactor_string(input, filename) if tree and tree.was_changed: # The [:-1] is to take off the \n we added earlier - self.processed_file(str(tree)[:-1], filename, write=write) + self.processed_file(str(tree)[:-1], filename, + write=write, encoding=encoding) else: self.log_debug("No changes in %s", filename) @@ -321,31 +352,26 @@ class RefactoringTool(object): node.replace(new) node = new - def processed_file(self, new_text, filename, old_text=None, write=False): + def processed_file(self, new_text, filename, old_text=None, write=False, + encoding=None): """ Called when a file has been refactored, and there are changes. """ self.files.append(filename) if old_text is None: - try: - f = open(filename, "r") - except IOError as err: - self.log_error("Can't read %s: %s", filename, err) + old_text = self._read_python_source(filename)[0] + if old_text is None: return - try: - old_text = f.read() - finally: - f.close() if old_text == new_text: self.log_debug("No changes to %s", filename) return self.print_output(diff_texts(old_text, new_text, filename)) if write: - self.write_file(new_text, filename, old_text) + self.write_file(new_text, filename, old_text, encoding) else: self.log_debug("Not writing changes to %s", filename) - def write_file(self, new_text, filename, old_text): + def write_file(self, new_text, filename, old_text, encoding=None): """Writes a string to a file. It first shows a unified diff between the old text and the new text, and @@ -353,12 +379,12 @@ class RefactoringTool(object): set. """ try: - f = open(filename, "w") + f = _open_with_encoding(filename, "w", encoding=encoding) except os.error as err: self.log_error("Can't create %s: %s", filename, err) return try: - f.write(new_text) + f.write(_to_system_newlines(new_text)) except os.error as err: self.log_error("Can't write %s: %s", filename, err) finally: |