diff options
Diffstat (limited to 'Lib/lib2to3/refactor.py')
-rwxr-xr-x | Lib/lib2to3/refactor.py | 51 |
1 files changed, 49 insertions, 2 deletions
diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py index ae7d3cdf..ae9d7d0 100755 --- a/Lib/lib2to3/refactor.py +++ b/Lib/lib2to3/refactor.py @@ -18,6 +18,8 @@ import sys import difflib import optparse import logging +from collections import defaultdict +from itertools import chain # Local imports from .pgen2 import driver @@ -96,6 +98,43 @@ def get_all_fix_names(): fix_names.sort() return fix_names +def get_head_types(pat): + """ Accepts a pytree Pattern Node and returns a set + of the pattern types which will match first. """ + + if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)): + # NodePatters must either have no type and no content + # or a type and content -- so they don't get any farther + # Always return leafs + return set([pat.type]) + + if isinstance(pat, pytree.NegatedPattern): + if pat.content: + return get_head_types(pat.content) + return set([None]) # Negated Patterns don't have a type + + if isinstance(pat, pytree.WildcardPattern): + # Recurse on each node in content + r = set() + for p in pat.content: + for x in p: + r.update(get_head_types(x)) + return r + + raise Exception("Oh no! I don't understand pattern %s" %(pat)) + +def get_headnode_dict(fixer_list): + """ Accepts a list of fixers and returns a dictionary + of head node type --> fixer list. """ + head_nodes = defaultdict(list) + for fixer in fixer_list: + if not fixer.pattern: + head_nodes[None].append(fixer) + continue + for t in get_head_types(fixer.pattern): + head_nodes[t].append(fixer) + return head_nodes + class RefactoringTool(object): @@ -114,6 +153,10 @@ class RefactoringTool(object): convert=pytree.convert, logger=self.logger) self.pre_order, self.post_order = self.get_fixers() + + self.pre_order = get_headnode_dict(self.pre_order) + self.post_order = get_headnode_dict(self.post_order) + self.files = [] # List of files that were or should be modified def get_fixers(self): @@ -286,7 +329,11 @@ class RefactoringTool(object): Returns: True if the tree was modified, False otherwise. """ - all_fixers = self.pre_order + self.post_order + # Two calls to chain are required because pre_order.values() + # will be a list of lists of fixers: + # [[<fixer ...>, <fixer ...>], [<fixer ...>]] + all_fixers = chain(chain(*self.pre_order.values()),\ + chain(*self.post_order.values())) for fixer in all_fixers: fixer.start_tree(tree, name) @@ -312,7 +359,7 @@ class RefactoringTool(object): if not fixers: return for node in traversal: - for fixer in fixers: + for fixer in fixers[node.type] + fixers[None]: results = fixer.match(node) if results: new = fixer.transform(node, results) |