diff options
Diffstat (limited to 'Lib/test/test_import.py')
-rw-r--r-- | Lib/test/test_import.py | 95 |
1 files changed, 94 insertions, 1 deletions
diff --git a/Lib/test/test_import.py b/Lib/test/test_import.py index 6598d4e..145ff9a 100644 --- a/Lib/test/test_import.py +++ b/Lib/test/test_import.py @@ -6,6 +6,7 @@ import sys import py_compile import warnings import imp +import marshal from test.support import unlink, TESTFN, unload, run_unittest @@ -230,6 +231,98 @@ class ImportTest(unittest.TestCase): else: self.fail("import by path didn't raise an exception") +class TestPycRewriting(unittest.TestCase): + # Test that the `co_filename` attribute on code objects always points + # to the right file, even when various things happen (e.g. both the .py + # and the .pyc file are renamed). + + module_name = "unlikely_module_name" + module_source = """ +import sys +code_filename = sys._getframe().f_code.co_filename +module_filename = __file__ +constant = 1 +def func(): + pass +func_filename = func.__code__.co_filename +""" + dir_name = os.path.abspath(TESTFN) + file_name = os.path.join(dir_name, module_name) + os.extsep + "py" + compiled_name = file_name + ("c" if __debug__ else "o") + + def setUp(self): + self.sys_path = sys.path[:] + self.orig_module = sys.modules.pop(self.module_name, None) + os.mkdir(self.dir_name) + with open(self.file_name, "w") as f: + f.write(self.module_source) + sys.path.insert(0, self.dir_name) + + def tearDown(self): + sys.path[:] = self.sys_path + if self.orig_module is not None: + sys.modules[self.module_name] = self.orig_module + else: + del sys.modules[self.module_name] + for file_name in self.file_name, self.compiled_name: + if os.path.exists(file_name): + os.remove(file_name) + if os.path.exists(self.dir_name): + shutil.rmtree(self.dir_name) + + def import_module(self): + ns = globals() + __import__(self.module_name, ns, ns) + return sys.modules[self.module_name] + + def test_basics(self): + mod = self.import_module() + self.assertEqual(mod.module_filename, self.file_name) + self.assertEqual(mod.code_filename, self.file_name) + self.assertEqual(mod.func_filename, self.file_name) + del sys.modules[self.module_name] + mod = self.import_module() + self.assertEqual(mod.module_filename, self.file_name) + self.assertEqual(mod.code_filename, self.file_name) + self.assertEqual(mod.func_filename, self.file_name) + + def test_incorrect_code_name(self): + py_compile.compile(self.file_name, dfile="another_module.py") + mod = self.import_module() + self.assertEqual(mod.module_filename, self.file_name) + self.assertEqual(mod.code_filename, self.file_name) + self.assertEqual(mod.func_filename, self.file_name) + + def test_module_without_source(self): + target = "another_module.py" + py_compile.compile(self.file_name, dfile=target) + os.remove(self.file_name) + mod = self.import_module() + self.assertEqual(mod.module_filename, self.compiled_name) + self.assertEqual(mod.code_filename, target) + self.assertEqual(mod.func_filename, target) + + def test_foreign_code(self): + py_compile.compile(self.file_name) + with open(self.compiled_name, "rb") as f: + header = f.read(8) + code = marshal.load(f) + constants = list(code.co_consts) + foreign_code = test_main.__code__ + pos = constants.index(1) + constants[pos] = foreign_code + code = type(code)(code.co_argcount, code.co_kwonlyargcount, + code.co_nlocals, code.co_stacksize, + code.co_flags, code.co_code, tuple(constants), + code.co_names, code.co_varnames, code.co_filename, + code.co_name, code.co_firstlineno, code.co_lnotab, + code.co_freevars, code.co_cellvars) + with open(self.compiled_name, "wb") as f: + f.write(header) + marshal.dump(code, f) + mod = self.import_module() + self.assertEqual(mod.constant.co_filename, foreign_code.co_filename) + class PathsTests(unittest.TestCase): SAMPLES = ('test', 'test\u00e4\u00f6\u00fc\u00df', 'test\u00e9\u00e8', 'test\u00b0\u00b3\u00b2') @@ -288,7 +381,7 @@ class RelativeImport(unittest.TestCase): self.assertRaises(ValueError, check_relative) def test_main(verbose=None): - run_unittest(ImportTest, PathsTests, RelativeImport) + run_unittest(ImportTest, TestPycRewriting, PathsTests, RelativeImport) if __name__ == '__main__': # test needs to be a package, so we can do relative import |