diff options
Diffstat (limited to 'Lib/lib2to3/refactor.py')
-rwxr-xr-x | Lib/lib2to3/refactor.py | 66 |
1 files changed, 31 insertions, 35 deletions
diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py index c318045..82a2d45 100755 --- a/Lib/lib2to3/refactor.py +++ b/Lib/lib2to3/refactor.py @@ -90,11 +90,18 @@ def get_fixers_from_package(pkg_name): for fix_name in get_all_fix_names(pkg_name, False)] +class FixerError(Exception): + """A fixer could not be loaded.""" + + class RefactoringTool(object): _default_options = {"print_function": False} - def __init__(self, fixer_names, options=None, explicit=[]): + CLASS_PREFIX = "Fix" # The prefix for fixer classes + FILE_PREFIX = "fix_" # The prefix for modules with a fixer within + + def __init__(self, fixer_names, options=None, explicit=None): """Initializer. Args: @@ -103,7 +110,7 @@ class RefactoringTool(object): explicit: a list of fixers to run even if they are explicit. """ self.fixers = fixer_names - self.explicit = explicit + self.explicit = explicit or [] self.options = self._default_options.copy() if options is not None: self.options.update(options) @@ -134,29 +141,17 @@ class RefactoringTool(object): pre_order_fixers = [] post_order_fixers = [] for fix_mod_path in self.fixers: - try: - mod = __import__(fix_mod_path, {}, {}, ["*"]) - except ImportError: - self.log_error("Can't load transformation module %s", - fix_mod_path) - continue + mod = __import__(fix_mod_path, {}, {}, ["*"]) fix_name = fix_mod_path.rsplit(".", 1)[-1] - if fix_name.startswith("fix_"): - fix_name = fix_name[4:] + if fix_name.startswith(self.FILE_PREFIX): + fix_name = fix_name[len(self.FILE_PREFIX):] parts = fix_name.split("_") - class_name = "Fix" + "".join([p.title() for p in parts]) + class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts]) try: fix_class = getattr(mod, class_name) except AttributeError: - self.log_error("Can't find %s.%s", - fix_name, class_name) - continue - try: - fixer = fix_class(self.options, self.fixer_log) - except Exception as err: - self.log_error("Can't instantiate fixes.fix_%s.%s()", - fix_name, class_name, exc_info=True) - continue + raise FixerError("Can't find %s.%s" % (fix_name, class_name)) + fixer = fix_class(self.options, self.fixer_log) if fixer.explicit and self.explicit is not True and \ fix_mod_path not in self.explicit: self.log_message("Skipping implicit fixer: %s", fix_name) @@ -168,7 +163,7 @@ class RefactoringTool(object): elif fixer.order == "post": post_order_fixers.append(fixer) else: - raise ValueError("Illegal fixer order: %r" % fixer.order) + raise FixerError("Illegal fixer order: %r" % fixer.order) key_func = operator.attrgetter("run_order") pre_order_fixers.sort(key=key_func) @@ -176,9 +171,8 @@ class RefactoringTool(object): return (pre_order_fixers, post_order_fixers) def log_error(self, msg, *args, **kwds): - """Increments error count and log a message.""" - self.errors.append((msg, args, kwds)) - self.logger.error(msg, *args, **kwds) + """Called when an error occurs.""" + raise def log_message(self, msg, *args): """Hook to log a message.""" @@ -191,13 +185,17 @@ class RefactoringTool(object): msg = msg % args self.logger.debug(msg) + def print_output(self, lines): + """Called with lines of output to give to the user.""" + pass + def refactor(self, items, write=False, doctests_only=False): """Refactor a list of files and directories.""" for dir_or_file in items: if os.path.isdir(dir_or_file): - self.refactor_dir(dir_or_file, write) + self.refactor_dir(dir_or_file, write, doctests_only) else: - self.refactor_file(dir_or_file, write) + self.refactor_file(dir_or_file, write, doctests_only) def refactor_dir(self, dir_name, write=False, doctests_only=False): """Descends down a directory and refactor every Python file found. @@ -348,12 +346,11 @@ class RefactoringTool(object): if old_text == new_text: self.log_debug("No changes to %s", filename) return - diff_texts(old_text, new_text, filename) - if not write: - self.log_debug("Not writing changes to %s", filename) - return + self.print_output(diff_texts(old_text, new_text, filename)) if write: self.write_file(new_text, filename, old_text) + else: + self.log_debug("Not writing changes to %s", filename) def write_file(self, new_text, filename, old_text=None): """Writes a string to a file. @@ -528,10 +525,9 @@ class RefactoringTool(object): def diff_texts(a, b, filename): - """Prints a unified diff of two strings.""" + """Return a unified diff of two strings.""" a = a.splitlines() b = b.splitlines() - for line in difflib.unified_diff(a, b, filename, filename, - "(original)", "(refactored)", - lineterm=""): - print(line) + return difflib.unified_diff(a, b, filename, filename, + "(original)", "(refactored)", + lineterm="") |