diff options
Diffstat (limited to 'Lib/lib2to3/fixes/fix_imports.py')
| -rw-r--r-- | Lib/lib2to3/fixes/fix_imports.py | 64 |
1 files changed, 37 insertions, 27 deletions
diff --git a/Lib/lib2to3/fixes/fix_imports.py b/Lib/lib2to3/fixes/fix_imports.py index 38e868b..e48c4f0 100644 --- a/Lib/lib2to3/fixes/fix_imports.py +++ b/Lib/lib2to3/fixes/fix_imports.py @@ -1,9 +1,9 @@ """Fix incompatible imports and module references.""" -# Author: Collin Winter +# Authors: Collin Winter, Nick Edds # Local imports from .. import fixer_base -from ..fixer_util import Name, attr_chain, any, set +from ..fixer_util import Name, attr_chain MAPPING = {'StringIO': 'io', 'cStringIO': 'io', @@ -61,36 +61,49 @@ def alternates(members): def build_pattern(mapping=MAPPING): - mod_list = ' | '.join(["module='" + key + "'" for key in mapping.keys()]) - mod_name_list = ' | '.join(["module_name='" + key + "'" for key in mapping.keys()]) - yield """import_name< 'import' ((%s) + mod_list = ' | '.join(["module_name='%s'" % key for key in mapping]) + bare_names = alternates(mapping.keys()) + + yield """name_import=import_name< 'import' ((%s) | dotted_as_names< any* (%s) any* >) > """ % (mod_list, mod_list) yield """import_from< 'from' (%s) 'import' ['('] ( any | import_as_name< any 'as' any > | import_as_names< any* >) [')'] > - """ % mod_name_list + """ % mod_list yield """import_name< 'import' dotted_as_name< (%s) 'as' any > > - """ % mod_name_list - # Find usages of module members in code e.g. urllib.foo(bar) - yield """power< (%s) - trailer<'.' any > any* > - """ % mod_name_list - yield """bare_name=%s""" % alternates(mapping.keys()) + """ % mod_list + + # Find usages of module members in code e.g. thread.foo(bar) + yield "power< bare_with_attr=(%s) trailer<'.' any > any* >" % bare_names + class FixImports(fixer_base.BaseFix): - PATTERN = "|".join(build_pattern()) + order = "pre" # Pre-order tree traversal + # This is overridden in fix_imports2. mapping = MAPPING - # Don't match the node if it's within another match + def build_pattern(self): + return "|".join(build_pattern(self.mapping)) + + def compile_pattern(self): + # We override this, so MAPPING can be pragmatically altered and the + # changes will be reflected in PATTERN. + self.PATTERN = self.build_pattern() + super(FixImports, self).compile_pattern() + + # Don't match the node if it's within another match. def match(self, node): match = super(FixImports, self).match results = match(node) if results: - if any([match(obj) for obj in attr_chain(node, "parent")]): + # Module usage could be in the trailier of an attribute lookup, so + # we might have nested matches when "bare_with_attr" is present. + if "bare_with_attr" not in results and \ + any([match(obj) for obj in attr_chain(node, "parent")]): return False return results return False @@ -100,20 +113,17 @@ class FixImports(fixer_base.BaseFix): self.replace = {} def transform(self, node, results): - import_mod = results.get("module") - mod_name = results.get("module_name") - bare_name = results.get("bare_name") - - if import_mod or mod_name: - new_name = self.mapping[(import_mod or mod_name).value] - + import_mod = results.get("module_name") if import_mod: - self.replace[import_mod.value] = new_name + new_name = self.mapping[(import_mod or mod_name).value] + if "name_import" in results: + # If it's not a "from x import x, y" or "import x as y" import, + # marked its usage to be replaced. + self.replace[import_mod.value] = new_name import_mod.replace(Name(new_name, prefix=import_mod.get_prefix())) - elif mod_name: - mod_name.replace(Name(new_name, prefix=mod_name.get_prefix())) - elif bare_name: - bare_name = bare_name[0] + else: + # Replace usage of the module. + bare_name = results["bare_with_attr"][0] new_name = self.replace.get(bare_name.value) if new_name: bare_name.replace(Name(new_name, prefix=bare_name.get_prefix())) |
