summaryrefslogtreecommitdiffstats
path: root/Lib/zipfile.py
diff options
context:
space:
mode:
authorAntoine Pitrou <solipsis@pitrou.net>2012-11-17 22:52:05 (GMT)
committerAntoine Pitrou <solipsis@pitrou.net>2012-11-17 22:52:05 (GMT)
commit8572da5e961f6a645e3e8932568afd889448e78b (patch)
tree238300d2bf447c04054720f51c2b58adda55f54a /Lib/zipfile.py
parent6d5ad227a50c6c5a78e48a98095788953ab49512 (diff)
parent17babc5e97f49f03e0a271b1c80f7adaf2eb5f48 (diff)
downloadcpython-8572da5e961f6a645e3e8932568afd889448e78b.zip
cpython-8572da5e961f6a645e3e8932568afd889448e78b.tar.gz
cpython-8572da5e961f6a645e3e8932568afd889448e78b.tar.bz2
Issue #16408: Fix file descriptors not being closed in error conditions in the zipfile module.
Patch by Serhiy Storchaka.
Diffstat (limited to 'Lib/zipfile.py')
-rw-r--r--Lib/zipfile.py451
1 files changed, 217 insertions, 234 deletions
diff --git a/Lib/zipfile.py b/Lib/zipfile.py
index 209dc4a..68051c8 100644
--- a/Lib/zipfile.py
+++ b/Lib/zipfile.py
@@ -906,30 +906,34 @@ class ZipFile:
self.fp = file
self.filename = getattr(file, 'name', None)
- if key == 'r':
- self._GetContents()
- elif key == 'w':
- # set the modified flag so central directory gets written
- # even if no files are added to the archive
- self._didModify = True
- elif key == 'a':
- try:
- # See if file is a zip file
+ try:
+ if key == 'r':
self._RealGetContents()
- # seek to start of directory and overwrite
- self.fp.seek(self.start_dir, 0)
- except BadZipFile:
- # file is not a zip file, just append
- self.fp.seek(0, 2)
-
+ elif key == 'w':
# set the modified flag so central directory gets written
# even if no files are added to the archive
self._didModify = True
- else:
+ elif key == 'a':
+ try:
+ # See if file is a zip file
+ self._RealGetContents()
+ # seek to start of directory and overwrite
+ self.fp.seek(self.start_dir, 0)
+ except BadZipFile:
+ # file is not a zip file, just append
+ self.fp.seek(0, 2)
+
+ # set the modified flag so central directory gets written
+ # even if no files are added to the archive
+ self._didModify = True
+ else:
+ raise RuntimeError('Mode must be "r", "w" or "a"')
+ except:
+ fp = self.fp
+ self.fp = None
if not self._filePassed:
- self.fp.close()
- self.fp = None
- raise RuntimeError('Mode must be "r", "w" or "a"')
+ fp.close()
+ raise
def __enter__(self):
return self
@@ -937,17 +941,6 @@ class ZipFile:
def __exit__(self, type, value, traceback):
self.close()
- def _GetContents(self):
- """Read the directory, making sure we close the file if the format
- is bad."""
- try:
- self._RealGetContents()
- except BadZipFile:
- if not self._filePassed:
- self.fp.close()
- self.fp = None
- raise
-
def _RealGetContents(self):
"""Read in the table of contents for the ZIP file."""
fp = self.fp
@@ -1049,9 +1042,9 @@ class ZipFile:
try:
# Read by chunks, to avoid an OverflowError or a
# MemoryError with very large embedded files.
- f = self.open(zinfo.filename, "r")
- while f.read(chunk_size): # Check CRC-32
- pass
+ with self.open(zinfo.filename, "r") as f:
+ while f.read(chunk_size): # Check CRC-32
+ pass
except BadZipFile:
return zinfo.filename
@@ -1113,84 +1106,78 @@ class ZipFile:
else:
zef_file = io.open(self.filename, 'rb')
- # Make sure we have an info object
- if isinstance(name, ZipInfo):
- # 'name' is already an info object
- zinfo = name
- else:
- # Get info object for name
- try:
+ try:
+ # Make sure we have an info object
+ if isinstance(name, ZipInfo):
+ # 'name' is already an info object
+ zinfo = name
+ else:
+ # Get info object for name
zinfo = self.getinfo(name)
- except KeyError:
- if not self._filePassed:
- zef_file.close()
- raise
- zef_file.seek(zinfo.header_offset, 0)
-
- # Skip the file header:
- fheader = zef_file.read(sizeFileHeader)
- if fheader[0:4] != stringFileHeader:
- raise BadZipFile("Bad magic number for file header")
-
- fheader = struct.unpack(structFileHeader, fheader)
- fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
- if fheader[_FH_EXTRA_FIELD_LENGTH]:
- zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
-
- if zinfo.flag_bits & 0x20:
- # Zip 2.7: compressed patched data
- raise NotImplementedError("compressed patched data (flag bit 5)")
-
- if zinfo.flag_bits & 0x40:
- # strong encryption
- raise NotImplementedError("strong encryption (flag bit 6)")
-
- if zinfo.flag_bits & 0x800:
- # UTF-8 filename
- fname_str = fname.decode("utf-8")
- else:
- fname_str = fname.decode("cp437")
+ zef_file.seek(zinfo.header_offset, 0)
+
+ # Skip the file header:
+ fheader = zef_file.read(sizeFileHeader)
+ if fheader[0:4] != stringFileHeader:
+ raise BadZipFile("Bad magic number for file header")
+
+ fheader = struct.unpack(structFileHeader, fheader)
+ fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
+ if fheader[_FH_EXTRA_FIELD_LENGTH]:
+ zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
+
+ if zinfo.flag_bits & 0x20:
+ # Zip 2.7: compressed patched data
+ raise NotImplementedError("compressed patched data (flag bit 5)")
+
+ if zinfo.flag_bits & 0x40:
+ # strong encryption
+ raise NotImplementedError("strong encryption (flag bit 6)")
- if fname_str != zinfo.orig_filename:
+ if zinfo.flag_bits & 0x800:
+ # UTF-8 filename
+ fname_str = fname.decode("utf-8")
+ else:
+ fname_str = fname.decode("cp437")
+
+ if fname_str != zinfo.orig_filename:
+ raise BadZipFile(
+ 'File name in directory %r and header %r differ.'
+ % (zinfo.orig_filename, fname))
+
+ # check for encrypted flag & handle password
+ is_encrypted = zinfo.flag_bits & 0x1
+ zd = None
+ if is_encrypted:
+ if not pwd:
+ pwd = self.pwd
+ if not pwd:
+ raise RuntimeError("File %s is encrypted, password "
+ "required for extraction" % name)
+
+ zd = _ZipDecrypter(pwd)
+ # The first 12 bytes in the cypher stream is an encryption header
+ # used to strengthen the algorithm. The first 11 bytes are
+ # completely random, while the 12th contains the MSB of the CRC,
+ # or the MSB of the file time depending on the header type
+ # and is used to check the correctness of the password.
+ header = zef_file.read(12)
+ h = list(map(zd, header[0:12]))
+ if zinfo.flag_bits & 0x8:
+ # compare against the file type from extended local headers
+ check_byte = (zinfo._raw_time >> 8) & 0xff
+ else:
+ # compare against the CRC otherwise
+ check_byte = (zinfo.CRC >> 24) & 0xff
+ if h[11] != check_byte:
+ raise RuntimeError("Bad password for file", name)
+
+ return ZipExtFile(zef_file, mode, zinfo, zd,
+ close_fileobj=not self._filePassed)
+ except:
if not self._filePassed:
zef_file.close()
- raise BadZipFile(
- 'File name in directory %r and header %r differ.'
- % (zinfo.orig_filename, fname))
-
- # check for encrypted flag & handle password
- is_encrypted = zinfo.flag_bits & 0x1
- zd = None
- if is_encrypted:
- if not pwd:
- pwd = self.pwd
- if not pwd:
- if not self._filePassed:
- zef_file.close()
- raise RuntimeError("File %s is encrypted, "
- "password required for extraction" % name)
-
- zd = _ZipDecrypter(pwd)
- # The first 12 bytes in the cypher stream is an encryption header
- # used to strengthen the algorithm. The first 11 bytes are
- # completely random, while the 12th contains the MSB of the CRC,
- # or the MSB of the file time depending on the header type
- # and is used to check the correctness of the password.
- header = zef_file.read(12)
- h = list(map(zd, header[0:12]))
- if zinfo.flag_bits & 0x8:
- # compare against the file type from extended local headers
- check_byte = (zinfo._raw_time >> 8) & 0xff
- else:
- # compare against the CRC otherwise
- check_byte = (zinfo.CRC >> 24) & 0xff
- if h[11] != check_byte:
- if not self._filePassed:
- zef_file.close()
- raise RuntimeError("Bad password for file", name)
-
- return ZipExtFile(zef_file, mode, zinfo, zd,
- close_fileobj=not self._filePassed)
+ raise
def extract(self, member, path=None, pwd=None):
"""Extract a member from the archive to the current working directory,
@@ -1247,11 +1234,9 @@ class ZipFile:
os.mkdir(targetpath)
return targetpath
- source = self.open(member, pwd=pwd)
- target = open(targetpath, "wb")
- shutil.copyfileobj(source, target)
- source.close()
- target.close()
+ with self.open(member, pwd=pwd) as source, \
+ open(targetpath, "wb") as target:
+ shutil.copyfileobj(source, target)
return targetpath
@@ -1412,105 +1397,107 @@ class ZipFile:
if self.fp is None:
return
- if self.mode in ("w", "a") and self._didModify: # write ending records
- count = 0
- pos1 = self.fp.tell()
- for zinfo in self.filelist: # write central directory
- count = count + 1
- dt = zinfo.date_time
- dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2]
- dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2)
- extra = []
- if zinfo.file_size > ZIP64_LIMIT \
- or zinfo.compress_size > ZIP64_LIMIT:
- extra.append(zinfo.file_size)
- extra.append(zinfo.compress_size)
- file_size = 0xffffffff
- compress_size = 0xffffffff
- else:
- file_size = zinfo.file_size
- compress_size = zinfo.compress_size
-
- if zinfo.header_offset > ZIP64_LIMIT:
- extra.append(zinfo.header_offset)
- header_offset = 0xffffffff
- else:
- header_offset = zinfo.header_offset
-
- extra_data = zinfo.extra
- min_version = 0
- if extra:
- # Append a ZIP64 field to the extra's
- extra_data = struct.pack(
- '<HH' + 'Q'*len(extra),
- 1, 8*len(extra), *extra) + extra_data
-
- min_version = ZIP64_VERSION
-
- if zinfo.compress_type == ZIP_BZIP2:
- min_version = max(BZIP2_VERSION, min_version)
- elif zinfo.compress_type == ZIP_LZMA:
- min_version = max(LZMA_VERSION, min_version)
-
- extract_version = max(min_version, zinfo.extract_version)
- create_version = max(min_version, zinfo.create_version)
- try:
- filename, flag_bits = zinfo._encodeFilenameFlags()
- centdir = struct.pack(structCentralDir,
- stringCentralDir, create_version,
- zinfo.create_system, extract_version, zinfo.reserved,
- flag_bits, zinfo.compress_type, dostime, dosdate,
- zinfo.CRC, compress_size, file_size,
- len(filename), len(extra_data), len(zinfo.comment),
- 0, zinfo.internal_attr, zinfo.external_attr,
- header_offset)
- except DeprecationWarning:
- print((structCentralDir, stringCentralDir, create_version,
- zinfo.create_system, extract_version, zinfo.reserved,
- zinfo.flag_bits, zinfo.compress_type, dostime, dosdate,
- zinfo.CRC, compress_size, file_size,
- len(zinfo.filename), len(extra_data), len(zinfo.comment),
- 0, zinfo.internal_attr, zinfo.external_attr,
- header_offset), file=sys.stderr)
- raise
- self.fp.write(centdir)
- self.fp.write(filename)
- self.fp.write(extra_data)
- self.fp.write(zinfo.comment)
-
- pos2 = self.fp.tell()
- # Write end-of-zip-archive record
- centDirCount = count
- centDirSize = pos2 - pos1
- centDirOffset = pos1
- if (centDirCount >= ZIP_FILECOUNT_LIMIT or
- centDirOffset > ZIP64_LIMIT or
- centDirSize > ZIP64_LIMIT):
- # Need to write the ZIP64 end-of-archive records
- zip64endrec = struct.pack(
- structEndArchive64, stringEndArchive64,
- 44, 45, 45, 0, 0, centDirCount, centDirCount,
- centDirSize, centDirOffset)
- self.fp.write(zip64endrec)
-
- zip64locrec = struct.pack(
- structEndArchive64Locator,
- stringEndArchive64Locator, 0, pos2, 1)
- self.fp.write(zip64locrec)
- centDirCount = min(centDirCount, 0xFFFF)
- centDirSize = min(centDirSize, 0xFFFFFFFF)
- centDirOffset = min(centDirOffset, 0xFFFFFFFF)
-
- endrec = struct.pack(structEndArchive, stringEndArchive,
- 0, 0, centDirCount, centDirCount,
- centDirSize, centDirOffset, len(self._comment))
- self.fp.write(endrec)
- self.fp.write(self._comment)
- self.fp.flush()
-
- if not self._filePassed:
- self.fp.close()
- self.fp = None
+ try:
+ if self.mode in ("w", "a") and self._didModify: # write ending records
+ count = 0
+ pos1 = self.fp.tell()
+ for zinfo in self.filelist: # write central directory
+ count = count + 1
+ dt = zinfo.date_time
+ dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2]
+ dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2)
+ extra = []
+ if zinfo.file_size > ZIP64_LIMIT \
+ or zinfo.compress_size > ZIP64_LIMIT:
+ extra.append(zinfo.file_size)
+ extra.append(zinfo.compress_size)
+ file_size = 0xffffffff
+ compress_size = 0xffffffff
+ else:
+ file_size = zinfo.file_size
+ compress_size = zinfo.compress_size
+
+ if zinfo.header_offset > ZIP64_LIMIT:
+ extra.append(zinfo.header_offset)
+ header_offset = 0xffffffff
+ else:
+ header_offset = zinfo.header_offset
+
+ extra_data = zinfo.extra
+ min_version = 0
+ if extra:
+ # Append a ZIP64 field to the extra's
+ extra_data = struct.pack(
+ '<HH' + 'Q'*len(extra),
+ 1, 8*len(extra), *extra) + extra_data
+
+ min_version = ZIP64_VERSION
+
+ if zinfo.compress_type == ZIP_BZIP2:
+ min_version = max(BZIP2_VERSION, min_version)
+ elif zinfo.compress_type == ZIP_LZMA:
+ min_version = max(LZMA_VERSION, min_version)
+
+ extract_version = max(min_version, zinfo.extract_version)
+ create_version = max(min_version, zinfo.create_version)
+ try:
+ filename, flag_bits = zinfo._encodeFilenameFlags()
+ centdir = struct.pack(structCentralDir,
+ stringCentralDir, create_version,
+ zinfo.create_system, extract_version, zinfo.reserved,
+ flag_bits, zinfo.compress_type, dostime, dosdate,
+ zinfo.CRC, compress_size, file_size,
+ len(filename), len(extra_data), len(zinfo.comment),
+ 0, zinfo.internal_attr, zinfo.external_attr,
+ header_offset)
+ except DeprecationWarning:
+ print((structCentralDir, stringCentralDir, create_version,
+ zinfo.create_system, extract_version, zinfo.reserved,
+ zinfo.flag_bits, zinfo.compress_type, dostime, dosdate,
+ zinfo.CRC, compress_size, file_size,
+ len(zinfo.filename), len(extra_data), len(zinfo.comment),
+ 0, zinfo.internal_attr, zinfo.external_attr,
+ header_offset), file=sys.stderr)
+ raise
+ self.fp.write(centdir)
+ self.fp.write(filename)
+ self.fp.write(extra_data)
+ self.fp.write(zinfo.comment)
+
+ pos2 = self.fp.tell()
+ # Write end-of-zip-archive record
+ centDirCount = count
+ centDirSize = pos2 - pos1
+ centDirOffset = pos1
+ if (centDirCount >= ZIP_FILECOUNT_LIMIT or
+ centDirOffset > ZIP64_LIMIT or
+ centDirSize > ZIP64_LIMIT):
+ # Need to write the ZIP64 end-of-archive records
+ zip64endrec = struct.pack(
+ structEndArchive64, stringEndArchive64,
+ 44, 45, 45, 0, 0, centDirCount, centDirCount,
+ centDirSize, centDirOffset)
+ self.fp.write(zip64endrec)
+
+ zip64locrec = struct.pack(
+ structEndArchive64Locator,
+ stringEndArchive64Locator, 0, pos2, 1)
+ self.fp.write(zip64locrec)
+ centDirCount = min(centDirCount, 0xFFFF)
+ centDirSize = min(centDirSize, 0xFFFFFFFF)
+ centDirOffset = min(centDirOffset, 0xFFFFFFFF)
+
+ endrec = struct.pack(structEndArchive, stringEndArchive,
+ 0, 0, centDirCount, centDirCount,
+ centDirSize, centDirOffset, len(self._comment))
+ self.fp.write(endrec)
+ self.fp.write(self._comment)
+ self.fp.flush()
+ finally:
+ fp = self.fp
+ self.fp = None
+ if not self._filePassed:
+ fp.close()
class PyZipFile(ZipFile):
@@ -1677,16 +1664,15 @@ def main(args = None):
if len(args) != 2:
print(USAGE)
sys.exit(1)
- zf = ZipFile(args[1], 'r')
- zf.printdir()
- zf.close()
+ with ZipFile(args[1], 'r') as zf:
+ zf.printdir()
elif args[0] == '-t':
if len(args) != 2:
print(USAGE)
sys.exit(1)
- zf = ZipFile(args[1], 'r')
- badfile = zf.testzip()
+ with ZipFile(args[1], 'r') as zf:
+ badfile = zf.testzip()
if badfile:
print("The following enclosed file is corrupted: {!r}".format(badfile))
print("Done testing")
@@ -1696,20 +1682,19 @@ def main(args = None):
print(USAGE)
sys.exit(1)
- zf = ZipFile(args[1], 'r')
- out = args[2]
- for path in zf.namelist():
- if path.startswith('./'):
- tgt = os.path.join(out, path[2:])
- else:
- tgt = os.path.join(out, path)
+ with ZipFile(args[1], 'r') as zf:
+ out = args[2]
+ for path in zf.namelist():
+ if path.startswith('./'):
+ tgt = os.path.join(out, path[2:])
+ else:
+ tgt = os.path.join(out, path)
- tgtdir = os.path.dirname(tgt)
- if not os.path.exists(tgtdir):
- os.makedirs(tgtdir)
- with open(tgt, 'wb') as fp:
- fp.write(zf.read(path))
- zf.close()
+ tgtdir = os.path.dirname(tgt)
+ if not os.path.exists(tgtdir):
+ os.makedirs(tgtdir)
+ with open(tgt, 'wb') as fp:
+ fp.write(zf.read(path))
elif args[0] == '-c':
if len(args) < 3:
@@ -1725,11 +1710,9 @@ def main(args = None):
os.path.join(path, nm), os.path.join(zippath, nm))
# else: ignore
- zf = ZipFile(args[1], 'w', allowZip64=True)
- for src in args[2:]:
- addToZip(zf, src, os.path.basename(src))
-
- zf.close()
+ with ZipFile(args[1], 'w', allowZip64=True) as zf:
+ for src in args[2:]:
+ addToZip(zf, src, os.path.basename(src))
if __name__ == "__main__":
main()