summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/dis.py36
-rw-r--r--Lib/modulefinder.py33
-rw-r--r--Lib/test/test_dis.py33
3 files changed, 76 insertions, 26 deletions
diff --git a/Lib/dis.py b/Lib/dis.py
index 66487dc..a073572 100644
--- a/Lib/dis.py
+++ b/Lib/dis.py
@@ -535,6 +535,42 @@ def findlinestarts(code):
yield start, line
return
+def _find_imports(co):
+ """Find import statements in the code
+
+ Generate triplets (name, level, fromlist) where
+ name is the imported module and level, fromlist are
+ the corresponding args to __import__.
+ """
+ IMPORT_NAME = opmap['IMPORT_NAME']
+ LOAD_CONST = opmap['LOAD_CONST']
+
+ consts = co.co_consts
+ names = co.co_names
+ opargs = [(op, arg) for _, op, arg in _unpack_opargs(co.co_code)
+ if op != EXTENDED_ARG]
+ for i, (op, oparg) in enumerate(opargs):
+ if (op == IMPORT_NAME and i >= 2
+ and opargs[i-1][0] == opargs[i-2][0] == LOAD_CONST):
+ level = consts[opargs[i-2][1]]
+ fromlist = consts[opargs[i-1][1]]
+ yield (names[oparg], level, fromlist)
+
+def _find_store_names(co):
+ """Find names of variables which are written in the code
+
+ Generate sequence of strings
+ """
+ STORE_OPS = {
+ opmap['STORE_NAME'],
+ opmap['STORE_GLOBAL']
+ }
+
+ names = co.co_names
+ for _, op, arg in _unpack_opargs(co.co_code):
+ if op in STORE_OPS:
+ yield names[arg]
+
class Bytecode:
"""The bytecode operations of a piece of code
diff --git a/Lib/modulefinder.py b/Lib/modulefinder.py
index cb455f4..a0a020f 100644
--- a/Lib/modulefinder.py
+++ b/Lib/modulefinder.py
@@ -8,14 +8,6 @@ import os
import io
import sys
-
-LOAD_CONST = dis.opmap['LOAD_CONST']
-IMPORT_NAME = dis.opmap['IMPORT_NAME']
-STORE_NAME = dis.opmap['STORE_NAME']
-STORE_GLOBAL = dis.opmap['STORE_GLOBAL']
-STORE_OPS = STORE_NAME, STORE_GLOBAL
-EXTENDED_ARG = dis.EXTENDED_ARG
-
# Old imp constants:
_SEARCH_ERROR = 0
@@ -394,24 +386,13 @@ class ModuleFinder:
def scan_opcodes(self, co):
# Scan the code, and yield 'interesting' opcode combinations
- code = co.co_code
- names = co.co_names
- consts = co.co_consts
- opargs = [(op, arg) for _, op, arg in dis._unpack_opargs(code)
- if op != EXTENDED_ARG]
- for i, (op, oparg) in enumerate(opargs):
- if op in STORE_OPS:
- yield "store", (names[oparg],)
- continue
- if (op == IMPORT_NAME and i >= 2
- and opargs[i-1][0] == opargs[i-2][0] == LOAD_CONST):
- level = consts[opargs[i-2][1]]
- fromlist = consts[opargs[i-1][1]]
- if level == 0: # absolute import
- yield "absolute_import", (fromlist, names[oparg])
- else: # relative import
- yield "relative_import", (level, fromlist, names[oparg])
- continue
+ for name in dis._find_store_names(co):
+ yield "store", (name,)
+ for name, level, fromlist in dis._find_imports(co):
+ if level == 0: # absolute import
+ yield "absolute_import", (fromlist, name)
+ else: # relative import
+ yield "relative_import", (level, fromlist, name)
def scan_code(self, co, m):
code = co.co_code
diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py
index b97e41c..a140a89 100644
--- a/Lib/test/test_dis.py
+++ b/Lib/test/test_dis.py
@@ -1326,5 +1326,38 @@ class TestBytecodeTestCase(BytecodeTestCase):
with self.assertRaises(AssertionError):
self.assertNotInBytecode(code, "LOAD_CONST", 1)
+class TestFinderMethods(unittest.TestCase):
+ def test__find_imports(self):
+ cases = [
+ ("import a.b.c", ('a.b.c', 0, None)),
+ ("from a.b import c", ('a.b', 0, ('c',))),
+ ("from a.b import c as d", ('a.b', 0, ('c',))),
+ ("from a.b import *", ('a.b', 0, ('*',))),
+ ("from ...a.b import c as d", ('a.b', 3, ('c',))),
+ ("from ..a.b import c as d, e as f", ('a.b', 2, ('c', 'e'))),
+ ("from ..a.b import *", ('a.b', 2, ('*',))),
+ ]
+ for src, expected in cases:
+ with self.subTest(src=src):
+ code = compile(src, "<string>", "exec")
+ res = tuple(dis._find_imports(code))
+ self.assertEqual(len(res), 1)
+ self.assertEqual(res[0], expected)
+
+ def test__find_store_names(self):
+ cases = [
+ ("x+y", ()),
+ ("x=y=1", ('x', 'y')),
+ ("x+=y", ('x',)),
+ ("global x\nx=y=1", ('x', 'y')),
+ ("global x\nz=x", ('z',)),
+ ]
+ for src, expected in cases:
+ with self.subTest(src=src):
+ code = compile(src, "<string>", "exec")
+ res = tuple(dis._find_store_names(code))
+ self.assertEqual(res, expected)
+
+
if __name__ == "__main__":
unittest.main()