diff options
Diffstat (limited to 'Lib/test')
-rw-r--r-- | Lib/test/test_compileall.py | 21 | ||||
-rw-r--r-- | Lib/test/test_imp.py | 21 | ||||
-rw-r--r-- | Lib/test/test_import/__init__.py | 2 | ||||
-rw-r--r-- | Lib/test/test_importlib/source/test_file_loader.py | 210 | ||||
-rw-r--r-- | Lib/test/test_importlib/test_abc.py | 2 | ||||
-rw-r--r-- | Lib/test/test_py_compile.py | 18 | ||||
-rw-r--r-- | Lib/test/test_zipimport.py | 18 |
7 files changed, 270 insertions, 22 deletions
diff --git a/Lib/test/test_compileall.py b/Lib/test/test_compileall.py index 2356efc..38d7b99 100644 --- a/Lib/test/test_compileall.py +++ b/Lib/test/test_compileall.py @@ -48,9 +48,9 @@ class CompileallTests(unittest.TestCase): def data(self): with open(self.bc_path, 'rb') as file: - data = file.read(8) + data = file.read(12) mtime = int(os.stat(self.source_path).st_mtime) - compare = struct.pack('<4sl', importlib.util.MAGIC_NUMBER, mtime) + compare = struct.pack('<4sll', importlib.util.MAGIC_NUMBER, 0, mtime) return data, compare @unittest.skipUnless(hasattr(os, 'stat'), 'test needs os.stat()') @@ -70,8 +70,8 @@ class CompileallTests(unittest.TestCase): def test_mtime(self): # Test a change in mtime leads to a new .pyc. - self.recreation_check(struct.pack('<4sl', importlib.util.MAGIC_NUMBER, - 1)) + self.recreation_check(struct.pack('<4sll', importlib.util.MAGIC_NUMBER, + 0, 1)) def test_magic_number(self): # Test a change in mtime leads to a new .pyc. @@ -519,6 +519,19 @@ class CommandLineTests(unittest.TestCase): out = self.assertRunOK('badfilename') self.assertRegex(out, b"Can't list 'badfilename'") + def test_pyc_invalidation_mode(self): + script_helper.make_script(self.pkgdir, 'f1', '') + pyc = importlib.util.cache_from_source( + os.path.join(self.pkgdir, 'f1.py')) + self.assertRunOK('--invalidation-mode=checked-hash', self.pkgdir) + with open(pyc, 'rb') as fp: + data = fp.read() + self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b11) + self.assertRunOK('--invalidation-mode=unchecked-hash', self.pkgdir) + with open(pyc, 'rb') as fp: + data = fp.read() + self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b01) + @skipUnless(_have_multiprocessing, "requires multiprocessing") def test_workers(self): bar2fn = script_helper.make_script(self.directory, 'bar2', '') diff --git a/Lib/test/test_imp.py b/Lib/test/test_imp.py index b70ec7c..a115e60 100644 --- a/Lib/test/test_imp.py +++ b/Lib/test/test_imp.py @@ -4,11 +4,13 @@ import os import os.path import sys from test import support +from test.support import script_helper import unittest import warnings with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) import imp +import _imp def requires_load_dynamic(meth): @@ -329,6 +331,25 @@ class ImportTests(unittest.TestCase): with self.assertRaises(TypeError): create_dynamic(BadSpec()) + def test_source_hash(self): + self.assertEqual(_imp.source_hash(42, b'hi'), b'\xc6\xe7Z\r\x03:}\xab') + self.assertEqual(_imp.source_hash(43, b'hi'), b'\x85\x9765\xf8\x9a\x8b9') + + def test_pyc_invalidation_mode_from_cmdline(self): + cases = [ + ([], "default"), + (["--check-hash-based-pycs", "default"], "default"), + (["--check-hash-based-pycs", "always"], "always"), + (["--check-hash-based-pycs", "never"], "never"), + ] + for interp_args, expected in cases: + args = interp_args + [ + "-c", + "import _imp; print(_imp.check_hash_based_pycs)", + ] + res = script_helper.assert_python_ok(*args) + self.assertEqual(res.out.strip().decode('utf-8'), expected) + class ReloadTests(unittest.TestCase): diff --git a/Lib/test/test_import/__init__.py b/Lib/test/test_import/__init__.py index 5a610ba..ceea79f 100644 --- a/Lib/test/test_import/__init__.py +++ b/Lib/test/test_import/__init__.py @@ -598,7 +598,7 @@ func_filename = func.__code__.co_filename def test_foreign_code(self): py_compile.compile(self.file_name) with open(self.compiled_name, "rb") as f: - header = f.read(12) + header = f.read(16) code = marshal.load(f) constants = list(code.co_consts) foreign_code = importlib.import_module.__code__ diff --git a/Lib/test/test_importlib/source/test_file_loader.py b/Lib/test/test_importlib/source/test_file_loader.py index a151149..643a02c 100644 --- a/Lib/test/test_importlib/source/test_file_loader.py +++ b/Lib/test/test_importlib/source/test_file_loader.py @@ -235,6 +235,123 @@ class SimpleTest(abc.LoaderTests): warnings.simplefilter('ignore', DeprecationWarning) loader.load_module('bad name') + @util.writes_bytecode_files + def test_checked_hash_based_pyc(self): + with util.create_modules('_temp') as mapping: + source = mapping['_temp'] + pyc = self.util.cache_from_source(source) + with open(source, 'wb') as fp: + fp.write(b'state = "old"') + os.utime(source, (50, 50)) + py_compile.compile( + source, + invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH, + ) + loader = self.machinery.SourceFileLoader('_temp', source) + mod = types.ModuleType('_temp') + mod.__spec__ = self.util.spec_from_loader('_temp', loader) + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + # Write a new source with the same mtime and size as before. + with open(source, 'wb') as fp: + fp.write(b'state = "new"') + os.utime(source, (50, 50)) + loader.exec_module(mod) + self.assertEqual(mod.state, 'new') + with open(pyc, 'rb') as fp: + data = fp.read() + self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b11) + self.assertEqual( + self.util.source_hash(b'state = "new"'), + data[8:16], + ) + + @util.writes_bytecode_files + def test_overriden_checked_hash_based_pyc(self): + with util.create_modules('_temp') as mapping, \ + unittest.mock.patch('_imp.check_hash_based_pycs', 'never'): + source = mapping['_temp'] + pyc = self.util.cache_from_source(source) + with open(source, 'wb') as fp: + fp.write(b'state = "old"') + os.utime(source, (50, 50)) + py_compile.compile( + source, + invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH, + ) + loader = self.machinery.SourceFileLoader('_temp', source) + mod = types.ModuleType('_temp') + mod.__spec__ = self.util.spec_from_loader('_temp', loader) + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + # Write a new source with the same mtime and size as before. + with open(source, 'wb') as fp: + fp.write(b'state = "new"') + os.utime(source, (50, 50)) + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + + @util.writes_bytecode_files + def test_unchecked_hash_based_pyc(self): + with util.create_modules('_temp') as mapping: + source = mapping['_temp'] + pyc = self.util.cache_from_source(source) + with open(source, 'wb') as fp: + fp.write(b'state = "old"') + os.utime(source, (50, 50)) + py_compile.compile( + source, + invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH, + ) + loader = self.machinery.SourceFileLoader('_temp', source) + mod = types.ModuleType('_temp') + mod.__spec__ = self.util.spec_from_loader('_temp', loader) + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + # Update the source file, which should be ignored. + with open(source, 'wb') as fp: + fp.write(b'state = "new"') + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + with open(pyc, 'rb') as fp: + data = fp.read() + self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b1) + self.assertEqual( + self.util.source_hash(b'state = "old"'), + data[8:16], + ) + + @util.writes_bytecode_files + def test_overiden_unchecked_hash_based_pyc(self): + with util.create_modules('_temp') as mapping, \ + unittest.mock.patch('_imp.check_hash_based_pycs', 'always'): + source = mapping['_temp'] + pyc = self.util.cache_from_source(source) + with open(source, 'wb') as fp: + fp.write(b'state = "old"') + os.utime(source, (50, 50)) + py_compile.compile( + source, + invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH, + ) + loader = self.machinery.SourceFileLoader('_temp', source) + mod = types.ModuleType('_temp') + mod.__spec__ = self.util.spec_from_loader('_temp', loader) + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + # Update the source file, which should be ignored. + with open(source, 'wb') as fp: + fp.write(b'state = "new"') + loader.exec_module(mod) + self.assertEqual(mod.state, 'new') + with open(pyc, 'rb') as fp: + data = fp.read() + self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b1) + self.assertEqual( + self.util.source_hash(b'state = "new"'), + data[8:16], + ) + (Frozen_SimpleTest, Source_SimpleTest @@ -247,15 +364,17 @@ class BadBytecodeTest: def import_(self, file, module_name): raise NotImplementedError - def manipulate_bytecode(self, name, mapping, manipulator, *, - del_source=False): + def manipulate_bytecode(self, + name, mapping, manipulator, *, + del_source=False, + invalidation_mode=py_compile.PycInvalidationMode.TIMESTAMP): """Manipulate the bytecode of a module by passing it into a callable that returns what to use as the new bytecode.""" try: del sys.modules['_temp'] except KeyError: pass - py_compile.compile(mapping[name]) + py_compile.compile(mapping[name], invalidation_mode=invalidation_mode) if not del_source: bytecode_path = self.util.cache_from_source(mapping[name]) else: @@ -294,24 +413,51 @@ class BadBytecodeTest: del_source=del_source) test('_temp', mapping, bc_path) + def _test_partial_flags(self, test, *, del_source=False): + with util.create_modules('_temp') as mapping: + bc_path = self.manipulate_bytecode('_temp', mapping, + lambda bc: bc[:7], + del_source=del_source) + test('_temp', mapping, bc_path) + + def _test_partial_hash(self, test, *, del_source=False): + with util.create_modules('_temp') as mapping: + bc_path = self.manipulate_bytecode( + '_temp', + mapping, + lambda bc: bc[:13], + del_source=del_source, + invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH, + ) + test('_temp', mapping, bc_path) + with util.create_modules('_temp') as mapping: + bc_path = self.manipulate_bytecode( + '_temp', + mapping, + lambda bc: bc[:13], + del_source=del_source, + invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH, + ) + test('_temp', mapping, bc_path) + def _test_partial_timestamp(self, test, *, del_source=False): with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, - lambda bc: bc[:7], + lambda bc: bc[:11], del_source=del_source) test('_temp', mapping, bc_path) def _test_partial_size(self, test, *, del_source=False): with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, - lambda bc: bc[:11], + lambda bc: bc[:15], del_source=del_source) test('_temp', mapping, bc_path) def _test_no_marshal(self, *, del_source=False): with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, - lambda bc: bc[:12], + lambda bc: bc[:16], del_source=del_source) file_path = mapping['_temp'] if not del_source else bc_path with self.assertRaises(EOFError): @@ -320,7 +466,7 @@ class BadBytecodeTest: def _test_non_code_marshal(self, *, del_source=False): with util.create_modules('_temp') as mapping: bytecode_path = self.manipulate_bytecode('_temp', mapping, - lambda bc: bc[:12] + marshal.dumps(b'abcd'), + lambda bc: bc[:16] + marshal.dumps(b'abcd'), del_source=del_source) file_path = mapping['_temp'] if not del_source else bytecode_path with self.assertRaises(ImportError) as cm: @@ -331,7 +477,7 @@ class BadBytecodeTest: def _test_bad_marshal(self, *, del_source=False): with util.create_modules('_temp') as mapping: bytecode_path = self.manipulate_bytecode('_temp', mapping, - lambda bc: bc[:12] + b'<test>', + lambda bc: bc[:16] + b'<test>', del_source=del_source) file_path = mapping['_temp'] if not del_source else bytecode_path with self.assertRaises(EOFError): @@ -376,7 +522,7 @@ class SourceLoaderBadBytecodeTest: def test(name, mapping, bytecode_path): self.import_(mapping[name], name) with open(bytecode_path, 'rb') as file: - self.assertGreater(len(file.read()), 12) + self.assertGreater(len(file.read()), 16) self._test_empty_file(test) @@ -384,7 +530,7 @@ class SourceLoaderBadBytecodeTest: def test(name, mapping, bytecode_path): self.import_(mapping[name], name) with open(bytecode_path, 'rb') as file: - self.assertGreater(len(file.read()), 12) + self.assertGreater(len(file.read()), 16) self._test_partial_magic(test) @@ -395,7 +541,7 @@ class SourceLoaderBadBytecodeTest: def test(name, mapping, bytecode_path): self.import_(mapping[name], name) with open(bytecode_path, 'rb') as file: - self.assertGreater(len(file.read()), 12) + self.assertGreater(len(file.read()), 16) self._test_magic_only(test) @@ -418,18 +564,38 @@ class SourceLoaderBadBytecodeTest: def test(name, mapping, bc_path): self.import_(mapping[name], name) with open(bc_path, 'rb') as file: - self.assertGreater(len(file.read()), 12) + self.assertGreater(len(file.read()), 16) self._test_partial_timestamp(test) @util.writes_bytecode_files + def test_partial_flags(self): + # When the flags is partial, regenerate the .pyc, else raise EOFError. + def test(name, mapping, bc_path): + self.import_(mapping[name], name) + with open(bc_path, 'rb') as file: + self.assertGreater(len(file.read()), 16) + + self._test_partial_flags(test) + + @util.writes_bytecode_files + def test_partial_hash(self): + # When the hash is partial, regenerate the .pyc, else raise EOFError. + def test(name, mapping, bc_path): + self.import_(mapping[name], name) + with open(bc_path, 'rb') as file: + self.assertGreater(len(file.read()), 16) + + self._test_partial_hash(test) + + @util.writes_bytecode_files def test_partial_size(self): # When the size is partial, regenerate the .pyc, else # raise EOFError. def test(name, mapping, bc_path): self.import_(mapping[name], name) with open(bc_path, 'rb') as file: - self.assertGreater(len(file.read()), 12) + self.assertGreater(len(file.read()), 16) self._test_partial_size(test) @@ -459,13 +625,13 @@ class SourceLoaderBadBytecodeTest: py_compile.compile(mapping['_temp']) bytecode_path = self.util.cache_from_source(mapping['_temp']) with open(bytecode_path, 'r+b') as bytecode_file: - bytecode_file.seek(4) + bytecode_file.seek(8) bytecode_file.write(zeros) self.import_(mapping['_temp'], '_temp') source_mtime = os.path.getmtime(mapping['_temp']) source_timestamp = self.importlib._w_long(source_mtime) with open(bytecode_path, 'rb') as bytecode_file: - bytecode_file.seek(4) + bytecode_file.seek(8) self.assertEqual(bytecode_file.read(4), source_timestamp) # [bytecode read-only] @@ -560,6 +726,20 @@ class SourcelessLoaderBadBytecodeTest: self._test_partial_timestamp(test, del_source=True) + def test_partial_flags(self): + def test(name, mapping, bytecode_path): + with self.assertRaises(EOFError): + self.import_(bytecode_path, name) + + self._test_partial_flags(test, del_source=True) + + def test_partial_hash(self): + def test(name, mapping, bytecode_path): + with self.assertRaises(EOFError): + self.import_(bytecode_path, name) + + self._test_partial_hash(test, del_source=True) + def test_partial_size(self): def test(name, mapping, bytecode_path): with self.assertRaises(EOFError): diff --git a/Lib/test/test_importlib/test_abc.py b/Lib/test/test_importlib/test_abc.py index 54b2da6..4ba28c6 100644 --- a/Lib/test/test_importlib/test_abc.py +++ b/Lib/test/test_importlib/test_abc.py @@ -673,6 +673,7 @@ class SourceLoader(SourceOnlyLoader): if magic is None: magic = self.util.MAGIC_NUMBER data = bytearray(magic) + data.extend(self.init._w_long(0)) data.extend(self.init._w_long(self.source_mtime)) data.extend(self.init._w_long(self.source_size)) code_object = compile(self.source, self.path, 'exec', @@ -836,6 +837,7 @@ class SourceLoaderBytecodeTests(SourceLoaderTestHarness): if bytecode_written: self.assertIn(self.cached, self.loader.written) data = bytearray(self.util.MAGIC_NUMBER) + data.extend(self.init._w_long(0)) data.extend(self.init._w_long(self.loader.source_mtime)) data.extend(self.init._w_long(self.loader.source_size)) data.extend(marshal.dumps(code_object)) diff --git a/Lib/test/test_py_compile.py b/Lib/test/test_py_compile.py index 4a6caa5..bcb686c 100644 --- a/Lib/test/test_py_compile.py +++ b/Lib/test/test_py_compile.py @@ -122,6 +122,24 @@ class PyCompileTests(unittest.TestCase): # Specifying optimized bytecode should lead to a path reflecting that. self.assertIn('opt-2', py_compile.compile(self.source_path, optimize=2)) + def test_invalidation_mode(self): + py_compile.compile( + self.source_path, + invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH, + ) + with open(self.cache_path, 'rb') as fp: + flags = importlib._bootstrap_external._classify_pyc( + fp.read(), 'test', {}) + self.assertEqual(flags, 0b11) + py_compile.compile( + self.source_path, + invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH, + ) + with open(self.cache_path, 'rb') as fp: + flags = importlib._bootstrap_external._classify_pyc( + fp.read(), 'test', {}) + self.assertEqual(flags, 0b1) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_zipimport.py b/Lib/test/test_zipimport.py index 67ca39b..901bebd 100644 --- a/Lib/test/test_zipimport.py +++ b/Lib/test/test_zipimport.py @@ -40,7 +40,7 @@ def make_pyc(co, mtime, size): else: mtime = int(-0x100000000 + int(mtime)) pyc = (importlib.util.MAGIC_NUMBER + - struct.pack("<ii", int(mtime), size & 0xFFFFFFFF) + data) + struct.pack("<iii", 0, int(mtime), size & 0xFFFFFFFF) + data) return pyc def module_path_to_dotted_name(path): @@ -187,6 +187,20 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase): TESTMOD + pyc_ext: (NOW, test_pyc)} self.doTest(pyc_ext, files, TESTMOD) + def testUncheckedHashBasedPyc(self): + source = b"state = 'old'" + source_hash = importlib.util.source_hash(source) + bytecode = importlib._bootstrap_external._code_to_hash_pyc( + compile(source, "???", "exec"), + source_hash, + False, # unchecked + ) + files = {TESTMOD + ".py": (NOW, "state = 'new'"), + TESTMOD + ".pyc": (NOW - 20, bytecode)} + def check(mod): + self.assertEqual(mod.state, 'old') + self.doTest(None, files, TESTMOD, call=check) + def testEmptyPy(self): files = {TESTMOD + ".py": (NOW, "")} self.doTest(None, files, TESTMOD) @@ -215,7 +229,7 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase): badtime_pyc = bytearray(test_pyc) # flip the second bit -- not the first as that one isn't stored in the # .py's mtime in the zip archive. - badtime_pyc[7] ^= 0x02 + badtime_pyc[11] ^= 0x02 files = {TESTMOD + ".py": (NOW, test_src), TESTMOD + pyc_ext: (NOW, badtime_pyc)} self.doTest(".py", files, TESTMOD) |