From 09aa752067863ecf749010b315227f41e25571a5 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith" Date: Sun, 3 Feb 2013 00:36:32 -0800 Subject: Refactor recently added bugfix into more testable code by using a method for windows file name sanitization. Splits the unittest up into several based on platform. --- Lib/test/test_zipfile.py | 41 +++++++++++++++++++++++++++++------------ Lib/zipfile.py | 26 +++++++++++++++++++------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/Lib/test/test_zipfile.py b/Lib/test/test_zipfile.py index c1e20b2..5e837cd 100644 --- a/Lib/test/test_zipfile.py +++ b/Lib/test/test_zipfile.py @@ -538,8 +538,15 @@ class TestsWithSourceFile(unittest.TestCase): with open(filename, 'rb') as f: self.assertEqual(f.read(), content) - def test_extract_hackers_arcnames(self): - hacknames = [ + def test_sanitize_windows_name(self): + san = zipfile.ZipFile._sanitize_windows_name + # Passing pathsep in allows this test to work regardless of platform. + self.assertEqual(san(r',,?,C:,foo,bar/z', ','), r'_,C_,foo,bar/z') + self.assertEqual(san(r'a\b,ce|f"g?h*i', ','), r'a\b,c_d_e_f_g_h_i') + self.assertEqual(san('../../foo../../ba..r', '/'), r'foo/ba..r') + + def test_extract_hackers_arcnames_common_cases(self): + common_hacknames = [ ('../foo/bar', 'foo/bar'), ('foo/../bar', 'foo/bar'), ('foo/../../bar', 'foo/bar'), @@ -549,8 +556,12 @@ class TestsWithSourceFile(unittest.TestCase): ('/foo/../bar', 'foo/bar'), ('/foo/../../bar', 'foo/bar'), ] - if os.path.sep == '\\': # Windows. - hacknames.extend([ + self._test_extract_hackers_arcnames(common_hacknames) + + @unittest.skipIf(os.path.sep != '\\', 'Requires \\ as path separator.') + def test_extract_hackers_arcnames_windows_only(self): + """Test combination of path fixing and windows name sanitization.""" + windows_hacknames = [ (r'..\foo\bar', 'foo/bar'), (r'..\/foo\/bar', 'foo/bar'), (r'foo/\..\/bar', 'foo/bar'), @@ -570,14 +581,19 @@ class TestsWithSourceFile(unittest.TestCase): (r'C:/../C:/foo/bar', 'C_/foo/bar'), (r'a:b\ce|f"g?h*i', 'b/c_d_e_f_g_h_i'), ('../../foo../../ba..r', 'foo/ba..r'), - ]) - else: # Unix - hacknames.extend([ - ('//foo/bar', 'foo/bar'), - ('../../foo../../ba..r', 'foo../ba..r'), - (r'foo/..\bar', r'foo/..\bar'), - ]) + ] + self._test_extract_hackers_arcnames(windows_hacknames) + + @unittest.skipIf(os.path.sep != '/', r'Requires / as path separator.') + def test_extract_hackers_arcnames_posix_only(self): + posix_hacknames = [ + ('//foo/bar', 'foo/bar'), + ('../../foo../../ba..r', 'foo../ba..r'), + (r'foo/..\bar', r'foo/..\bar'), + ] + self._test_extract_hackers_arcnames(posix_hacknames) + def _test_extract_hackers_arcnames(self, hacknames): for arcname, fixedname in hacknames: content = b'foobar' + arcname.encode() with zipfile.ZipFile(TESTFN2, 'w', zipfile.ZIP_STORED) as zipfp: @@ -594,7 +610,8 @@ class TestsWithSourceFile(unittest.TestCase): with zipfile.ZipFile(TESTFN2, 'r') as zipfp: writtenfile = zipfp.extract(arcname, targetpath) self.assertEqual(writtenfile, correctfile, - msg="extract %r" % arcname) + msg='extract %r: %r != %r' % + (arcname, writtenfile, correctfile)) self.check_file(correctfile, content) shutil.rmtree('target') diff --git a/Lib/zipfile.py b/Lib/zipfile.py index 8b355d6..3448c61 100644 --- a/Lib/zipfile.py +++ b/Lib/zipfile.py @@ -883,6 +883,7 @@ class ZipFile: """ fp = None # Set here since __del__ checks it + _windows_illegal_name_trans_table = None def __init__(self, file, mode="r", compression=ZIP_STORED, allowZip64=False): """Open the ZIP file with mode read "r", write "w" or append "a".""" @@ -1223,6 +1224,21 @@ class ZipFile: for zipinfo in members: self.extract(zipinfo, path, pwd) + @classmethod + def _sanitize_windows_name(cls, arcname, pathsep): + """Replace bad characters and remove trailing dots from parts.""" + table = cls._windows_illegal_name_trans_table + if not table: + illegal = ':<>|"?*' + table = str.maketrans(illegal, '_' * len(illegal)) + cls._windows_illegal_name_trans_table = table + arcname = arcname.translate(table) + # remove trailing dots + arcname = (x.rstrip('.') for x in arcname.split(pathsep)) + # rejoin, removing empty parts. + arcname = pathsep.join(x for x in arcname if x) + return arcname + def _extract_member(self, member, targetpath, pwd): """Extract the ZipInfo object 'member' to a physical file on the path targetpath. @@ -1236,16 +1252,12 @@ class ZipFile: # interpret absolute pathname as relative, remove drive letter or # UNC path, redundant separators, "." and ".." components. arcname = os.path.splitdrive(arcname)[1] + invalid_path_parts = ('', os.path.curdir, os.path.pardir) arcname = os.path.sep.join(x for x in arcname.split(os.path.sep) - if x not in ('', os.path.curdir, os.path.pardir)) + if x not in invalid_path_parts) if os.path.sep == '\\': # filter illegal characters on Windows - illegal = ':<>|"?*' - table = str.maketrans(illegal, '_' * len(illegal)) - arcname = arcname.translate(table) - # remove trailing dots - arcname = (x.rstrip('.') for x in arcname.split(os.path.sep)) - arcname = os.path.sep.join(x for x in arcname if x) + arcname = self._sanitize_windows_name(arcname, os.path.sep) targetpath = os.path.join(targetpath, arcname) targetpath = os.path.normpath(targetpath) -- cgit v0.12