summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_import.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_import.py')
-rw-r--r--Lib/test/test_import.py95
1 files changed, 94 insertions, 1 deletions
diff --git a/Lib/test/test_import.py b/Lib/test/test_import.py
index 6598d4e..145ff9a 100644
--- a/Lib/test/test_import.py
+++ b/Lib/test/test_import.py
@@ -6,6 +6,7 @@ import sys
import py_compile
import warnings
import imp
+import marshal
from test.support import unlink, TESTFN, unload, run_unittest
@@ -230,6 +231,98 @@ class ImportTest(unittest.TestCase):
else:
self.fail("import by path didn't raise an exception")
+class TestPycRewriting(unittest.TestCase):
+ # Test that the `co_filename` attribute on code objects always points
+ # to the right file, even when various things happen (e.g. both the .py
+ # and the .pyc file are renamed).
+
+ module_name = "unlikely_module_name"
+ module_source = """
+import sys
+code_filename = sys._getframe().f_code.co_filename
+module_filename = __file__
+constant = 1
+def func():
+ pass
+func_filename = func.__code__.co_filename
+"""
+ dir_name = os.path.abspath(TESTFN)
+ file_name = os.path.join(dir_name, module_name) + os.extsep + "py"
+ compiled_name = file_name + ("c" if __debug__ else "o")
+
+ def setUp(self):
+ self.sys_path = sys.path[:]
+ self.orig_module = sys.modules.pop(self.module_name, None)
+ os.mkdir(self.dir_name)
+ with open(self.file_name, "w") as f:
+ f.write(self.module_source)
+ sys.path.insert(0, self.dir_name)
+
+ def tearDown(self):
+ sys.path[:] = self.sys_path
+ if self.orig_module is not None:
+ sys.modules[self.module_name] = self.orig_module
+ else:
+ del sys.modules[self.module_name]
+ for file_name in self.file_name, self.compiled_name:
+ if os.path.exists(file_name):
+ os.remove(file_name)
+ if os.path.exists(self.dir_name):
+ shutil.rmtree(self.dir_name)
+
+ def import_module(self):
+ ns = globals()
+ __import__(self.module_name, ns, ns)
+ return sys.modules[self.module_name]
+
+ def test_basics(self):
+ mod = self.import_module()
+ self.assertEqual(mod.module_filename, self.file_name)
+ self.assertEqual(mod.code_filename, self.file_name)
+ self.assertEqual(mod.func_filename, self.file_name)
+ del sys.modules[self.module_name]
+ mod = self.import_module()
+ self.assertEqual(mod.module_filename, self.file_name)
+ self.assertEqual(mod.code_filename, self.file_name)
+ self.assertEqual(mod.func_filename, self.file_name)
+
+ def test_incorrect_code_name(self):
+ py_compile.compile(self.file_name, dfile="another_module.py")
+ mod = self.import_module()
+ self.assertEqual(mod.module_filename, self.file_name)
+ self.assertEqual(mod.code_filename, self.file_name)
+ self.assertEqual(mod.func_filename, self.file_name)
+
+ def test_module_without_source(self):
+ target = "another_module.py"
+ py_compile.compile(self.file_name, dfile=target)
+ os.remove(self.file_name)
+ mod = self.import_module()
+ self.assertEqual(mod.module_filename, self.compiled_name)
+ self.assertEqual(mod.code_filename, target)
+ self.assertEqual(mod.func_filename, target)
+
+ def test_foreign_code(self):
+ py_compile.compile(self.file_name)
+ with open(self.compiled_name, "rb") as f:
+ header = f.read(8)
+ code = marshal.load(f)
+ constants = list(code.co_consts)
+ foreign_code = test_main.__code__
+ pos = constants.index(1)
+ constants[pos] = foreign_code
+ code = type(code)(code.co_argcount, code.co_kwonlyargcount,
+ code.co_nlocals, code.co_stacksize,
+ code.co_flags, code.co_code, tuple(constants),
+ code.co_names, code.co_varnames, code.co_filename,
+ code.co_name, code.co_firstlineno, code.co_lnotab,
+ code.co_freevars, code.co_cellvars)
+ with open(self.compiled_name, "wb") as f:
+ f.write(header)
+ marshal.dump(code, f)
+ mod = self.import_module()
+ self.assertEqual(mod.constant.co_filename, foreign_code.co_filename)
+
class PathsTests(unittest.TestCase):
SAMPLES = ('test', 'test\u00e4\u00f6\u00fc\u00df', 'test\u00e9\u00e8',
'test\u00b0\u00b3\u00b2')
@@ -288,7 +381,7 @@ class RelativeImport(unittest.TestCase):
self.assertRaises(ValueError, check_relative)
def test_main(verbose=None):
- run_unittest(ImportTest, PathsTests, RelativeImport)
+ run_unittest(ImportTest, TestPycRewriting, PathsTests, RelativeImport)
if __name__ == '__main__':
# test needs to be a package, so we can do relative import