summaryrefslogtreecommitdiffstats
path: root/Lib/lib2to3/refactor.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/lib2to3/refactor.py')
-rwxr-xr-xLib/lib2to3/refactor.py51
1 files changed, 49 insertions, 2 deletions
diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py
index ae7d3cdf..ae9d7d0 100755
--- a/Lib/lib2to3/refactor.py
+++ b/Lib/lib2to3/refactor.py
@@ -18,6 +18,8 @@ import sys
import difflib
import optparse
import logging
+from collections import defaultdict
+from itertools import chain
# Local imports
from .pgen2 import driver
@@ -96,6 +98,43 @@ def get_all_fix_names():
fix_names.sort()
return fix_names
+def get_head_types(pat):
+ """ Accepts a pytree Pattern Node and returns a set
+ of the pattern types which will match first. """
+
+ if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)):
+ # NodePatters must either have no type and no content
+ # or a type and content -- so they don't get any farther
+ # Always return leafs
+ return set([pat.type])
+
+ if isinstance(pat, pytree.NegatedPattern):
+ if pat.content:
+ return get_head_types(pat.content)
+ return set([None]) # Negated Patterns don't have a type
+
+ if isinstance(pat, pytree.WildcardPattern):
+ # Recurse on each node in content
+ r = set()
+ for p in pat.content:
+ for x in p:
+ r.update(get_head_types(x))
+ return r
+
+ raise Exception("Oh no! I don't understand pattern %s" %(pat))
+
+def get_headnode_dict(fixer_list):
+ """ Accepts a list of fixers and returns a dictionary
+ of head node type --> fixer list. """
+ head_nodes = defaultdict(list)
+ for fixer in fixer_list:
+ if not fixer.pattern:
+ head_nodes[None].append(fixer)
+ continue
+ for t in get_head_types(fixer.pattern):
+ head_nodes[t].append(fixer)
+ return head_nodes
+
class RefactoringTool(object):
@@ -114,6 +153,10 @@ class RefactoringTool(object):
convert=pytree.convert,
logger=self.logger)
self.pre_order, self.post_order = self.get_fixers()
+
+ self.pre_order = get_headnode_dict(self.pre_order)
+ self.post_order = get_headnode_dict(self.post_order)
+
self.files = [] # List of files that were or should be modified
def get_fixers(self):
@@ -286,7 +329,11 @@ class RefactoringTool(object):
Returns:
True if the tree was modified, False otherwise.
"""
- all_fixers = self.pre_order + self.post_order
+ # Two calls to chain are required because pre_order.values()
+ # will be a list of lists of fixers:
+ # [[<fixer ...>, <fixer ...>], [<fixer ...>]]
+ all_fixers = chain(chain(*self.pre_order.values()),\
+ chain(*self.post_order.values()))
for fixer in all_fixers:
fixer.start_tree(tree, name)
@@ -312,7 +359,7 @@ class RefactoringTool(object):
if not fixers:
return
for node in traversal:
- for fixer in fixers:
+ for fixer in fixers[node.type] + fixers[None]:
results = fixer.match(node)
if results:
new = fixer.transform(node, results)