diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/compileall.py | 56 | ||||
-rw-r--r-- | Lib/importlib/_bootstrap_external.py | 197 | ||||
-rw-r--r-- | Lib/importlib/util.py | 7 | ||||
-rw-r--r-- | Lib/modulefinder.py | 5 | ||||
-rw-r--r-- | Lib/pkgutil.py | 2 | ||||
-rw-r--r-- | Lib/py_compile.py | 25 | ||||
-rw-r--r-- | Lib/test/test_compileall.py | 21 | ||||
-rw-r--r-- | Lib/test/test_imp.py | 21 | ||||
-rw-r--r-- | Lib/test/test_import/__init__.py | 2 | ||||
-rw-r--r-- | Lib/test/test_importlib/source/test_file_loader.py | 210 | ||||
-rw-r--r-- | Lib/test/test_importlib/test_abc.py | 2 | ||||
-rw-r--r-- | Lib/test/test_py_compile.py | 18 | ||||
-rw-r--r-- | Lib/test/test_zipimport.py | 18 |
13 files changed, 485 insertions, 99 deletions
diff --git a/Lib/compileall.py b/Lib/compileall.py index 1c9ceb6..7259212 100644 --- a/Lib/compileall.py +++ b/Lib/compileall.py @@ -52,7 +52,8 @@ def _walk_dir(dir, ddir=None, maxlevels=10, quiet=0): maxlevels=maxlevels - 1, quiet=quiet) def compile_dir(dir, maxlevels=10, ddir=None, force=False, rx=None, - quiet=0, legacy=False, optimize=-1, workers=1): + quiet=0, legacy=False, optimize=-1, workers=1, + invalidation_mode=py_compile.PycInvalidationMode.TIMESTAMP): """Byte-compile all modules in the given directory tree. Arguments (only dir is required): @@ -67,6 +68,7 @@ def compile_dir(dir, maxlevels=10, ddir=None, force=False, rx=None, legacy: if True, produce legacy pyc paths instead of PEP 3147 paths optimize: optimization level or -1 for level of the interpreter workers: maximum number of parallel workers + invalidation_mode: how the up-to-dateness of the pyc will be checked """ if workers is not None and workers < 0: raise ValueError('workers must be greater or equal to 0') @@ -81,18 +83,20 @@ def compile_dir(dir, maxlevels=10, ddir=None, force=False, rx=None, ddir=ddir, force=force, rx=rx, quiet=quiet, legacy=legacy, - optimize=optimize), + optimize=optimize, + invalidation_mode=invalidation_mode), files) success = min(results, default=True) else: for file in files: if not compile_file(file, ddir, force, rx, quiet, - legacy, optimize): + legacy, optimize, invalidation_mode): success = False return success def compile_file(fullname, ddir=None, force=False, rx=None, quiet=0, - legacy=False, optimize=-1): + legacy=False, optimize=-1, + invalidation_mode=py_compile.PycInvalidationMode.TIMESTAMP): """Byte-compile one file. Arguments (only fullname is required): @@ -105,6 +109,7 @@ def compile_file(fullname, ddir=None, force=False, rx=None, quiet=0, no output with 2 legacy: if True, produce legacy pyc paths instead of PEP 3147 paths optimize: optimization level or -1 for level of the interpreter + invalidation_mode: how the up-to-dateness of the pyc will be checked """ success = True if quiet < 2 and isinstance(fullname, os.PathLike): @@ -134,10 +139,10 @@ def compile_file(fullname, ddir=None, force=False, rx=None, quiet=0, if not force: try: mtime = int(os.stat(fullname).st_mtime) - expect = struct.pack('<4sl', importlib.util.MAGIC_NUMBER, - mtime) + expect = struct.pack('<4sll', importlib.util.MAGIC_NUMBER, + 0, mtime) with open(cfile, 'rb') as chandle: - actual = chandle.read(8) + actual = chandle.read(12) if expect == actual: return success except OSError: @@ -146,7 +151,8 @@ def compile_file(fullname, ddir=None, force=False, rx=None, quiet=0, print('Compiling {!r}...'.format(fullname)) try: ok = py_compile.compile(fullname, cfile, dfile, True, - optimize=optimize) + optimize=optimize, + invalidation_mode=invalidation_mode) except py_compile.PyCompileError as err: success = False if quiet >= 2: @@ -175,7 +181,8 @@ def compile_file(fullname, ddir=None, force=False, rx=None, quiet=0, return success def compile_path(skip_curdir=1, maxlevels=0, force=False, quiet=0, - legacy=False, optimize=-1): + legacy=False, optimize=-1, + invalidation_mode=py_compile.PycInvalidationMode.TIMESTAMP): """Byte-compile all module on sys.path. Arguments (all optional): @@ -186,6 +193,7 @@ def compile_path(skip_curdir=1, maxlevels=0, force=False, quiet=0, quiet: as for compile_dir() (default 0) legacy: as for compile_dir() (default False) optimize: as for compile_dir() (default -1) + invalidation_mode: as for compiler_dir() """ success = True for dir in sys.path: @@ -193,9 +201,16 @@ def compile_path(skip_curdir=1, maxlevels=0, force=False, quiet=0, if quiet < 2: print('Skipping current directory') else: - success = success and compile_dir(dir, maxlevels, None, - force, quiet=quiet, - legacy=legacy, optimize=optimize) + success = success and compile_dir( + dir, + maxlevels, + None, + force, + quiet=quiet, + legacy=legacy, + optimize=optimize, + invalidation_mode=invalidation_mode, + ) return success @@ -238,6 +253,11 @@ def main(): 'to the equivalent of -l sys.path')) parser.add_argument('-j', '--workers', default=1, type=int, help='Run compileall concurrently') + invalidation_modes = [mode.name.lower().replace('_', '-') + for mode in py_compile.PycInvalidationMode] + parser.add_argument('--invalidation-mode', default='timestamp', + choices=sorted(invalidation_modes), + help='How the pycs will be invalidated at runtime') args = parser.parse_args() compile_dests = args.compile_dest @@ -266,23 +286,29 @@ def main(): if args.workers is not None: args.workers = args.workers or None + ivl_mode = args.invalidation_mode.replace('-', '_').upper() + invalidation_mode = py_compile.PycInvalidationMode[ivl_mode] + success = True try: if compile_dests: for dest in compile_dests: if os.path.isfile(dest): if not compile_file(dest, args.ddir, args.force, args.rx, - args.quiet, args.legacy): + args.quiet, args.legacy, + invalidation_mode=invalidation_mode): success = False else: if not compile_dir(dest, maxlevels, args.ddir, args.force, args.rx, args.quiet, - args.legacy, workers=args.workers): + args.legacy, workers=args.workers, + invalidation_mode=invalidation_mode): success = False return success else: return compile_path(legacy=args.legacy, force=args.force, - quiet=args.quiet) + quiet=args.quiet, + invalidation_mode=invalidation_mode) except KeyboardInterrupt: if args.quiet < 2: print("\n[interrupted]") diff --git a/Lib/importlib/_bootstrap_external.py b/Lib/importlib/_bootstrap_external.py index 41de8a7..e808507 100644 --- a/Lib/importlib/_bootstrap_external.py +++ b/Lib/importlib/_bootstrap_external.py @@ -242,6 +242,7 @@ _code_type = type(_write_atomic.__code__) # Python 3.6rc1 3379 (more thorough __class__ validation #23722) # Python 3.7a0 3390 (add LOAD_METHOD and CALL_METHOD opcodes) # Python 3.7a0 3391 (update GET_AITER #31709) +# Python 3.7a0 3392 (PEP 552: Deterministic pycs) # # MAGIC must change whenever the bytecode emitted by the compiler may no # longer be understood by older implementations of the eval loop (usually @@ -250,7 +251,7 @@ _code_type = type(_write_atomic.__code__) # Whenever MAGIC_NUMBER is changed, the ranges in the magic_values array # in PC/launcher.c must also be updated. -MAGIC_NUMBER = (3391).to_bytes(2, 'little') + b'\r\n' +MAGIC_NUMBER = (3392).to_bytes(2, 'little') + b'\r\n' _RAW_MAGIC_NUMBER = int.from_bytes(MAGIC_NUMBER, 'little') # For import.c _PYCACHE = '__pycache__' @@ -429,63 +430,93 @@ def _find_module_shim(self, fullname): return loader -def _validate_bytecode_header(data, source_stats=None, name=None, path=None): - """Validate the header of the passed-in bytecode against source_stats (if - given) and returning the bytecode that can be compiled by compile(). +def _classify_pyc(data, name, exc_details): + """Perform basic validity checking of a pyc header and return the flags field, + which determines how the pyc should be further validated against the source. - All other arguments are used to enhance error reporting. + *data* is the contents of the pyc file. (Only the first 16 bytes are + required, though.) - ImportError is raised when the magic number is incorrect or the bytecode is - found to be stale. EOFError is raised when the data is found to be - truncated. + *name* is the name of the module being imported. It is used for logging. + + *exc_details* is a dictionary passed to ImportError if it raised for + improved debugging. + + ImportError is raised when the magic number is incorrect or when the flags + field is invalid. EOFError is raised when the data is found to be truncated. """ - exc_details = {} - if name is not None: - exc_details['name'] = name - else: - # To prevent having to make all messages have a conditional name. - name = '<bytecode>' - if path is not None: - exc_details['path'] = path magic = data[:4] - raw_timestamp = data[4:8] - raw_size = data[8:12] if magic != MAGIC_NUMBER: - message = 'bad magic number in {!r}: {!r}'.format(name, magic) + message = f'bad magic number in {name!r}: {magic!r}' _bootstrap._verbose_message('{}', message) raise ImportError(message, **exc_details) - elif len(raw_timestamp) != 4: - message = 'reached EOF while reading timestamp in {!r}'.format(name) + if len(data) < 16: + message = f'reached EOF while reading pyc header of {name!r}' _bootstrap._verbose_message('{}', message) raise EOFError(message) - elif len(raw_size) != 4: - message = 'reached EOF while reading size of source in {!r}'.format(name) + flags = _r_long(data[4:8]) + # Only the first two flags are defined. + if flags & ~0b11: + message = f'invalid flags {flags!r} in {name!r}' + raise ImportError(message, **exc_details) + return flags + + +def _validate_timestamp_pyc(data, source_mtime, source_size, name, + exc_details): + """Validate a pyc against the source last-modified time. + + *data* is the contents of the pyc file. (Only the first 16 bytes are + required.) + + *source_mtime* is the last modified timestamp of the source file. + + *source_size* is None or the size of the source file in bytes. + + *name* is the name of the module being imported. It is used for logging. + + *exc_details* is a dictionary passed to ImportError if it raised for + improved debugging. + + An ImportError is raised if the bytecode is stale. + + """ + if _r_long(data[8:12]) != (source_mtime & 0xFFFFFFFF): + message = f'bytecode is stale for {name!r}' _bootstrap._verbose_message('{}', message) - raise EOFError(message) - if source_stats is not None: - try: - source_mtime = int(source_stats['mtime']) - except KeyError: - pass - else: - if _r_long(raw_timestamp) != source_mtime: - message = 'bytecode is stale for {!r}'.format(name) - _bootstrap._verbose_message('{}', message) - raise ImportError(message, **exc_details) - try: - source_size = source_stats['size'] & 0xFFFFFFFF - except KeyError: - pass - else: - if _r_long(raw_size) != source_size: - raise ImportError('bytecode is stale for {!r}'.format(name), - **exc_details) - return data[12:] + raise ImportError(message, **exc_details) + if (source_size is not None and + _r_long(data[12:16]) != (source_size & 0xFFFFFFFF)): + raise ImportError(f'bytecode is stale for {name!r}', **exc_details) + + +def _validate_hash_pyc(data, source_hash, name, exc_details): + """Validate a hash-based pyc by checking the real source hash against the one in + the pyc header. + + *data* is the contents of the pyc file. (Only the first 16 bytes are + required.) + + *source_hash* is the importlib.util.source_hash() of the source file. + + *name* is the name of the module being imported. It is used for logging. + + *exc_details* is a dictionary passed to ImportError if it raised for + improved debugging. + + An ImportError is raised if the bytecode is stale. + + """ + if data[8:16] != source_hash: + raise ImportError( + f'hash in bytecode doesn\'t match hash of source {name!r}', + **exc_details, + ) def _compile_bytecode(data, name=None, bytecode_path=None, source_path=None): - """Compile bytecode as returned by _validate_bytecode_header().""" + """Compile bytecode as found in a pyc.""" code = marshal.loads(data) if isinstance(code, _code_type): _bootstrap._verbose_message('code object from {!r}', bytecode_path) @@ -496,16 +527,28 @@ def _compile_bytecode(data, name=None, bytecode_path=None, source_path=None): raise ImportError('Non-code object in {!r}'.format(bytecode_path), name=name, path=bytecode_path) -def _code_to_bytecode(code, mtime=0, source_size=0): - """Compile a code object into bytecode for writing out to a byte-compiled - file.""" + +def _code_to_timestamp_pyc(code, mtime=0, source_size=0): + "Produce the data for a timestamp-based pyc." data = bytearray(MAGIC_NUMBER) + data.extend(_w_long(0)) data.extend(_w_long(mtime)) data.extend(_w_long(source_size)) data.extend(marshal.dumps(code)) return data +def _code_to_hash_pyc(code, source_hash, checked=True): + "Produce the data for a hash-based pyc." + data = bytearray(MAGIC_NUMBER) + flags = 0b1 | checked << 1 + data.extend(_w_long(flags)) + assert len(source_hash) == 8 + data.extend(source_hash) + data.extend(marshal.dumps(code)) + return data + + def decode_source(source_bytes): """Decode bytes representing source code and return the string. @@ -751,6 +794,10 @@ class SourceLoader(_LoaderBasics): """ source_path = self.get_filename(fullname) source_mtime = None + source_bytes = None + source_hash = None + hash_based = False + check_source = True try: bytecode_path = cache_from_source(source_path) except NotImplementedError: @@ -767,10 +814,34 @@ class SourceLoader(_LoaderBasics): except OSError: pass else: + exc_details = { + 'name': fullname, + 'path': bytecode_path, + } try: - bytes_data = _validate_bytecode_header(data, - source_stats=st, name=fullname, - path=bytecode_path) + flags = _classify_pyc(data, fullname, exc_details) + bytes_data = memoryview(data)[16:] + hash_based = flags & 0b1 != 0 + if hash_based: + check_source = flags & 0b10 != 0 + if (_imp.check_hash_based_pycs != 'never' and + (check_source or + _imp.check_hash_based_pycs == 'always')): + source_bytes = self.get_data(source_path) + source_hash = _imp.source_hash( + _RAW_MAGIC_NUMBER, + source_bytes, + ) + _validate_hash_pyc(data, source_hash, fullname, + exc_details) + else: + _validate_timestamp_pyc( + data, + source_mtime, + st['size'], + fullname, + exc_details, + ) except (ImportError, EOFError): pass else: @@ -779,13 +850,19 @@ class SourceLoader(_LoaderBasics): return _compile_bytecode(bytes_data, name=fullname, bytecode_path=bytecode_path, source_path=source_path) - source_bytes = self.get_data(source_path) + if source_bytes is None: + source_bytes = self.get_data(source_path) code_object = self.source_to_code(source_bytes, source_path) _bootstrap._verbose_message('code object from {}', source_path) if (not sys.dont_write_bytecode and bytecode_path is not None and source_mtime is not None): - data = _code_to_bytecode(code_object, source_mtime, - len(source_bytes)) + if hash_based: + if source_hash is None: + source_hash = _imp.source_hash(source_bytes) + data = _code_to_hash_pyc(code_object, source_hash, check_source) + else: + data = _code_to_timestamp_pyc(code_object, source_mtime, + len(source_bytes)) try: self._cache_bytecode(source_path, bytecode_path, data) _bootstrap._verbose_message('wrote {!r}', bytecode_path) @@ -887,8 +964,18 @@ class SourcelessFileLoader(FileLoader, _LoaderBasics): def get_code(self, fullname): path = self.get_filename(fullname) data = self.get_data(path) - bytes_data = _validate_bytecode_header(data, name=fullname, path=path) - return _compile_bytecode(bytes_data, name=fullname, bytecode_path=path) + # Call _classify_pyc to do basic validation of the pyc but ignore the + # result. There's no source to check against. + exc_details = { + 'name': fullname, + 'path': path, + } + _classify_pyc(data, fullname, exc_details) + return _compile_bytecode( + memoryview(data)[16:], + name=fullname, + bytecode_path=path, + ) def get_source(self, fullname): """Return None as there is no source code.""" diff --git a/Lib/importlib/util.py b/Lib/importlib/util.py index 41c74d4..9d0a90d 100644 --- a/Lib/importlib/util.py +++ b/Lib/importlib/util.py @@ -5,18 +5,25 @@ from ._bootstrap import _resolve_name from ._bootstrap import spec_from_loader from ._bootstrap import _find_spec from ._bootstrap_external import MAGIC_NUMBER +from ._bootstrap_external import _RAW_MAGIC_NUMBER from ._bootstrap_external import cache_from_source from ._bootstrap_external import decode_source from ._bootstrap_external import source_from_cache from ._bootstrap_external import spec_from_file_location from contextlib import contextmanager +import _imp import functools import sys import types import warnings +def source_hash(source_bytes): + "Return the hash of *source_bytes* as used in hash-based pyc files." + return _imp.source_hash(_RAW_MAGIC_NUMBER, source_bytes) + + def resolve_name(name, package): """Resolve a relative module name to an absolute one.""" if not name.startswith('.'): diff --git a/Lib/modulefinder.py b/Lib/modulefinder.py index e277ca7..10320a7 100644 --- a/Lib/modulefinder.py +++ b/Lib/modulefinder.py @@ -287,11 +287,12 @@ class ModuleFinder: co = compile(fp.read()+'\n', pathname, 'exec') elif type == imp.PY_COMPILED: try: - marshal_data = importlib._bootstrap_external._validate_bytecode_header(fp.read()) + data = fp.read() + importlib._bootstrap_external._classify_pyc(data, fqname, {}) except ImportError as exc: self.msgout(2, "raise ImportError: " + str(exc), pathname) raise - co = marshal.loads(marshal_data) + co = marshal.loads(memoryview(data)[16:]) else: co = None m = self.add_module(fqname) diff --git a/Lib/pkgutil.py b/Lib/pkgutil.py index 9180eae..8474a77 100644 --- a/Lib/pkgutil.py +++ b/Lib/pkgutil.py @@ -46,7 +46,7 @@ def read_code(stream): if magic != importlib.util.MAGIC_NUMBER: return None - stream.read(8) # Skip timestamp and size + stream.read(12) # Skip rest of the header return marshal.load(stream) diff --git a/Lib/py_compile.py b/Lib/py_compile.py index 11c5b50..a0f4def 100644 --- a/Lib/py_compile.py +++ b/Lib/py_compile.py @@ -3,6 +3,7 @@ This module has intimate knowledge of the format of .pyc files. """ +import enum import importlib._bootstrap_external import importlib.machinery import importlib.util @@ -11,7 +12,7 @@ import os.path import sys import traceback -__all__ = ["compile", "main", "PyCompileError"] +__all__ = ["compile", "main", "PyCompileError", "PycInvalidationMode"] class PyCompileError(Exception): @@ -62,7 +63,14 @@ class PyCompileError(Exception): return self.msg -def compile(file, cfile=None, dfile=None, doraise=False, optimize=-1): +class PycInvalidationMode(enum.Enum): + TIMESTAMP = 1 + CHECKED_HASH = 2 + UNCHECKED_HASH = 3 + + +def compile(file, cfile=None, dfile=None, doraise=False, optimize=-1, + invalidation_mode=PycInvalidationMode.TIMESTAMP): """Byte-compile one Python source file to Python bytecode. :param file: The source file name. @@ -79,6 +87,7 @@ def compile(file, cfile=None, dfile=None, doraise=False, optimize=-1): :param optimize: The optimization level for the compiler. Valid values are -1, 0, 1 and 2. A value of -1 means to use the optimization level of the current interpreter, as given by -O command line options. + :param invalidation_mode: :return: Path to the resulting byte compiled file. @@ -136,9 +145,17 @@ def compile(file, cfile=None, dfile=None, doraise=False, optimize=-1): os.makedirs(dirname) except FileExistsError: pass - source_stats = loader.path_stats(file) - bytecode = importlib._bootstrap_external._code_to_bytecode( + if invalidation_mode == PycInvalidationMode.TIMESTAMP: + source_stats = loader.path_stats(file) + bytecode = importlib._bootstrap_external._code_to_timestamp_pyc( code, source_stats['mtime'], source_stats['size']) + else: + source_hash = importlib.util.source_hash(source_bytes) + bytecode = importlib._bootstrap_external._code_to_hash_pyc( + code, + source_hash, + (invalidation_mode == PycInvalidationMode.CHECKED_HASH), + ) mode = importlib._bootstrap_external._calc_mode(file) importlib._bootstrap_external._write_atomic(cfile, bytecode, mode) return cfile diff --git a/Lib/test/test_compileall.py b/Lib/test/test_compileall.py index 2356efc..38d7b99 100644 --- a/Lib/test/test_compileall.py +++ b/Lib/test/test_compileall.py @@ -48,9 +48,9 @@ class CompileallTests(unittest.TestCase): def data(self): with open(self.bc_path, 'rb') as file: - data = file.read(8) + data = file.read(12) mtime = int(os.stat(self.source_path).st_mtime) - compare = struct.pack('<4sl', importlib.util.MAGIC_NUMBER, mtime) + compare = struct.pack('<4sll', importlib.util.MAGIC_NUMBER, 0, mtime) return data, compare @unittest.skipUnless(hasattr(os, 'stat'), 'test needs os.stat()') @@ -70,8 +70,8 @@ class CompileallTests(unittest.TestCase): def test_mtime(self): # Test a change in mtime leads to a new .pyc. - self.recreation_check(struct.pack('<4sl', importlib.util.MAGIC_NUMBER, - 1)) + self.recreation_check(struct.pack('<4sll', importlib.util.MAGIC_NUMBER, + 0, 1)) def test_magic_number(self): # Test a change in mtime leads to a new .pyc. @@ -519,6 +519,19 @@ class CommandLineTests(unittest.TestCase): out = self.assertRunOK('badfilename') self.assertRegex(out, b"Can't list 'badfilename'") + def test_pyc_invalidation_mode(self): + script_helper.make_script(self.pkgdir, 'f1', '') + pyc = importlib.util.cache_from_source( + os.path.join(self.pkgdir, 'f1.py')) + self.assertRunOK('--invalidation-mode=checked-hash', self.pkgdir) + with open(pyc, 'rb') as fp: + data = fp.read() + self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b11) + self.assertRunOK('--invalidation-mode=unchecked-hash', self.pkgdir) + with open(pyc, 'rb') as fp: + data = fp.read() + self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b01) + @skipUnless(_have_multiprocessing, "requires multiprocessing") def test_workers(self): bar2fn = script_helper.make_script(self.directory, 'bar2', '') diff --git a/Lib/test/test_imp.py b/Lib/test/test_imp.py index b70ec7c..a115e60 100644 --- a/Lib/test/test_imp.py +++ b/Lib/test/test_imp.py @@ -4,11 +4,13 @@ import os import os.path import sys from test import support +from test.support import script_helper import unittest import warnings with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) import imp +import _imp def requires_load_dynamic(meth): @@ -329,6 +331,25 @@ class ImportTests(unittest.TestCase): with self.assertRaises(TypeError): create_dynamic(BadSpec()) + def test_source_hash(self): + self.assertEqual(_imp.source_hash(42, b'hi'), b'\xc6\xe7Z\r\x03:}\xab') + self.assertEqual(_imp.source_hash(43, b'hi'), b'\x85\x9765\xf8\x9a\x8b9') + + def test_pyc_invalidation_mode_from_cmdline(self): + cases = [ + ([], "default"), + (["--check-hash-based-pycs", "default"], "default"), + (["--check-hash-based-pycs", "always"], "always"), + (["--check-hash-based-pycs", "never"], "never"), + ] + for interp_args, expected in cases: + args = interp_args + [ + "-c", + "import _imp; print(_imp.check_hash_based_pycs)", + ] + res = script_helper.assert_python_ok(*args) + self.assertEqual(res.out.strip().decode('utf-8'), expected) + class ReloadTests(unittest.TestCase): diff --git a/Lib/test/test_import/__init__.py b/Lib/test/test_import/__init__.py index 5a610ba..ceea79f 100644 --- a/Lib/test/test_import/__init__.py +++ b/Lib/test/test_import/__init__.py @@ -598,7 +598,7 @@ func_filename = func.__code__.co_filename def test_foreign_code(self): py_compile.compile(self.file_name) with open(self.compiled_name, "rb") as f: - header = f.read(12) + header = f.read(16) code = marshal.load(f) constants = list(code.co_consts) foreign_code = importlib.import_module.__code__ diff --git a/Lib/test/test_importlib/source/test_file_loader.py b/Lib/test/test_importlib/source/test_file_loader.py index a151149..643a02c 100644 --- a/Lib/test/test_importlib/source/test_file_loader.py +++ b/Lib/test/test_importlib/source/test_file_loader.py @@ -235,6 +235,123 @@ class SimpleTest(abc.LoaderTests): warnings.simplefilter('ignore', DeprecationWarning) loader.load_module('bad name') + @util.writes_bytecode_files + def test_checked_hash_based_pyc(self): + with util.create_modules('_temp') as mapping: + source = mapping['_temp'] + pyc = self.util.cache_from_source(source) + with open(source, 'wb') as fp: + fp.write(b'state = "old"') + os.utime(source, (50, 50)) + py_compile.compile( + source, + invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH, + ) + loader = self.machinery.SourceFileLoader('_temp', source) + mod = types.ModuleType('_temp') + mod.__spec__ = self.util.spec_from_loader('_temp', loader) + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + # Write a new source with the same mtime and size as before. + with open(source, 'wb') as fp: + fp.write(b'state = "new"') + os.utime(source, (50, 50)) + loader.exec_module(mod) + self.assertEqual(mod.state, 'new') + with open(pyc, 'rb') as fp: + data = fp.read() + self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b11) + self.assertEqual( + self.util.source_hash(b'state = "new"'), + data[8:16], + ) + + @util.writes_bytecode_files + def test_overriden_checked_hash_based_pyc(self): + with util.create_modules('_temp') as mapping, \ + unittest.mock.patch('_imp.check_hash_based_pycs', 'never'): + source = mapping['_temp'] + pyc = self.util.cache_from_source(source) + with open(source, 'wb') as fp: + fp.write(b'state = "old"') + os.utime(source, (50, 50)) + py_compile.compile( + source, + invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH, + ) + loader = self.machinery.SourceFileLoader('_temp', source) + mod = types.ModuleType('_temp') + mod.__spec__ = self.util.spec_from_loader('_temp', loader) + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + # Write a new source with the same mtime and size as before. + with open(source, 'wb') as fp: + fp.write(b'state = "new"') + os.utime(source, (50, 50)) + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + + @util.writes_bytecode_files + def test_unchecked_hash_based_pyc(self): + with util.create_modules('_temp') as mapping: + source = mapping['_temp'] + pyc = self.util.cache_from_source(source) + with open(source, 'wb') as fp: + fp.write(b'state = "old"') + os.utime(source, (50, 50)) + py_compile.compile( + source, + invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH, + ) + loader = self.machinery.SourceFileLoader('_temp', source) + mod = types.ModuleType('_temp') + mod.__spec__ = self.util.spec_from_loader('_temp', loader) + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + # Update the source file, which should be ignored. + with open(source, 'wb') as fp: + fp.write(b'state = "new"') + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + with open(pyc, 'rb') as fp: + data = fp.read() + self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b1) + self.assertEqual( + self.util.source_hash(b'state = "old"'), + data[8:16], + ) + + @util.writes_bytecode_files + def test_overiden_unchecked_hash_based_pyc(self): + with util.create_modules('_temp') as mapping, \ + unittest.mock.patch('_imp.check_hash_based_pycs', 'always'): + source = mapping['_temp'] + pyc = self.util.cache_from_source(source) + with open(source, 'wb') as fp: + fp.write(b'state = "old"') + os.utime(source, (50, 50)) + py_compile.compile( + source, + invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH, + ) + loader = self.machinery.SourceFileLoader('_temp', source) + mod = types.ModuleType('_temp') + mod.__spec__ = self.util.spec_from_loader('_temp', loader) + loader.exec_module(mod) + self.assertEqual(mod.state, 'old') + # Update the source file, which should be ignored. + with open(source, 'wb') as fp: + fp.write(b'state = "new"') + loader.exec_module(mod) + self.assertEqual(mod.state, 'new') + with open(pyc, 'rb') as fp: + data = fp.read() + self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b1) + self.assertEqual( + self.util.source_hash(b'state = "new"'), + data[8:16], + ) + (Frozen_SimpleTest, Source_SimpleTest @@ -247,15 +364,17 @@ class BadBytecodeTest: def import_(self, file, module_name): raise NotImplementedError - def manipulate_bytecode(self, name, mapping, manipulator, *, - del_source=False): + def manipulate_bytecode(self, + name, mapping, manipulator, *, + del_source=False, + invalidation_mode=py_compile.PycInvalidationMode.TIMESTAMP): """Manipulate the bytecode of a module by passing it into a callable that returns what to use as the new bytecode.""" try: del sys.modules['_temp'] except KeyError: pass - py_compile.compile(mapping[name]) + py_compile.compile(mapping[name], invalidation_mode=invalidation_mode) if not del_source: bytecode_path = self.util.cache_from_source(mapping[name]) else: @@ -294,24 +413,51 @@ class BadBytecodeTest: del_source=del_source) test('_temp', mapping, bc_path) + def _test_partial_flags(self, test, *, del_source=False): + with util.create_modules('_temp') as mapping: + bc_path = self.manipulate_bytecode('_temp', mapping, + lambda bc: bc[:7], + del_source=del_source) + test('_temp', mapping, bc_path) + + def _test_partial_hash(self, test, *, del_source=False): + with util.create_modules('_temp') as mapping: + bc_path = self.manipulate_bytecode( + '_temp', + mapping, + lambda bc: bc[:13], + del_source=del_source, + invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH, + ) + test('_temp', mapping, bc_path) + with util.create_modules('_temp') as mapping: + bc_path = self.manipulate_bytecode( + '_temp', + mapping, + lambda bc: bc[:13], + del_source=del_source, + invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH, + ) + test('_temp', mapping, bc_path) + def _test_partial_timestamp(self, test, *, del_source=False): with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, - lambda bc: bc[:7], + lambda bc: bc[:11], del_source=del_source) test('_temp', mapping, bc_path) def _test_partial_size(self, test, *, del_source=False): with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, - lambda bc: bc[:11], + lambda bc: bc[:15], del_source=del_source) test('_temp', mapping, bc_path) def _test_no_marshal(self, *, del_source=False): with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, - lambda bc: bc[:12], + lambda bc: bc[:16], del_source=del_source) file_path = mapping['_temp'] if not del_source else bc_path with self.assertRaises(EOFError): @@ -320,7 +466,7 @@ class BadBytecodeTest: def _test_non_code_marshal(self, *, del_source=False): with util.create_modules('_temp') as mapping: bytecode_path = self.manipulate_bytecode('_temp', mapping, - lambda bc: bc[:12] + marshal.dumps(b'abcd'), + lambda bc: bc[:16] + marshal.dumps(b'abcd'), del_source=del_source) file_path = mapping['_temp'] if not del_source else bytecode_path with self.assertRaises(ImportError) as cm: @@ -331,7 +477,7 @@ class BadBytecodeTest: def _test_bad_marshal(self, *, del_source=False): with util.create_modules('_temp') as mapping: bytecode_path = self.manipulate_bytecode('_temp', mapping, - lambda bc: bc[:12] + b'<test>', + lambda bc: bc[:16] + b'<test>', del_source=del_source) file_path = mapping['_temp'] if not del_source else bytecode_path with self.assertRaises(EOFError): @@ -376,7 +522,7 @@ class SourceLoaderBadBytecodeTest: def test(name, mapping, bytecode_path): self.import_(mapping[name], name) with open(bytecode_path, 'rb') as file: - self.assertGreater(len(file.read()), 12) + self.assertGreater(len(file.read()), 16) self._test_empty_file(test) @@ -384,7 +530,7 @@ class SourceLoaderBadBytecodeTest: def test(name, mapping, bytecode_path): self.import_(mapping[name], name) with open(bytecode_path, 'rb') as file: - self.assertGreater(len(file.read()), 12) + self.assertGreater(len(file.read()), 16) self._test_partial_magic(test) @@ -395,7 +541,7 @@ class SourceLoaderBadBytecodeTest: def test(name, mapping, bytecode_path): self.import_(mapping[name], name) with open(bytecode_path, 'rb') as file: - self.assertGreater(len(file.read()), 12) + self.assertGreater(len(file.read()), 16) self._test_magic_only(test) @@ -418,18 +564,38 @@ class SourceLoaderBadBytecodeTest: def test(name, mapping, bc_path): self.import_(mapping[name], name) with open(bc_path, 'rb') as file: - self.assertGreater(len(file.read()), 12) + self.assertGreater(len(file.read()), 16) self._test_partial_timestamp(test) @util.writes_bytecode_files + def test_partial_flags(self): + # When the flags is partial, regenerate the .pyc, else raise EOFError. + def test(name, mapping, bc_path): + self.import_(mapping[name], name) + with open(bc_path, 'rb') as file: + self.assertGreater(len(file.read()), 16) + + self._test_partial_flags(test) + + @util.writes_bytecode_files + def test_partial_hash(self): + # When the hash is partial, regenerate the .pyc, else raise EOFError. + def test(name, mapping, bc_path): + self.import_(mapping[name], name) + with open(bc_path, 'rb') as file: + self.assertGreater(len(file.read()), 16) + + self._test_partial_hash(test) + + @util.writes_bytecode_files def test_partial_size(self): # When the size is partial, regenerate the .pyc, else # raise EOFError. def test(name, mapping, bc_path): self.import_(mapping[name], name) with open(bc_path, 'rb') as file: - self.assertGreater(len(file.read()), 12) + self.assertGreater(len(file.read()), 16) self._test_partial_size(test) @@ -459,13 +625,13 @@ class SourceLoaderBadBytecodeTest: py_compile.compile(mapping['_temp']) bytecode_path = self.util.cache_from_source(mapping['_temp']) with open(bytecode_path, 'r+b') as bytecode_file: - bytecode_file.seek(4) + bytecode_file.seek(8) bytecode_file.write(zeros) self.import_(mapping['_temp'], '_temp') source_mtime = os.path.getmtime(mapping['_temp']) source_timestamp = self.importlib._w_long(source_mtime) with open(bytecode_path, 'rb') as bytecode_file: - bytecode_file.seek(4) + bytecode_file.seek(8) self.assertEqual(bytecode_file.read(4), source_timestamp) # [bytecode read-only] @@ -560,6 +726,20 @@ class SourcelessLoaderBadBytecodeTest: self._test_partial_timestamp(test, del_source=True) + def test_partial_flags(self): + def test(name, mapping, bytecode_path): + with self.assertRaises(EOFError): + self.import_(bytecode_path, name) + + self._test_partial_flags(test, del_source=True) + + def test_partial_hash(self): + def test(name, mapping, bytecode_path): + with self.assertRaises(EOFError): + self.import_(bytecode_path, name) + + self._test_partial_hash(test, del_source=True) + def test_partial_size(self): def test(name, mapping, bytecode_path): with self.assertRaises(EOFError): diff --git a/Lib/test/test_importlib/test_abc.py b/Lib/test/test_importlib/test_abc.py index 54b2da6..4ba28c6 100644 --- a/Lib/test/test_importlib/test_abc.py +++ b/Lib/test/test_importlib/test_abc.py @@ -673,6 +673,7 @@ class SourceLoader(SourceOnlyLoader): if magic is None: magic = self.util.MAGIC_NUMBER data = bytearray(magic) + data.extend(self.init._w_long(0)) data.extend(self.init._w_long(self.source_mtime)) data.extend(self.init._w_long(self.source_size)) code_object = compile(self.source, self.path, 'exec', @@ -836,6 +837,7 @@ class SourceLoaderBytecodeTests(SourceLoaderTestHarness): if bytecode_written: self.assertIn(self.cached, self.loader.written) data = bytearray(self.util.MAGIC_NUMBER) + data.extend(self.init._w_long(0)) data.extend(self.init._w_long(self.loader.source_mtime)) data.extend(self.init._w_long(self.loader.source_size)) data.extend(marshal.dumps(code_object)) diff --git a/Lib/test/test_py_compile.py b/Lib/test/test_py_compile.py index 4a6caa5..bcb686c 100644 --- a/Lib/test/test_py_compile.py +++ b/Lib/test/test_py_compile.py @@ -122,6 +122,24 @@ class PyCompileTests(unittest.TestCase): # Specifying optimized bytecode should lead to a path reflecting that. self.assertIn('opt-2', py_compile.compile(self.source_path, optimize=2)) + def test_invalidation_mode(self): + py_compile.compile( + self.source_path, + invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH, + ) + with open(self.cache_path, 'rb') as fp: + flags = importlib._bootstrap_external._classify_pyc( + fp.read(), 'test', {}) + self.assertEqual(flags, 0b11) + py_compile.compile( + self.source_path, + invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH, + ) + with open(self.cache_path, 'rb') as fp: + flags = importlib._bootstrap_external._classify_pyc( + fp.read(), 'test', {}) + self.assertEqual(flags, 0b1) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_zipimport.py b/Lib/test/test_zipimport.py index 67ca39b..901bebd 100644 --- a/Lib/test/test_zipimport.py +++ b/Lib/test/test_zipimport.py @@ -40,7 +40,7 @@ def make_pyc(co, mtime, size): else: mtime = int(-0x100000000 + int(mtime)) pyc = (importlib.util.MAGIC_NUMBER + - struct.pack("<ii", int(mtime), size & 0xFFFFFFFF) + data) + struct.pack("<iii", 0, int(mtime), size & 0xFFFFFFFF) + data) return pyc def module_path_to_dotted_name(path): @@ -187,6 +187,20 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase): TESTMOD + pyc_ext: (NOW, test_pyc)} self.doTest(pyc_ext, files, TESTMOD) + def testUncheckedHashBasedPyc(self): + source = b"state = 'old'" + source_hash = importlib.util.source_hash(source) + bytecode = importlib._bootstrap_external._code_to_hash_pyc( + compile(source, "???", "exec"), + source_hash, + False, # unchecked + ) + files = {TESTMOD + ".py": (NOW, "state = 'new'"), + TESTMOD + ".pyc": (NOW - 20, bytecode)} + def check(mod): + self.assertEqual(mod.state, 'old') + self.doTest(None, files, TESTMOD, call=check) + def testEmptyPy(self): files = {TESTMOD + ".py": (NOW, "")} self.doTest(None, files, TESTMOD) @@ -215,7 +229,7 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase): badtime_pyc = bytearray(test_pyc) # flip the second bit -- not the first as that one isn't stored in the # .py's mtime in the zip archive. - badtime_pyc[7] ^= 0x02 + badtime_pyc[11] ^= 0x02 files = {TESTMOD + ".py": (NOW, test_src), TESTMOD + pyc_ext: (NOW, badtime_pyc)} self.doTest(".py", files, TESTMOD) |