summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/tarfile.py26
-rw-r--r--Lib/test/test_tarfile.py35
2 files changed, 47 insertions, 14 deletions
diff --git a/Lib/tarfile.py b/Lib/tarfile.py
index e4c5863..60259bc 100644
--- a/Lib/tarfile.py
+++ b/Lib/tarfile.py
@@ -1179,17 +1179,16 @@ class TarFile(object):
# Fill the TarInfo object with all
# information we can get.
- tarinfo.name = arcname
- tarinfo.mode = stmd
- tarinfo.uid = statres.st_uid
- tarinfo.gid = statres.st_gid
- if stat.S_ISDIR(stmd):
- # For a directory, the size must be 0
- tarinfo.size = 0
- else:
+ tarinfo.name = arcname
+ tarinfo.mode = stmd
+ tarinfo.uid = statres.st_uid
+ tarinfo.gid = statres.st_gid
+ if stat.S_ISREG(stmd):
tarinfo.size = statres.st_size
+ else:
+ tarinfo.size = 0L
tarinfo.mtime = statres.st_mtime
- tarinfo.type = type
+ tarinfo.type = type
tarinfo.linkname = linkname
if pwd:
try:
@@ -1280,16 +1279,15 @@ class TarFile(object):
self.addfile(tarinfo, f)
f.close()
- if tarinfo.type in (LNKTYPE, SYMTYPE, FIFOTYPE, CHRTYPE, BLKTYPE):
- tarinfo.size = 0L
- self.addfile(tarinfo)
-
- if tarinfo.isdir():
+ elif tarinfo.isdir():
self.addfile(tarinfo)
if recursive:
for f in os.listdir(name):
self.add(os.path.join(name, f), os.path.join(arcname, f))
+ else:
+ self.addfile(tarinfo)
+
def addfile(self, tarinfo, fileobj=None):
"""Add the TarInfo object `tarinfo' to the archive. If `fileobj' is
given, tarinfo.size bytes are read from it and added to the archive.
diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py
index cc5e505..66409cd 100644
--- a/Lib/test/test_tarfile.py
+++ b/Lib/test/test_tarfile.py
@@ -230,6 +230,40 @@ class WriteTest(BaseTest):
else:
self.dst.addfile(tarinfo, f)
+class WriteSize0Test(BaseTest):
+ mode = 'w'
+
+ def setUp(self):
+ self.tmpdir = dirname()
+ self.dstname = tmpname()
+ self.dst = tarfile.open(self.dstname, "w")
+
+ def tearDown(self):
+ self.dst.close()
+
+ def test_file(self):
+ path = os.path.join(self.tmpdir, "file")
+ file(path, "w")
+ tarinfo = self.dst.gettarinfo(path)
+ self.assertEqual(tarinfo.size, 0)
+ file(path, "w").write("aaa")
+ tarinfo = self.dst.gettarinfo(path)
+ self.assertEqual(tarinfo.size, 3)
+
+ def test_directory(self):
+ path = os.path.join(self.tmpdir, "directory")
+ os.mkdir(path)
+ tarinfo = self.dst.gettarinfo(path)
+ self.assertEqual(tarinfo.size, 0)
+
+ def test_symlink(self):
+ if hasattr(os, "symlink"):
+ path = os.path.join(self.tmpdir, "symlink")
+ os.symlink("link_target", path)
+ tarinfo = self.dst.gettarinfo(path)
+ self.assertEqual(tarinfo.size, 0)
+
+
class WriteStreamTest(WriteTest):
sep = '|'
@@ -399,6 +433,7 @@ def test_main():
ReadAsteriskTest,
ReadStreamAsteriskTest,
WriteTest,
+ WriteSize0Test,
WriteStreamTest,
WriteGNULongTest,
]