summaryrefslogtreecommitdiffstats
path: root/Lib/lib2to3/fixes/fix_exitfunc.py
blob: b29b4640b25774c736466ce96fad6acd98b13222 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
Convert use of sys.exitfunc to use the atexit module.
"""

# Author: Benjamin Peterson

from lib2to3 import pytree, fixer_base
from lib2to3.fixer_util import Name, Attr, Call, Comma, Newline, syms


class FixExitfunc(fixer_base.BaseFix):

    PATTERN = """
              (
                  sys_import=import_name<'import'
                      ('sys'
                      |
                      dotted_as_names< (any ',')* 'sys' (',' any)* >
                      )
                  >
              |
                  expr_stmt<
                      power< 'sys' trailer< '.' 'exitfunc' > >
                  '=' func=any >
              )
              """

    def __init__(self, *args):
        super(FixExitfunc, self).__init__(*args)

    def start_tree(self, tree, filename):
        super(FixExitfunc, self).start_tree(tree, filename)
        self.sys_import = None

    def transform(self, node, results):
        # First, find a the sys import. We'll just hope it's global scope.
        if "sys_import" in results:
            if self.sys_import is None:
                self.sys_import = results["sys_import"]
            return

        func = results["func"].clone()
        func.prefix = ""
        register = pytree.Node(syms.power,
                               Attr(Name("atexit"), Name("register"))
                               )
        call = Call(register, [func], node.prefix)
        node.replace(call)

        if self.sys_import is None:
            # That's interesting.
            self.warning(node, "Can't find sys import; Please add an atexit "
                             "import at the top of your file.")
            return

        # Now add an atexit import after the sys import.
        names = self.sys_import.children[1]
        if names.type == syms.dotted_as_names:
            names.append_child(Comma())
            names.append_child(Name("atexit", " "))
        else:
            containing_stmt = self.sys_import.parent
            position = containing_stmt.children.index(self.sys_import)
            stmt_container = containing_stmt.parent
            new_import = pytree.Node("import_name",
                              [Name("import"), Name("atexit", " ")]
                              )
            new = pytree.Node("simple_stmt", [new_import])
            containing_stmt.insert_child(position + 1, Newline())
            containing_stmt.insert_child(position + 2, new)