diff options
author | Antoine Pitrou <solipsis@pitrou.net> | 2012-11-17 22:50:08 (GMT) |
---|---|---|
committer | Antoine Pitrou <solipsis@pitrou.net> | 2012-11-17 22:50:08 (GMT) |
commit | 17babc5e97f49f03e0a271b1c80f7adaf2eb5f48 (patch) | |
tree | eb4037a09cbdc572517b9ac395a2765bfb8df867 /Lib/zipfile.py | |
parent | a39a22dc0b20f0373792322b054228f1f62b736a (diff) | |
download | cpython-17babc5e97f49f03e0a271b1c80f7adaf2eb5f48.zip cpython-17babc5e97f49f03e0a271b1c80f7adaf2eb5f48.tar.gz cpython-17babc5e97f49f03e0a271b1c80f7adaf2eb5f48.tar.bz2 |
Issue #16408: Fix file descriptors not being closed in error conditions in the zipfile module.
Patch by Serhiy Storchaka.
Diffstat (limited to 'Lib/zipfile.py')
-rw-r--r-- | Lib/zipfile.py | 427 |
1 files changed, 205 insertions, 222 deletions
diff --git a/Lib/zipfile.py b/Lib/zipfile.py index 435a3e9..2da70b5 100644 --- a/Lib/zipfile.py +++ b/Lib/zipfile.py @@ -719,30 +719,34 @@ class ZipFile: self.fp = file self.filename = getattr(file, 'name', None) - if key == 'r': - self._GetContents() - elif key == 'w': - # set the modified flag so central directory gets written - # even if no files are added to the archive - self._didModify = True - elif key == 'a': - try: - # See if file is a zip file + try: + if key == 'r': self._RealGetContents() - # seek to start of directory and overwrite - self.fp.seek(self.start_dir, 0) - except BadZipFile: - # file is not a zip file, just append - self.fp.seek(0, 2) - + elif key == 'w': # set the modified flag so central directory gets written # even if no files are added to the archive self._didModify = True - else: + elif key == 'a': + try: + # See if file is a zip file + self._RealGetContents() + # seek to start of directory and overwrite + self.fp.seek(self.start_dir, 0) + except BadZipFile: + # file is not a zip file, just append + self.fp.seek(0, 2) + + # set the modified flag so central directory gets written + # even if no files are added to the archive + self._didModify = True + else: + raise RuntimeError('Mode must be "r", "w" or "a"') + except: + fp = self.fp + self.fp = None if not self._filePassed: - self.fp.close() - self.fp = None - raise RuntimeError('Mode must be "r", "w" or "a"') + fp.close() + raise def __enter__(self): return self @@ -750,17 +754,6 @@ class ZipFile: def __exit__(self, type, value, traceback): self.close() - def _GetContents(self): - """Read the directory, making sure we close the file if the format - is bad.""" - try: - self._RealGetContents() - except BadZipFile: - if not self._filePassed: - self.fp.close() - self.fp = None - raise - def _RealGetContents(self): """Read in the table of contents for the ZIP file.""" fp = self.fp @@ -862,9 +855,9 @@ class ZipFile: try: # Read by chunks, to avoid an OverflowError or a # MemoryError with very large embedded files. - f = self.open(zinfo.filename, "r") - while f.read(chunk_size): # Check CRC-32 - pass + with self.open(zinfo.filename, "r") as f: + while f.read(chunk_size): # Check CRC-32 + pass except BadZipFile: return zinfo.filename @@ -926,76 +919,70 @@ class ZipFile: else: zef_file = io.open(self.filename, 'rb') - # Make sure we have an info object - if isinstance(name, ZipInfo): - # 'name' is already an info object - zinfo = name - else: - # Get info object for name - try: + try: + # Make sure we have an info object + if isinstance(name, ZipInfo): + # 'name' is already an info object + zinfo = name + else: + # Get info object for name zinfo = self.getinfo(name) - except KeyError: - if not self._filePassed: - zef_file.close() - raise - zef_file.seek(zinfo.header_offset, 0) - - # Skip the file header: - fheader = zef_file.read(sizeFileHeader) - if fheader[0:4] != stringFileHeader: - raise BadZipFile("Bad magic number for file header") - - fheader = struct.unpack(structFileHeader, fheader) - fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) - if fheader[_FH_EXTRA_FIELD_LENGTH]: - zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH]) - - if zinfo.flag_bits & 0x800: - # UTF-8 filename - fname_str = fname.decode("utf-8") - else: - fname_str = fname.decode("cp437") + zef_file.seek(zinfo.header_offset, 0) + + # Skip the file header: + fheader = zef_file.read(sizeFileHeader) + if fheader[0:4] != stringFileHeader: + raise BadZipFile("Bad magic number for file header") - if fname_str != zinfo.orig_filename: + fheader = struct.unpack(structFileHeader, fheader) + fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) + if fheader[_FH_EXTRA_FIELD_LENGTH]: + zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH]) + + if zinfo.flag_bits & 0x800: + # UTF-8 filename + fname_str = fname.decode("utf-8") + else: + fname_str = fname.decode("cp437") + + if fname_str != zinfo.orig_filename: + raise BadZipFile( + 'File name in directory %r and header %r differ.' + % (zinfo.orig_filename, fname)) + + # check for encrypted flag & handle password + is_encrypted = zinfo.flag_bits & 0x1 + zd = None + if is_encrypted: + if not pwd: + pwd = self.pwd + if not pwd: + raise RuntimeError("File %s is encrypted, password " + "required for extraction" % name) + + zd = _ZipDecrypter(pwd) + # The first 12 bytes in the cypher stream is an encryption header + # used to strengthen the algorithm. The first 11 bytes are + # completely random, while the 12th contains the MSB of the CRC, + # or the MSB of the file time depending on the header type + # and is used to check the correctness of the password. + header = zef_file.read(12) + h = list(map(zd, header[0:12])) + if zinfo.flag_bits & 0x8: + # compare against the file type from extended local headers + check_byte = (zinfo._raw_time >> 8) & 0xff + else: + # compare against the CRC otherwise + check_byte = (zinfo.CRC >> 24) & 0xff + if h[11] != check_byte: + raise RuntimeError("Bad password for file", name) + + return ZipExtFile(zef_file, mode, zinfo, zd, + close_fileobj=not self._filePassed) + except: if not self._filePassed: zef_file.close() - raise BadZipFile( - 'File name in directory %r and header %r differ.' - % (zinfo.orig_filename, fname)) - - # check for encrypted flag & handle password - is_encrypted = zinfo.flag_bits & 0x1 - zd = None - if is_encrypted: - if not pwd: - pwd = self.pwd - if not pwd: - if not self._filePassed: - zef_file.close() - raise RuntimeError("File %s is encrypted, " - "password required for extraction" % name) - - zd = _ZipDecrypter(pwd) - # The first 12 bytes in the cypher stream is an encryption header - # used to strengthen the algorithm. The first 11 bytes are - # completely random, while the 12th contains the MSB of the CRC, - # or the MSB of the file time depending on the header type - # and is used to check the correctness of the password. - header = zef_file.read(12) - h = list(map(zd, header[0:12])) - if zinfo.flag_bits & 0x8: - # compare against the file type from extended local headers - check_byte = (zinfo._raw_time >> 8) & 0xff - else: - # compare against the CRC otherwise - check_byte = (zinfo.CRC >> 24) & 0xff - if h[11] != check_byte: - if not self._filePassed: - zef_file.close() - raise RuntimeError("Bad password for file", name) - - return ZipExtFile(zef_file, mode, zinfo, zd, - close_fileobj=not self._filePassed) + raise def extract(self, member, path=None, pwd=None): """Extract a member from the archive to the current working directory, @@ -1052,11 +1039,9 @@ class ZipFile: os.mkdir(targetpath) return targetpath - source = self.open(member, pwd=pwd) - target = open(targetpath, "wb") - shutil.copyfileobj(source, target) - source.close() - target.close() + with self.open(member, pwd=pwd) as source, \ + open(targetpath, "wb") as target: + shutil.copyfileobj(source, target) return targetpath @@ -1220,101 +1205,103 @@ class ZipFile: if self.fp is None: return - if self.mode in ("w", "a") and self._didModify: # write ending records - count = 0 - pos1 = self.fp.tell() - for zinfo in self.filelist: # write central directory - count = count + 1 - dt = zinfo.date_time - dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2] - dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2) - extra = [] - if zinfo.file_size > ZIP64_LIMIT \ - or zinfo.compress_size > ZIP64_LIMIT: - extra.append(zinfo.file_size) - extra.append(zinfo.compress_size) - file_size = 0xffffffff - compress_size = 0xffffffff - else: - file_size = zinfo.file_size - compress_size = zinfo.compress_size - - if zinfo.header_offset > ZIP64_LIMIT: - extra.append(zinfo.header_offset) - header_offset = 0xffffffff - else: - header_offset = zinfo.header_offset - - extra_data = zinfo.extra - if extra: - # Append a ZIP64 field to the extra's - extra_data = struct.pack( - '<HH' + 'Q'*len(extra), - 1, 8*len(extra), *extra) + extra_data - - extract_version = max(45, zinfo.extract_version) - create_version = max(45, zinfo.create_version) - else: - extract_version = zinfo.extract_version - create_version = zinfo.create_version - - try: - filename, flag_bits = zinfo._encodeFilenameFlags() - centdir = struct.pack(structCentralDir, - stringCentralDir, create_version, - zinfo.create_system, extract_version, zinfo.reserved, - flag_bits, zinfo.compress_type, dostime, dosdate, - zinfo.CRC, compress_size, file_size, - len(filename), len(extra_data), len(zinfo.comment), - 0, zinfo.internal_attr, zinfo.external_attr, - header_offset) - except DeprecationWarning: - print((structCentralDir, stringCentralDir, create_version, - zinfo.create_system, extract_version, zinfo.reserved, - zinfo.flag_bits, zinfo.compress_type, dostime, dosdate, - zinfo.CRC, compress_size, file_size, - len(zinfo.filename), len(extra_data), len(zinfo.comment), - 0, zinfo.internal_attr, zinfo.external_attr, - header_offset), file=sys.stderr) - raise - self.fp.write(centdir) - self.fp.write(filename) - self.fp.write(extra_data) - self.fp.write(zinfo.comment) - - pos2 = self.fp.tell() - # Write end-of-zip-archive record - centDirCount = count - centDirSize = pos2 - pos1 - centDirOffset = pos1 - if (centDirCount >= ZIP_FILECOUNT_LIMIT or - centDirOffset > ZIP64_LIMIT or - centDirSize > ZIP64_LIMIT): - # Need to write the ZIP64 end-of-archive records - zip64endrec = struct.pack( - structEndArchive64, stringEndArchive64, - 44, 45, 45, 0, 0, centDirCount, centDirCount, - centDirSize, centDirOffset) - self.fp.write(zip64endrec) - - zip64locrec = struct.pack( - structEndArchive64Locator, - stringEndArchive64Locator, 0, pos2, 1) - self.fp.write(zip64locrec) - centDirCount = min(centDirCount, 0xFFFF) - centDirSize = min(centDirSize, 0xFFFFFFFF) - centDirOffset = min(centDirOffset, 0xFFFFFFFF) - - endrec = struct.pack(structEndArchive, stringEndArchive, - 0, 0, centDirCount, centDirCount, - centDirSize, centDirOffset, len(self._comment)) - self.fp.write(endrec) - self.fp.write(self._comment) - self.fp.flush() - - if not self._filePassed: - self.fp.close() - self.fp = None + try: + if self.mode in ("w", "a") and self._didModify: # write ending records + count = 0 + pos1 = self.fp.tell() + for zinfo in self.filelist: # write central directory + count = count + 1 + dt = zinfo.date_time + dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2] + dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2) + extra = [] + if zinfo.file_size > ZIP64_LIMIT \ + or zinfo.compress_size > ZIP64_LIMIT: + extra.append(zinfo.file_size) + extra.append(zinfo.compress_size) + file_size = 0xffffffff + compress_size = 0xffffffff + else: + file_size = zinfo.file_size + compress_size = zinfo.compress_size + + if zinfo.header_offset > ZIP64_LIMIT: + extra.append(zinfo.header_offset) + header_offset = 0xffffffff + else: + header_offset = zinfo.header_offset + + extra_data = zinfo.extra + if extra: + # Append a ZIP64 field to the extra's + extra_data = struct.pack( + '<HH' + 'Q'*len(extra), + 1, 8*len(extra), *extra) + extra_data + + extract_version = max(45, zinfo.extract_version) + create_version = max(45, zinfo.create_version) + else: + extract_version = zinfo.extract_version + create_version = zinfo.create_version + + try: + filename, flag_bits = zinfo._encodeFilenameFlags() + centdir = struct.pack(structCentralDir, + stringCentralDir, create_version, + zinfo.create_system, extract_version, zinfo.reserved, + flag_bits, zinfo.compress_type, dostime, dosdate, + zinfo.CRC, compress_size, file_size, + len(filename), len(extra_data), len(zinfo.comment), + 0, zinfo.internal_attr, zinfo.external_attr, + header_offset) + except DeprecationWarning: + print((structCentralDir, stringCentralDir, create_version, + zinfo.create_system, extract_version, zinfo.reserved, + zinfo.flag_bits, zinfo.compress_type, dostime, dosdate, + zinfo.CRC, compress_size, file_size, + len(zinfo.filename), len(extra_data), len(zinfo.comment), + 0, zinfo.internal_attr, zinfo.external_attr, + header_offset), file=sys.stderr) + raise + self.fp.write(centdir) + self.fp.write(filename) + self.fp.write(extra_data) + self.fp.write(zinfo.comment) + + pos2 = self.fp.tell() + # Write end-of-zip-archive record + centDirCount = count + centDirSize = pos2 - pos1 + centDirOffset = pos1 + if (centDirCount >= ZIP_FILECOUNT_LIMIT or + centDirOffset > ZIP64_LIMIT or + centDirSize > ZIP64_LIMIT): + # Need to write the ZIP64 end-of-archive records + zip64endrec = struct.pack( + structEndArchive64, stringEndArchive64, + 44, 45, 45, 0, 0, centDirCount, centDirCount, + centDirSize, centDirOffset) + self.fp.write(zip64endrec) + + zip64locrec = struct.pack( + structEndArchive64Locator, + stringEndArchive64Locator, 0, pos2, 1) + self.fp.write(zip64locrec) + centDirCount = min(centDirCount, 0xFFFF) + centDirSize = min(centDirSize, 0xFFFFFFFF) + centDirOffset = min(centDirOffset, 0xFFFFFFFF) + + endrec = struct.pack(structEndArchive, stringEndArchive, + 0, 0, centDirCount, centDirCount, + centDirSize, centDirOffset, len(self._comment)) + self.fp.write(endrec) + self.fp.write(self._comment) + self.fp.flush() + finally: + fp = self.fp + self.fp = None + if not self._filePassed: + fp.close() class PyZipFile(ZipFile): @@ -1481,16 +1468,15 @@ def main(args = None): if len(args) != 2: print(USAGE) sys.exit(1) - zf = ZipFile(args[1], 'r') - zf.printdir() - zf.close() + with ZipFile(args[1], 'r') as zf: + zf.printdir() elif args[0] == '-t': if len(args) != 2: print(USAGE) sys.exit(1) - zf = ZipFile(args[1], 'r') - badfile = zf.testzip() + with ZipFile(args[1], 'r') as zf: + badfile = zf.testzip() if badfile: print("The following enclosed file is corrupted: {!r}".format(badfile)) print("Done testing") @@ -1500,20 +1486,19 @@ def main(args = None): print(USAGE) sys.exit(1) - zf = ZipFile(args[1], 'r') - out = args[2] - for path in zf.namelist(): - if path.startswith('./'): - tgt = os.path.join(out, path[2:]) - else: - tgt = os.path.join(out, path) + with ZipFile(args[1], 'r') as zf: + out = args[2] + for path in zf.namelist(): + if path.startswith('./'): + tgt = os.path.join(out, path[2:]) + else: + tgt = os.path.join(out, path) - tgtdir = os.path.dirname(tgt) - if not os.path.exists(tgtdir): - os.makedirs(tgtdir) - with open(tgt, 'wb') as fp: - fp.write(zf.read(path)) - zf.close() + tgtdir = os.path.dirname(tgt) + if not os.path.exists(tgtdir): + os.makedirs(tgtdir) + with open(tgt, 'wb') as fp: + fp.write(zf.read(path)) elif args[0] == '-c': if len(args) < 3: @@ -1529,11 +1514,9 @@ def main(args = None): os.path.join(path, nm), os.path.join(zippath, nm)) # else: ignore - zf = ZipFile(args[1], 'w', allowZip64=True) - for src in args[2:]: - addToZip(zf, src, os.path.basename(src)) - - zf.close() + with ZipFile(args[1], 'w', allowZip64=True) as zf: + for src in args[2:]: + addToZip(zf, src, os.path.basename(src)) if __name__ == "__main__": main() |