summaryrefslogtreecommitdiffstats
path: root/Lib/lib2to3/fixes/fix_imports.py
blob: fc1d7bb0fe1d86b818d38d7d6b59e5a10d9c808f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Fix incompatible imports and module references."""
# Author: Collin Winter

# Local imports
from .. import fixer_base
from ..fixer_util import Name, attr_chain, any, set

MAPPING = {'StringIO':  'io',
           'cStringIO': 'io',
           'cPickle': 'pickle',
           '__builtin__' : 'builtins',
           'copy_reg': 'copyreg',
           'Queue': 'queue',
           'SocketServer': 'socketserver',
           'ConfigParser': 'configparser',
           'repr': 'reprlib',
           'FileDialog': 'tkinter.filedialog',
           'tkFileDialog': 'tkinter.filedialog',
           'SimpleDialog': 'tkinter.simpledialog',
           'tkSimpleDialog': 'tkinter.simpledialog',
           'tkColorChooser': 'tkinter.colorchooser',
           'tkCommonDialog': 'tkinter.commondialog',
           'Dialog': 'tkinter.dialog',
           'Tkdnd': 'tkinter.dnd',
           'tkFont': 'tkinter.font',
           'tkMessageBox': 'tkinter.messagebox',
           'ScrolledText': 'tkinter.scrolledtext',
           'turtle': 'tkinter.turtle',
           'Tkconstants': 'tkinter.constants',
           'Tix': 'tkinter.tix',
           'Tkinter': 'tkinter',
           'markupbase': '_markupbase',
           '_winreg': 'winreg',
           'thread': '_thread',
           'dummy_thread': '_dummy_thread',
           # anydbm and whichdb are handled by fix_imports2
           'dbhash': 'dbm.bsd',
           'dumbdbm': 'dbm.dumb',
           'dbm': 'dbm.ndbm',
           'gdbm': 'dbm.gnu',
           'xmlrpclib': 'xmlrpc.client',
           'DocXMLRPCServer': 'xmlrpc.server',
           'SimpleXMLRPCServer': 'xmlrpc.server',
           'httplib': 'http.client',
           'Cookie': 'http.cookies',
           'cookielib': 'http.cookiejar',
           'BaseHTTPServer': 'http.server',
           'SimpleHTTPServer': 'http.server',
           'CGIHTTPServer': 'http.server',
           #'test.test_support': 'test.support',
           'commands': 'subprocess',
           'UserString' : 'collections',
           'UserList' : 'collections',
           'urlparse' : 'urllib.parse',
           'robotparser' : 'urllib.robotparser',
}


def alternates(members):
    return "(" + "|".join(map(repr, members)) + ")"


def build_pattern(mapping=MAPPING):
    bare = set()
    for old_module, new_module in mapping.items():
        bare.add(old_module)
        yield """import_name< 'import' (module=%r
                              | dotted_as_names< any* module=%r any* >) >
              """ % (old_module, old_module)
        yield """import_from< 'from' module_name=%r 'import'
                  ( any | import_as_name< any 'as' any > |
                    import_as_names< any* >) >
              """ % old_module
        yield """import_name< 'import'
                              dotted_as_name< module_name=%r 'as' any > >
              """ % old_module
        # Find usages of module members in code e.g. urllib.foo(bar)
        yield """power< module_name=%r
                 trailer<'.' any > any* >
              """ % old_module
    yield """bare_name=%s""" % alternates(bare)

class FixImports(fixer_base.BaseFix):
    PATTERN = "|".join(build_pattern())
    order = "pre" # Pre-order tree traversal

    mapping = MAPPING

    # Don't match the node if it's within another match
    def match(self, node):
        match = super(FixImports, self).match
        results = match(node)
        if results:
            if any([match(obj) for obj in attr_chain(node, "parent")]):
                return False
            return results
        return False

    def start_tree(self, tree, filename):
        super(FixImports, self).start_tree(tree, filename)
        self.replace = {}

    def transform(self, node, results):
        import_mod = results.get("module")
        mod_name = results.get("module_name")
        bare_name = results.get("bare_name")

        if import_mod or mod_name:
            new_name = self.mapping[(import_mod or mod_name).value]

        if import_mod:
            self.replace[import_mod.value] = new_name
            import_mod.replace(Name(new_name, prefix=import_mod.get_prefix()))
        elif mod_name:
            mod_name.replace(Name(new_name, prefix=mod_name.get_prefix()))
        elif bare_name:
            bare_name = bare_name[0]
            new_name = self.replace.get(bare_name.value)
            if new_name:
                bare_name.replace(Name(new_name, prefix=bare_name.get_prefix()))