diff options
-rw-r--r-- | Lib/dis.py | 36 | ||||
-rw-r--r-- | Lib/modulefinder.py | 33 | ||||
-rw-r--r-- | Lib/test/test_dis.py | 33 |
3 files changed, 76 insertions, 26 deletions
@@ -535,6 +535,42 @@ def findlinestarts(code): yield start, line return +def _find_imports(co): + """Find import statements in the code + + Generate triplets (name, level, fromlist) where + name is the imported module and level, fromlist are + the corresponding args to __import__. + """ + IMPORT_NAME = opmap['IMPORT_NAME'] + LOAD_CONST = opmap['LOAD_CONST'] + + consts = co.co_consts + names = co.co_names + opargs = [(op, arg) for _, op, arg in _unpack_opargs(co.co_code) + if op != EXTENDED_ARG] + for i, (op, oparg) in enumerate(opargs): + if (op == IMPORT_NAME and i >= 2 + and opargs[i-1][0] == opargs[i-2][0] == LOAD_CONST): + level = consts[opargs[i-2][1]] + fromlist = consts[opargs[i-1][1]] + yield (names[oparg], level, fromlist) + +def _find_store_names(co): + """Find names of variables which are written in the code + + Generate sequence of strings + """ + STORE_OPS = { + opmap['STORE_NAME'], + opmap['STORE_GLOBAL'] + } + + names = co.co_names + for _, op, arg in _unpack_opargs(co.co_code): + if op in STORE_OPS: + yield names[arg] + class Bytecode: """The bytecode operations of a piece of code diff --git a/Lib/modulefinder.py b/Lib/modulefinder.py index cb455f4..a0a020f 100644 --- a/Lib/modulefinder.py +++ b/Lib/modulefinder.py @@ -8,14 +8,6 @@ import os import io import sys - -LOAD_CONST = dis.opmap['LOAD_CONST'] -IMPORT_NAME = dis.opmap['IMPORT_NAME'] -STORE_NAME = dis.opmap['STORE_NAME'] -STORE_GLOBAL = dis.opmap['STORE_GLOBAL'] -STORE_OPS = STORE_NAME, STORE_GLOBAL -EXTENDED_ARG = dis.EXTENDED_ARG - # Old imp constants: _SEARCH_ERROR = 0 @@ -394,24 +386,13 @@ class ModuleFinder: def scan_opcodes(self, co): # Scan the code, and yield 'interesting' opcode combinations - code = co.co_code - names = co.co_names - consts = co.co_consts - opargs = [(op, arg) for _, op, arg in dis._unpack_opargs(code) - if op != EXTENDED_ARG] - for i, (op, oparg) in enumerate(opargs): - if op in STORE_OPS: - yield "store", (names[oparg],) - continue - if (op == IMPORT_NAME and i >= 2 - and opargs[i-1][0] == opargs[i-2][0] == LOAD_CONST): - level = consts[opargs[i-2][1]] - fromlist = consts[opargs[i-1][1]] - if level == 0: # absolute import - yield "absolute_import", (fromlist, names[oparg]) - else: # relative import - yield "relative_import", (level, fromlist, names[oparg]) - continue + for name in dis._find_store_names(co): + yield "store", (name,) + for name, level, fromlist in dis._find_imports(co): + if level == 0: # absolute import + yield "absolute_import", (fromlist, name) + else: # relative import + yield "relative_import", (level, fromlist, name) def scan_code(self, co, m): code = co.co_code diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py index b97e41c..a140a89 100644 --- a/Lib/test/test_dis.py +++ b/Lib/test/test_dis.py @@ -1326,5 +1326,38 @@ class TestBytecodeTestCase(BytecodeTestCase): with self.assertRaises(AssertionError): self.assertNotInBytecode(code, "LOAD_CONST", 1) +class TestFinderMethods(unittest.TestCase): + def test__find_imports(self): + cases = [ + ("import a.b.c", ('a.b.c', 0, None)), + ("from a.b import c", ('a.b', 0, ('c',))), + ("from a.b import c as d", ('a.b', 0, ('c',))), + ("from a.b import *", ('a.b', 0, ('*',))), + ("from ...a.b import c as d", ('a.b', 3, ('c',))), + ("from ..a.b import c as d, e as f", ('a.b', 2, ('c', 'e'))), + ("from ..a.b import *", ('a.b', 2, ('*',))), + ] + for src, expected in cases: + with self.subTest(src=src): + code = compile(src, "<string>", "exec") + res = tuple(dis._find_imports(code)) + self.assertEqual(len(res), 1) + self.assertEqual(res[0], expected) + + def test__find_store_names(self): + cases = [ + ("x+y", ()), + ("x=y=1", ('x', 'y')), + ("x+=y", ('x',)), + ("global x\nx=y=1", ('x', 'y')), + ("global x\nz=x", ('z',)), + ] + for src, expected in cases: + with self.subTest(src=src): + code = compile(src, "<string>", "exec") + res = tuple(dis._find_store_names(code)) + self.assertEqual(res, expected) + + if __name__ == "__main__": unittest.main() |