diff options
Diffstat (limited to 'Lib/lib2to3/fixes/fix_import.py')
-rw-r--r-- | Lib/lib2to3/fixes/fix_import.py | 87 |
1 files changed, 55 insertions, 32 deletions
diff --git a/Lib/lib2to3/fixes/fix_import.py b/Lib/lib2to3/fixes/fix_import.py index c065f70..4c75133 100644 --- a/Lib/lib2to3/fixes/fix_import.py +++ b/Lib/lib2to3/fixes/fix_import.py @@ -13,55 +13,78 @@ Becomes: # Local imports from .. import fixer_base from os.path import dirname, join, exists, pathsep -from ..fixer_util import FromImport, syms +from ..fixer_util import FromImport, syms, token + + +def traverse_imports(names): + """ + Walks over all the names imported in a dotted_as_names node. + """ + pending = [names] + while pending: + node = pending.pop() + if node.type == token.NAME: + yield node.value + elif node.type == syms.dotted_name: + yield "".join([ch.value for ch in node.children]) + elif node.type == syms.dotted_as_name: + pending.append(node.children[0]) + elif node.type == syms.dotted_as_names: + pending.extend(node.children[::-2]) + else: + raise AssertionError("unkown node type") + class FixImport(fixer_base.BaseFix): PATTERN = """ - import_from< type='from' imp=any 'import' ['('] any [')'] > + import_from< 'from' imp=any 'import' ['('] any [')'] > | - import_name< type='import' imp=any > + import_name< 'import' imp=any > """ def transform(self, node, results): imp = results['imp'] - mod_name = str(imp.children[0] if imp.type == syms.dotted_as_name \ - else imp) - - if str(imp).startswith('.'): - # Already a new-style import - return - - if not probably_a_local_import(str(mod_name), self.filename): - # I guess this is a global import -- skip it! - return - - if results['type'].value == 'from': + if node.type == syms.import_from: # Some imps are top-level (eg: 'import ham') # some are first level (eg: 'import ham.eggs') # some are third level (eg: 'import ham.eggs as spam') # Hence, the loop while not hasattr(imp, 'value'): imp = imp.children[0] - imp.value = "." + imp.value - node.changed() + if self.probably_a_local_import(imp.value): + imp.value = "." + imp.value + imp.changed() + return node else: - new = FromImport('.', getattr(imp, 'content', None) or [imp]) + have_local = False + have_absolute = False + for mod_name in traverse_imports(imp): + if self.probably_a_local_import(mod_name): + have_local = True + else: + have_absolute = True + if have_absolute: + if have_local: + # We won't handle both sibling and absolute imports in the + # same statement at the moment. + self.warning(node, "absolute and local imports together") + return + + new = FromImport('.', [imp]) new.set_prefix(node.get_prefix()) - node = new - return node + return new -def probably_a_local_import(imp_name, file_path): - # Must be stripped because the right space is included by the parser - imp_name = imp_name.split('.', 1)[0].strip() - base_path = dirname(file_path) - base_path = join(base_path, imp_name) - # If there is no __init__.py next to the file its not in a package - # so can't be a relative import. - if not exists(join(dirname(base_path), '__init__.py')): + def probably_a_local_import(self, imp_name): + imp_name = imp_name.split('.', 1)[0] + base_path = dirname(self.filename) + base_path = join(base_path, imp_name) + # If there is no __init__.py next to the file its not in a package + # so can't be a relative import. + if not exists(join(dirname(base_path), '__init__.py')): + return False + for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']: + if exists(base_path + ext): + return True return False - for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']: - if exists(base_path + ext): - return True - return False |