diff options
Diffstat (limited to 'Lib/lib2to3/fixer_util.py')
-rw-r--r-- | Lib/lib2to3/fixer_util.py | 74 |
1 files changed, 67 insertions, 7 deletions
diff --git a/Lib/lib2to3/fixer_util.py b/Lib/lib2to3/fixer_util.py index ea394e8..0c485f0 100644 --- a/Lib/lib2to3/fixer_util.py +++ b/Lib/lib2to3/fixer_util.py @@ -158,6 +158,9 @@ def is_list(node): ### Misc ########################################################### +def parenthesize(node): + return Node(syms.atom, [LParen(), node, RParen()]) + consuming_calls = set(["sorted", "list", "set", "any", "all", "tuple", "sum", "min", "max"]) @@ -232,20 +235,77 @@ def make_suite(node): suite.parent = parent return suite -def does_tree_import(package, name, node): - """ Returns true if name is imported from package at the - top level of the tree which node belongs to. - To cover the case of an import like 'import foo', use - Null for the package and 'foo' for the name. """ +def find_root(node): + """Find the top level namespace.""" # Scamper up to the top level namespace while node.type != syms.file_input: assert node.parent, "Tree is insane! root found before "\ "file_input node was found." node = node.parent + return node - binding = find_binding(name, node, package) +def does_tree_import(package, name, node): + """ Returns true if name is imported from package at the + top level of the tree which node belongs to. + To cover the case of an import like 'import foo', use + None for the package and 'foo' for the name. """ + binding = find_binding(name, find_root(node), package) return bool(binding) +def is_import(node): + """Returns true if the node is an import statement.""" + return node.type in (syms.import_name, syms.import_from) + +def touch_import(package, name, node): + """ Works like `does_tree_import` but adds an import statement + if it was not imported. """ + def is_import_stmt(node): + return node.type == syms.simple_stmt and node.children and \ + is_import(node.children[0]) + + root = find_root(node) + + if does_tree_import(package, name, root): + return + + add_newline_before = False + + # figure out where to insert the new import. First try to find + # the first import and then skip to the last one. + insert_pos = offset = 0 + for idx, node in enumerate(root.children): + if not is_import_stmt(node): + continue + for offset, node2 in enumerate(root.children[idx:]): + if not is_import_stmt(node2): + break + insert_pos = idx + offset + break + + # if there are no imports where we can insert, find the docstring. + # if that also fails, we stick to the beginning of the file + if insert_pos == 0: + for idx, node in enumerate(root.children): + if node.type == syms.simple_stmt and node.children and \ + node.children[0].type == token.STRING: + insert_pos = idx + 1 + add_newline_before + break + + if package is None: + import_ = Node(syms.import_name, [ + Leaf(token.NAME, 'import'), + Leaf(token.NAME, name, prefix=' ') + ]) + else: + import_ = FromImport(package, [Leaf(token.NAME, name, prefix=' ')]) + + children = [import_, Newline()] + if add_newline_before: + children.insert(0, Newline()) + root.insert_child(insert_pos, Node(syms.simple_stmt, children)) + + _def_syms = set([syms.classdef, syms.funcdef]) def find_binding(name, node, package=None): """ Returns the node which binds variable name, otherwise None. @@ -285,7 +345,7 @@ def find_binding(name, node, package=None): if ret: if not package: return ret - if ret.type in (syms.import_name, syms.import_from): + if is_import(ret): return ret return None |