diff options
Diffstat (limited to 'Lib/lib2to3')
-rw-r--r-- | Lib/lib2to3/fixer_base.py | 4 | ||||
-rw-r--r-- | Lib/lib2to3/main.py | 86 | ||||
-rwxr-xr-x | Lib/lib2to3/refactor.py | 239 | ||||
-rw-r--r-- | Lib/lib2to3/tests/support.py | 16 | ||||
-rw-r--r-- | Lib/lib2to3/tests/test_all_fixers.py | 11 | ||||
-rwxr-xr-x | Lib/lib2to3/tests/test_fixers.py | 21 |
6 files changed, 210 insertions, 167 deletions
diff --git a/Lib/lib2to3/fixer_base.py b/Lib/lib2to3/fixer_base.py index 8d78548..5246b08 100644 --- a/Lib/lib2to3/fixer_base.py +++ b/Lib/lib2to3/fixer_base.py @@ -47,8 +47,8 @@ class BaseFix(object): """Initializer. Subclass may override. Args: - options: an optparse.Values instance which can be used - to inspect the command line options. + options: an dict containing the options passed to RefactoringTool + that could be used to customize the fixer through the command line. log: a list to append warnings and other messages to. """ self.options = options diff --git a/Lib/lib2to3/main.py b/Lib/lib2to3/main.py new file mode 100644 index 0000000..c092886 --- /dev/null +++ b/Lib/lib2to3/main.py @@ -0,0 +1,86 @@ +""" +Main program for 2to3. +""" + +import sys +import os +import logging +import optparse + +from . import refactor + + +def main(fixer_pkg, args=None): + """Main program. + + Args: + fixer_pkg: the name of a package where the fixers are located. + args: optional; a list of command line arguments. If omitted, + sys.argv[1:] is used. + + Returns a suggested exit status (0, 1, 2). + """ + # Set up option parser + parser = optparse.OptionParser(usage="refactor.py [options] file|dir ...") + parser.add_option("-d", "--doctests_only", action="store_true", + help="Fix up doctests only") + parser.add_option("-f", "--fix", action="append", default=[], + help="Each FIX specifies a transformation; default all") + parser.add_option("-l", "--list-fixes", action="store_true", + help="List available transformations (fixes/fix_*.py)") + parser.add_option("-p", "--print-function", action="store_true", + help="Modify the grammar so that print() is a function") + parser.add_option("-v", "--verbose", action="store_true", + help="More verbose logging") + parser.add_option("-w", "--write", action="store_true", + help="Write back modified files") + + # Parse command line arguments + refactor_stdin = False + options, args = parser.parse_args(args) + if options.list_fixes: + print "Available transformations for the -f/--fix option:" + for fixname in refactor.get_all_fix_names(fixer_pkg): + print fixname + if not args: + return 0 + if not args: + print >>sys.stderr, "At least one file or directory argument required." + print >>sys.stderr, "Use --help to show usage." + return 2 + if "-" in args: + refactor_stdin = True + if options.write: + print >>sys.stderr, "Can't write to stdin." + return 2 + + # Set up logging handler + level = logging.DEBUG if options.verbose else logging.INFO + logging.basicConfig(format='%(name)s: %(message)s', level=level) + + # Initialize the refactoring tool + rt_opts = {"print_function" : options.print_function} + avail_names = refactor.get_fixers_from_package(fixer_pkg) + explicit = [] + if options.fix: + explicit = [fixer_pkg + ".fix_" + fix + for fix in options.fix if fix != "all"] + fixer_names = avail_names if "all" in options.fix else explicit + else: + fixer_names = avail_names + rt = refactor.RefactoringTool(fixer_names, rt_opts, explicit=explicit) + + # Refactor all files and directories passed as arguments + if not rt.errors: + if refactor_stdin: + rt.refactor_stdin() + else: + rt.refactor(args, options.write, options.doctests_only) + rt.summarize() + + # Return error status (0 if rt.errors is zero) + return int(bool(rt.errors)) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py index 1691041..82c1507 100755 --- a/Lib/lib2to3/refactor.py +++ b/Lib/lib2to3/refactor.py @@ -16,8 +16,8 @@ __author__ = "Guido van Rossum <guido@python.org>" import os import sys import difflib -import optparse import logging +import operator from collections import defaultdict from itertools import chain @@ -30,68 +30,19 @@ from . import patcomp from . import fixes from . import pygram -def main(fixer_dir, args=None): - """Main program. - Args: - fixer_dir: directory where fixer modules are located. - args: optional; a list of command line arguments. If omitted, - sys.argv[1:] is used. - - Returns a suggested exit status (0, 1, 2). - """ - # Set up option parser - parser = optparse.OptionParser(usage="refactor.py [options] file|dir ...") - parser.add_option("-d", "--doctests_only", action="store_true", - help="Fix up doctests only") - parser.add_option("-f", "--fix", action="append", default=[], - help="Each FIX specifies a transformation; default all") - parser.add_option("-l", "--list-fixes", action="store_true", - help="List available transformations (fixes/fix_*.py)") - parser.add_option("-p", "--print-function", action="store_true", - help="Modify the grammar so that print() is a function") - parser.add_option("-v", "--verbose", action="store_true", - help="More verbose logging") - parser.add_option("-w", "--write", action="store_true", - help="Write back modified files") - - # Parse command line arguments - options, args = parser.parse_args(args) - if options.list_fixes: - print "Available transformations for the -f/--fix option:" - for fixname in get_all_fix_names(fixer_dir): - print fixname - if not args: - return 0 - if not args: - print >>sys.stderr, "At least one file or directory argument required." - print >>sys.stderr, "Use --help to show usage." - return 2 - - # Set up logging handler - logging.basicConfig(format='%(name)s: %(message)s', level=logging.INFO) - - # Initialize the refactoring tool - rt = RefactoringTool(fixer_dir, options) - - # Refactor all files and directories passed as arguments - if not rt.errors: - rt.refactor_args(args) - rt.summarize() - - # Return error status (0 if rt.errors is zero) - return int(bool(rt.errors)) - - -def get_all_fix_names(fixer_dir): - """Return a sorted list of all available fix names.""" +def get_all_fix_names(fixer_pkg, remove_prefix=True): + """Return a sorted list of all available fix names in the given package.""" + pkg = __import__(fixer_pkg, [], [], ["*"]) + fixer_dir = os.path.dirname(pkg.__file__) fix_names = [] names = os.listdir(fixer_dir) names.sort() for name in names: if name.startswith("fix_") and name.endswith(".py"): - fix_names.append(name[4:-3]) - fix_names.sort() + if remove_prefix: + name = name[4:] + fix_names.append(name[:-3]) return fix_names def get_head_types(pat): @@ -131,22 +82,36 @@ def get_headnode_dict(fixer_list): head_nodes[t].append(fixer) return head_nodes +def get_fixers_from_package(pkg_name): + """ + Return the fully qualified names for fixers in the package pkg_name. + """ + return [pkg_name + "." + fix_name + for fix_name in get_all_fix_names(pkg_name, False)] + class RefactoringTool(object): - def __init__(self, fixer_dir, options): + _default_options = {"print_function": False} + + def __init__(self, fixer_names, options=None, explicit=[]): """Initializer. Args: - fixer_dir: directory in which to find fixer modules. - options: an optparse.Values instance. + fixer_names: a list of fixers to import + options: an dict with configuration. + explicit: a list of fixers to run even if they are explicit. """ - self.fixer_dir = fixer_dir - self.options = options + self.fixers = fixer_names + self.explicit = explicit + self.options = self._default_options.copy() + if options is not None: + self.options.update(options) self.errors = [] self.logger = logging.getLogger("RefactoringTool") self.fixer_log = [] - if self.options.print_function: + self.wrote = False + if self.options["print_function"]: del pygram.python_grammar.keywords["print"] self.driver = driver.Driver(pygram.python_grammar, convert=pytree.convert, @@ -166,30 +131,24 @@ class RefactoringTool(object): want a pre-order AST traversal, and post_order is the list that want post-order traversal. """ - if os.path.isabs(self.fixer_dir): - fixer_pkg = os.path.relpath(self.fixer_dir, os.path.join(os.path.dirname(__file__), '..')) - else: - fixer_pkg = self.fixer_dir - fixer_pkg = fixer_pkg.replace(os.path.sep, ".") - if os.path.altsep: - fixer_pkg = self.fixer_dir.replace(os.path.altsep, ".") pre_order_fixers = [] post_order_fixers = [] - fix_names = self.options.fix - if not fix_names or "all" in fix_names: - fix_names = get_all_fix_names(self.fixer_dir) - for fix_name in fix_names: + for fix_mod_path in self.fixers: try: - mod = __import__(fixer_pkg + ".fix_" + fix_name, {}, {}, ["*"]) + mod = __import__(fix_mod_path, {}, {}, ["*"]) except ImportError: - self.log_error("Can't find transformation %s", fix_name) + self.log_error("Can't load transformation module %s", + fix_mod_path) continue + fix_name = fix_mod_path.rsplit(".", 1)[-1] + if fix_name.startswith("fix_"): + fix_name = fix_name[4:] parts = fix_name.split("_") class_name = "Fix" + "".join([p.title() for p in parts]) try: fix_class = getattr(mod, class_name) except AttributeError: - self.log_error("Can't find fixes.fix_%s.%s", + self.log_error("Can't find %s.%s", fix_name, class_name) continue try: @@ -198,12 +157,12 @@ class RefactoringTool(object): self.log_error("Can't instantiate fixes.fix_%s.%s()", fix_name, class_name, exc_info=True) continue - if fixer.explicit and fix_name not in self.options.fix: + if fixer.explicit and self.explicit is not True and \ + fix_mod_path not in self.explicit: self.log_message("Skipping implicit fixer: %s", fix_name) continue - if self.options.verbose: - self.log_message("Adding transformation: %s", fix_name) + self.log_debug("Adding transformation: %s", fix_name) if fixer.order == "pre": pre_order_fixers.append(fixer) elif fixer.order == "post": @@ -211,8 +170,9 @@ class RefactoringTool(object): else: raise ValueError("Illegal fixer order: %r" % fixer.order) - pre_order_fixers.sort(key=lambda x: x.run_order) - post_order_fixers.sort(key=lambda x: x.run_order) + key_func = operator.attrgetter("run_order") + pre_order_fixers.sort(key=key_func) + post_order_fixers.sort(key=key_func) return (pre_order_fixers, post_order_fixers) def log_error(self, msg, *args, **kwds): @@ -226,36 +186,38 @@ class RefactoringTool(object): msg = msg % args self.logger.info(msg) - def refactor_args(self, args): - """Refactors files and directories from an argument list.""" - for arg in args: - if arg == "-": - self.refactor_stdin() - elif os.path.isdir(arg): - self.refactor_dir(arg) + def log_debug(self, msg, *args): + if args: + msg = msg % args + self.logger.debug(msg) + + def refactor(self, items, write=False, doctests_only=False): + """Refactor a list of files and directories.""" + for dir_or_file in items: + if os.path.isdir(dir_or_file): + self.refactor_dir(dir_or_file, write) else: - self.refactor_file(arg) + self.refactor_file(dir_or_file, write) - def refactor_dir(self, arg): + def refactor_dir(self, dir_name, write=False, doctests_only=False): """Descends down a directory and refactor every Python file found. Python files are assumed to have a .py extension. Files and subdirectories starting with '.' are skipped. """ - for dirpath, dirnames, filenames in os.walk(arg): - if self.options.verbose: - self.log_message("Descending into %s", dirpath) + for dirpath, dirnames, filenames in os.walk(dir_name): + self.log_debug("Descending into %s", dirpath) dirnames.sort() filenames.sort() for name in filenames: if not name.startswith(".") and name.endswith("py"): fullname = os.path.join(dirpath, name) - self.refactor_file(fullname) + self.refactor_file(fullname, write, doctests_only) # Modify dirnames in-place to remove subdirs with leading dots dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")] - def refactor_file(self, filename): + def refactor_file(self, filename, write=False, doctests_only=False): """Refactors a file.""" try: f = open(filename) @@ -266,21 +228,20 @@ class RefactoringTool(object): input = f.read() + "\n" # Silence certain parse errors finally: f.close() - if self.options.doctests_only: - if self.options.verbose: - self.log_message("Refactoring doctests in %s", filename) + if doctests_only: + self.log_debug("Refactoring doctests in %s", filename) output = self.refactor_docstring(input, filename) if output != input: - self.write_file(output, filename, input) - elif self.options.verbose: - self.log_message("No doctest changes in %s", filename) + self.processed_file(output, filename, input, write=write) + else: + self.log_debug("No doctest changes in %s", filename) else: tree = self.refactor_string(input, filename) if tree and tree.was_changed: # The [:-1] is to take off the \n we added earlier - self.write_file(str(tree)[:-1], filename) - elif self.options.verbose: - self.log_message("No changes in %s", filename) + self.processed_file(str(tree)[:-1], filename, write=write) + else: + self.log_debug("No changes in %s", filename) def refactor_string(self, data, name): """Refactor a given input string. @@ -299,30 +260,25 @@ class RefactoringTool(object): self.log_error("Can't parse %s: %s: %s", name, err.__class__.__name__, err) return - if self.options.verbose: - self.log_message("Refactoring %s", name) + self.log_debug("Refactoring %s", name) self.refactor_tree(tree, name) return tree - def refactor_stdin(self): - if self.options.write: - self.log_error("Can't write changes back to stdin") - return + def refactor_stdin(self, doctests_only=False): input = sys.stdin.read() - if self.options.doctests_only: - if self.options.verbose: - self.log_message("Refactoring doctests in stdin") + if doctests_only: + self.log_debug("Refactoring doctests in stdin") output = self.refactor_docstring(input, "<stdin>") if output != input: - self.write_file(output, "<stdin>", input) - elif self.options.verbose: - self.log_message("No doctest changes in stdin") + self.processed_file(output, "<stdin>", input) + else: + self.log_debug("No doctest changes in stdin") else: tree = self.refactor_string(input, "<stdin>") if tree and tree.was_changed: - self.write_file(str(tree), "<stdin>", input) - elif self.options.verbose: - self.log_message("No changes in stdin") + self.processed_file(str(tree), "<stdin>", input) + else: + self.log_debug("No changes in stdin") def refactor_tree(self, tree, name): """Refactors a parse tree (modifying the tree in place). @@ -374,14 +330,9 @@ class RefactoringTool(object): node.replace(new) node = new - def write_file(self, new_text, filename, old_text=None): - """Writes a string to a file. - - If there are no changes, this is a no-op. - - Otherwise, it first shows a unified diff between the old text - and the new text, and then rewrites the file; the latter is - only done if the write option is set. + def processed_file(self, new_text, filename, old_text=None, write=False): + """ + Called when a file has been refactored, and there are changes. """ self.files.append(filename) if old_text is None: @@ -395,14 +346,22 @@ class RefactoringTool(object): finally: f.close() if old_text == new_text: - if self.options.verbose: - self.log_message("No changes to %s", filename) + self.log_debug("No changes to %s", filename) return diff_texts(old_text, new_text, filename) - if not self.options.write: - if self.options.verbose: - self.log_message("Not writing changes to %s", filename) + if not write: + self.log_debug("Not writing changes to %s", filename) return + if write: + self.write_file(next_text, filename, old_text) + + def write_file(self, new_text, filename, old_text=None): + """Writes a string to a file. + + It first shows a unified diff between the old text and the new text, and + then rewrites the file; the latter is only done if the write option is + set. + """ backup = filename + ".bak" if os.path.lexists(backup): try: @@ -425,8 +384,8 @@ class RefactoringTool(object): self.log_error("Can't write %s: %s", filename, err) finally: f.close() - if self.options.verbose: - self.log_message("Wrote changes to %s", filename) + self.log_debug("Wrote changes to %s", filename) + self.wrote = True PS1 = ">>> " PS2 = "... " @@ -485,9 +444,9 @@ class RefactoringTool(object): try: tree = self.parse_block(block, lineno, indent) except Exception, err: - if self.options.verbose: + if self.log.isEnabledFor(logging.DEBUG): for line in block: - self.log_message("Source: %s", line.rstrip("\n")) + self.log_debug("Source: %s", line.rstrip("\n")) self.log_error("Can't parse docstring in %s line %s: %s: %s", filename, lineno, err.__class__.__name__, err) return block @@ -504,7 +463,7 @@ class RefactoringTool(object): return block def summarize(self): - if self.options.write: + if self.wrote: were = "were" else: were = "need to be" @@ -576,7 +535,3 @@ def diff_texts(a, b, filename): "(original)", "(refactored)", lineterm=""): print line - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/Lib/lib2to3/tests/support.py b/Lib/lib2to3/tests/support.py index 7789033..7abf2ef 100644 --- a/Lib/lib2to3/tests/support.py +++ b/Lib/lib2to3/tests/support.py @@ -13,6 +13,7 @@ from textwrap import dedent # Local imports from .. import pytree +from .. import refactor from ..pgen2 import driver test_dir = os.path.dirname(__file__) @@ -38,6 +39,21 @@ def run_all_tests(test_mod=None, tests=None): def reformat(string): return dedent(string) + "\n\n" +def get_refactorer(fixers=None, options=None): + """ + A convenience function for creating a RefactoringTool for tests. + + fixers is a list of fixers for the RefactoringTool to use. By default + "lib2to3.fixes.*" is used. options is an optional dictionary of options to + be passed to the RefactoringTool. + """ + if fixers is not None: + fixers = ["lib2to3.fixes.fix_" + fix for fix in fixers] + else: + fixers = refactor.get_fixers_from_package("lib2to3.fixes") + options = options or {} + return refactor.RefactoringTool(fixers, options, explicit=True) + def all_project_files(): for dirpath, dirnames, filenames in os.walk(proj_dir): for filename in filenames: diff --git a/Lib/lib2to3/tests/test_all_fixers.py b/Lib/lib2to3/tests/test_all_fixers.py index a7b7a19..39adaa9 100644 --- a/Lib/lib2to3/tests/test_all_fixers.py +++ b/Lib/lib2to3/tests/test_all_fixers.py @@ -19,17 +19,10 @@ import unittest from .. import pytree from .. import refactor -class Options: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - self.verbose = False - class Test_all(support.TestCase): def setUp(self): - options = Options(fix=["all", "idioms", "ws_comma", "buffer"], - print_function=False) - self.refactor = refactor.RefactoringTool("lib2to3/fixes", options) + options = {"print_function" : False} + self.refactor = support.get_refactorer(options=options) def test_all_project_files(self): for filepath in support.all_project_files(): diff --git a/Lib/lib2to3/tests/test_fixers.py b/Lib/lib2to3/tests/test_fixers.py index d86bb76..2dc65d3 100755 --- a/Lib/lib2to3/tests/test_fixers.py +++ b/Lib/lib2to3/tests/test_fixers.py @@ -21,19 +21,12 @@ from .. import refactor from .. import fixer_util -class Options: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - self.verbose = False - class FixerTestCase(support.TestCase): def setUp(self, fix_list=None): - if not fix_list: + if fix_list is None: fix_list = [self.fixer] - options = Options(fix=fix_list, print_function=False) - self.refactor = refactor.RefactoringTool("lib2to3/fixes", options) + options = {"print_function" : False} + self.refactor = support.get_refactorer(fix_list, options) self.fixer_log = [] self.filename = "<string>" @@ -70,10 +63,10 @@ class FixerTestCase(support.TestCase): self.failUnlessEqual(self.fixer_log, []) def assert_runs_after(self, *names): - fix = [self.fixer] - fix.extend(names) - options = Options(fix=fix, print_function=False) - r = refactor.RefactoringTool("lib2to3/fixes", options) + fixes = [self.fixer] + fixes.extend(names) + options = {"print_function" : False} + r = support.get_refactorer(fixes, options) (pre, post) = r.get_fixers() n = "fix_" + self.fixer if post and post[-1].__class__.__module__.endswith(n): |