diff options
Diffstat (limited to 'Lib/lib2to3')
-rw-r--r-- | Lib/lib2to3/main.py | 20 | ||||
-rwxr-xr-x | Lib/lib2to3/refactor.py | 66 | ||||
-rw-r--r-- | Lib/lib2to3/tests/data/fixers/bad_order.py | 5 | ||||
-rw-r--r-- | Lib/lib2to3/tests/data/fixers/myfixes/__init__.py | 0 | ||||
-rw-r--r-- | Lib/lib2to3/tests/data/fixers/myfixes/fix_explicit.py | 6 | ||||
-rw-r--r-- | Lib/lib2to3/tests/data/fixers/myfixes/fix_first.py | 6 | ||||
-rw-r--r-- | Lib/lib2to3/tests/data/fixers/myfixes/fix_last.py | 7 | ||||
-rw-r--r-- | Lib/lib2to3/tests/data/fixers/myfixes/fix_parrot.py | 13 | ||||
-rw-r--r-- | Lib/lib2to3/tests/data/fixers/myfixes/fix_preorder.py | 6 | ||||
-rw-r--r-- | Lib/lib2to3/tests/data/fixers/no_fixer_cls.py | 1 | ||||
-rw-r--r-- | Lib/lib2to3/tests/data/fixers/parrot_example.py | 2 | ||||
-rw-r--r-- | Lib/lib2to3/tests/test_refactor.py | 175 |
12 files changed, 267 insertions, 40 deletions
diff --git a/Lib/lib2to3/main.py b/Lib/lib2to3/main.py index 939f4f7..b286de8 100644 --- a/Lib/lib2to3/main.py +++ b/Lib/lib2to3/main.py @@ -10,6 +10,20 @@ import optparse from . import refactor +class StdoutRefactoringTool(refactor.RefactoringTool): + """ + Prints output to stdout. + """ + + def log_error(self, msg, *args, **kwargs): + self.errors.append((msg, args, kwargs)) + self.logger.error(msg, *args, **kwargs) + + def print_output(self, lines): + for line in lines: + print(line) + + def main(fixer_pkg, args=None): """Main program. @@ -68,7 +82,7 @@ def main(fixer_pkg, args=None): fixer_names = avail_names if "all" in options.fix else explicit else: fixer_names = avail_names - rt = refactor.RefactoringTool(fixer_names, rt_opts, explicit=explicit) + rt = StdoutRefactoringTool(fixer_names, rt_opts, explicit=explicit) # Refactor all files and directories passed as arguments if not rt.errors: @@ -80,7 +94,3 @@ def main(fixer_pkg, args=None): # Return error status (0 if rt.errors is zero) return int(bool(rt.errors)) - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py index c318045..82a2d45 100755 --- a/Lib/lib2to3/refactor.py +++ b/Lib/lib2to3/refactor.py @@ -90,11 +90,18 @@ def get_fixers_from_package(pkg_name): for fix_name in get_all_fix_names(pkg_name, False)] +class FixerError(Exception): + """A fixer could not be loaded.""" + + class RefactoringTool(object): _default_options = {"print_function": False} - def __init__(self, fixer_names, options=None, explicit=[]): + CLASS_PREFIX = "Fix" # The prefix for fixer classes + FILE_PREFIX = "fix_" # The prefix for modules with a fixer within + + def __init__(self, fixer_names, options=None, explicit=None): """Initializer. Args: @@ -103,7 +110,7 @@ class RefactoringTool(object): explicit: a list of fixers to run even if they are explicit. """ self.fixers = fixer_names - self.explicit = explicit + self.explicit = explicit or [] self.options = self._default_options.copy() if options is not None: self.options.update(options) @@ -134,29 +141,17 @@ class RefactoringTool(object): pre_order_fixers = [] post_order_fixers = [] for fix_mod_path in self.fixers: - try: - mod = __import__(fix_mod_path, {}, {}, ["*"]) - except ImportError: - self.log_error("Can't load transformation module %s", - fix_mod_path) - continue + mod = __import__(fix_mod_path, {}, {}, ["*"]) fix_name = fix_mod_path.rsplit(".", 1)[-1] - if fix_name.startswith("fix_"): - fix_name = fix_name[4:] + if fix_name.startswith(self.FILE_PREFIX): + fix_name = fix_name[len(self.FILE_PREFIX):] parts = fix_name.split("_") - class_name = "Fix" + "".join([p.title() for p in parts]) + class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts]) try: fix_class = getattr(mod, class_name) except AttributeError: - self.log_error("Can't find %s.%s", - fix_name, class_name) - continue - try: - fixer = fix_class(self.options, self.fixer_log) - except Exception as err: - self.log_error("Can't instantiate fixes.fix_%s.%s()", - fix_name, class_name, exc_info=True) - continue + raise FixerError("Can't find %s.%s" % (fix_name, class_name)) + fixer = fix_class(self.options, self.fixer_log) if fixer.explicit and self.explicit is not True and \ fix_mod_path not in self.explicit: self.log_message("Skipping implicit fixer: %s", fix_name) @@ -168,7 +163,7 @@ class RefactoringTool(object): elif fixer.order == "post": post_order_fixers.append(fixer) else: - raise ValueError("Illegal fixer order: %r" % fixer.order) + raise FixerError("Illegal fixer order: %r" % fixer.order) key_func = operator.attrgetter("run_order") pre_order_fixers.sort(key=key_func) @@ -176,9 +171,8 @@ class RefactoringTool(object): return (pre_order_fixers, post_order_fixers) def log_error(self, msg, *args, **kwds): - """Increments error count and log a message.""" - self.errors.append((msg, args, kwds)) - self.logger.error(msg, *args, **kwds) + """Called when an error occurs.""" + raise def log_message(self, msg, *args): """Hook to log a message.""" @@ -191,13 +185,17 @@ class RefactoringTool(object): msg = msg % args self.logger.debug(msg) + def print_output(self, lines): + """Called with lines of output to give to the user.""" + pass + def refactor(self, items, write=False, doctests_only=False): """Refactor a list of files and directories.""" for dir_or_file in items: if os.path.isdir(dir_or_file): - self.refactor_dir(dir_or_file, write) + self.refactor_dir(dir_or_file, write, doctests_only) else: - self.refactor_file(dir_or_file, write) + self.refactor_file(dir_or_file, write, doctests_only) def refactor_dir(self, dir_name, write=False, doctests_only=False): """Descends down a directory and refactor every Python file found. @@ -348,12 +346,11 @@ class RefactoringTool(object): if old_text == new_text: self.log_debug("No changes to %s", filename) return - diff_texts(old_text, new_text, filename) - if not write: - self.log_debug("Not writing changes to %s", filename) - return + self.print_output(diff_texts(old_text, new_text, filename)) if write: self.write_file(new_text, filename, old_text) + else: + self.log_debug("Not writing changes to %s", filename) def write_file(self, new_text, filename, old_text=None): """Writes a string to a file. @@ -528,10 +525,9 @@ class RefactoringTool(object): def diff_texts(a, b, filename): - """Prints a unified diff of two strings.""" + """Return a unified diff of two strings.""" a = a.splitlines() b = b.splitlines() - for line in difflib.unified_diff(a, b, filename, filename, - "(original)", "(refactored)", - lineterm=""): - print(line) + return difflib.unified_diff(a, b, filename, filename, + "(original)", "(refactored)", + lineterm="") diff --git a/Lib/lib2to3/tests/data/fixers/bad_order.py b/Lib/lib2to3/tests/data/fixers/bad_order.py new file mode 100644 index 0000000..061bbf2 --- /dev/null +++ b/Lib/lib2to3/tests/data/fixers/bad_order.py @@ -0,0 +1,5 @@ +from lib2to3.fixer_base import BaseFix + +class FixBadOrder(BaseFix): + + order = "crazy" diff --git a/Lib/lib2to3/tests/data/fixers/myfixes/__init__.py b/Lib/lib2to3/tests/data/fixers/myfixes/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/Lib/lib2to3/tests/data/fixers/myfixes/__init__.py diff --git a/Lib/lib2to3/tests/data/fixers/myfixes/fix_explicit.py b/Lib/lib2to3/tests/data/fixers/myfixes/fix_explicit.py new file mode 100644 index 0000000..cbe16f6 --- /dev/null +++ b/Lib/lib2to3/tests/data/fixers/myfixes/fix_explicit.py @@ -0,0 +1,6 @@ +from lib2to3.fixer_base import BaseFix + +class FixExplicit(BaseFix): + explicit = True + + def match(self): return False diff --git a/Lib/lib2to3/tests/data/fixers/myfixes/fix_first.py b/Lib/lib2to3/tests/data/fixers/myfixes/fix_first.py new file mode 100644 index 0000000..a88821f --- /dev/null +++ b/Lib/lib2to3/tests/data/fixers/myfixes/fix_first.py @@ -0,0 +1,6 @@ +from lib2to3.fixer_base import BaseFix + +class FixFirst(BaseFix): + run_order = 1 + + def match(self, node): return False diff --git a/Lib/lib2to3/tests/data/fixers/myfixes/fix_last.py b/Lib/lib2to3/tests/data/fixers/myfixes/fix_last.py new file mode 100644 index 0000000..9a077d4 --- /dev/null +++ b/Lib/lib2to3/tests/data/fixers/myfixes/fix_last.py @@ -0,0 +1,7 @@ +from lib2to3.fixer_base import BaseFix + +class FixLast(BaseFix): + + run_order = 10 + + def match(self, node): return False diff --git a/Lib/lib2to3/tests/data/fixers/myfixes/fix_parrot.py b/Lib/lib2to3/tests/data/fixers/myfixes/fix_parrot.py new file mode 100644 index 0000000..6bd2f49 --- /dev/null +++ b/Lib/lib2to3/tests/data/fixers/myfixes/fix_parrot.py @@ -0,0 +1,13 @@ +from lib2to3.fixer_base import BaseFix +from lib2to3.fixer_util import Name + +class FixParrot(BaseFix): + """ + Change functions named 'parrot' to 'cheese'. + """ + + PATTERN = """funcdef < 'def' name='parrot' any* >""" + + def transform(self, node, results): + name = results["name"] + name.replace(Name("cheese", name.get_prefix())) diff --git a/Lib/lib2to3/tests/data/fixers/myfixes/fix_preorder.py b/Lib/lib2to3/tests/data/fixers/myfixes/fix_preorder.py new file mode 100644 index 0000000..b9bfbba --- /dev/null +++ b/Lib/lib2to3/tests/data/fixers/myfixes/fix_preorder.py @@ -0,0 +1,6 @@ +from lib2to3.fixer_base import BaseFix + +class FixPreorder(BaseFix): + order = "pre" + + def match(self, node): return False diff --git a/Lib/lib2to3/tests/data/fixers/no_fixer_cls.py b/Lib/lib2to3/tests/data/fixers/no_fixer_cls.py new file mode 100644 index 0000000..506f794 --- /dev/null +++ b/Lib/lib2to3/tests/data/fixers/no_fixer_cls.py @@ -0,0 +1 @@ +# This is empty so trying to fetch the fixer class gives an AttributeError diff --git a/Lib/lib2to3/tests/data/fixers/parrot_example.py b/Lib/lib2to3/tests/data/fixers/parrot_example.py new file mode 100644 index 0000000..0852928 --- /dev/null +++ b/Lib/lib2to3/tests/data/fixers/parrot_example.py @@ -0,0 +1,2 @@ +def parrot(): + pass diff --git a/Lib/lib2to3/tests/test_refactor.py b/Lib/lib2to3/tests/test_refactor.py new file mode 100644 index 0000000..0717479 --- /dev/null +++ b/Lib/lib2to3/tests/test_refactor.py @@ -0,0 +1,175 @@ +""" +Unit tests for refactor.py. +""" + +import sys +import os +import operator +import io +import tempfile +import unittest + +from lib2to3 import refactor, pygram, fixer_base + +from . import support + + +FIXER_DIR = os.path.join(os.path.dirname(__file__), "data/fixers") + +sys.path.append(FIXER_DIR) +try: + _DEFAULT_FIXERS = refactor.get_fixers_from_package("myfixes") +finally: + sys.path.pop() + +class TestRefactoringTool(unittest.TestCase): + + def setUp(self): + sys.path.append(FIXER_DIR) + + def tearDown(self): + sys.path.pop() + + def check_instances(self, instances, classes): + for inst, cls in zip(instances, classes): + if not isinstance(inst, cls): + self.fail("%s are not instances of %s" % instances, classes) + + def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None): + return refactor.RefactoringTool(fixers, options, explicit) + + def test_print_function_option(self): + gram = pygram.python_grammar + save = gram.keywords["print"] + try: + rt = self.rt({"print_function" : True}) + self.assertRaises(KeyError, operator.itemgetter("print"), + gram.keywords) + finally: + gram.keywords["print"] = save + + def test_fixer_loading_helpers(self): + contents = ["explicit", "first", "last", "parrot", "preorder"] + non_prefixed = refactor.get_all_fix_names("myfixes") + prefixed = refactor.get_all_fix_names("myfixes", False) + full_names = refactor.get_fixers_from_package("myfixes") + self.assertEqual(prefixed, ["fix_" + name for name in contents]) + self.assertEqual(non_prefixed, contents) + self.assertEqual(full_names, + ["myfixes.fix_" + name for name in contents]) + + def test_get_headnode_dict(self): + class NoneFix(fixer_base.BaseFix): + PATTERN = None + + class FileInputFix(fixer_base.BaseFix): + PATTERN = "file_input< any * >" + + no_head = NoneFix({}, []) + with_head = FileInputFix({}, []) + d = refactor.get_headnode_dict([no_head, with_head]) + expected = {None: [no_head], + pygram.python_symbols.file_input : [with_head]} + self.assertEqual(d, expected) + + def test_fixer_loading(self): + from myfixes.fix_first import FixFirst + from myfixes.fix_last import FixLast + from myfixes.fix_parrot import FixParrot + from myfixes.fix_preorder import FixPreorder + + rt = self.rt() + pre, post = rt.get_fixers() + + self.check_instances(pre, [FixPreorder]) + self.check_instances(post, [FixFirst, FixParrot, FixLast]) + + def test_naughty_fixers(self): + self.assertRaises(ImportError, self.rt, fixers=["not_here"]) + self.assertRaises(refactor.FixerError, self.rt, fixers=["no_fixer_cls"]) + self.assertRaises(refactor.FixerError, self.rt, fixers=["bad_order"]) + + def test_refactor_string(self): + rt = self.rt() + input = "def parrot(): pass\n\n" + tree = rt.refactor_string(input, "<test>") + self.assertNotEqual(str(tree), input) + + input = "def f(): pass\n\n" + tree = rt.refactor_string(input, "<test>") + self.assertEqual(str(tree), input) + + def test_refactor_stdin(self): + + class MyRT(refactor.RefactoringTool): + + def print_output(self, lines): + diff_lines.extend(lines) + + diff_lines = [] + rt = MyRT(_DEFAULT_FIXERS) + save = sys.stdin + sys.stdin = io.StringIO("def parrot(): pass\n\n") + try: + rt.refactor_stdin() + finally: + sys.stdin = save + expected = """--- <stdin> (original) ++++ <stdin> (refactored) +@@ -1,2 +1,2 @@ +-def parrot(): pass ++def cheese(): pass""".splitlines() + self.assertEqual(diff_lines[:-1], expected) + + def test_refactor_file(self): + test_file = os.path.join(FIXER_DIR, "parrot_example.py") + backup = test_file + ".bak" + old_contents = open(test_file, "r").read() + rt = self.rt() + + rt.refactor_file(test_file) + self.assertEqual(old_contents, open(test_file, "r").read()) + + rt.refactor_file(test_file, True) + try: + self.assertNotEqual(old_contents, open(test_file, "r").read()) + self.assertTrue(os.path.exists(backup)) + self.assertEqual(old_contents, open(backup, "r").read()) + finally: + open(test_file, "w").write(old_contents) + try: + os.unlink(backup) + except OSError: + pass + + def test_refactor_docstring(self): + rt = self.rt() + + def example(): + """ + >>> example() + 42 + """ + out = rt.refactor_docstring(example.__doc__, "<test>") + self.assertEqual(out, example.__doc__) + + def parrot(): + """ + >>> def parrot(): + ... return 43 + """ + out = rt.refactor_docstring(parrot.__doc__, "<test>") + self.assertNotEqual(out, parrot.__doc__) + + def test_explicit(self): + from myfixes.fix_explicit import FixExplicit + + rt = self.rt(fixers=["myfixes.fix_explicit"]) + self.assertEqual(len(rt.post_order), 0) + + rt = self.rt(explicit=["myfixes.fix_explicit"]) + for fix in rt.post_order[None]: + if isinstance(fix, FixExplicit): + break + else: + self.fail("explicit fixer not loaded") |