diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/shutil.py | 13 | ||||
-rw-r--r-- | Lib/test/test_shutil.py | 51 |
2 files changed, 40 insertions, 24 deletions
diff --git a/Lib/shutil.py b/Lib/shutil.py index 8ad0c64..f3a31a4 100644 --- a/Lib/shutil.py +++ b/Lib/shutil.py @@ -147,8 +147,8 @@ def ignore_patterns(*patterns): return set(ignored_names) return _ignore_patterns -def copytree(src, dst, symlinks=False, ignore=None): - """Recursively copy a directory tree using copy2(). +def copytree(src, dst, symlinks=False, ignore=None, copy_function=copy2): + """Recursively copy a directory tree. The destination directory must not already exist. If exception(s) occur, an Error is raised with a list of reasons. @@ -170,7 +170,10 @@ def copytree(src, dst, symlinks=False, ignore=None): list of names relative to the `src` directory that should not be copied. - XXX Consider this example code rather than the ultimate tool. + The optional copy_function argument is a callable that will be used + to copy each file. It will be called with the source path and the + destination path as arguments. By default, copy2() is used, but any + function that supports the same signature (like copy()) can be used. """ names = os.listdir(src) @@ -191,10 +194,10 @@ def copytree(src, dst, symlinks=False, ignore=None): linkto = os.readlink(srcname) os.symlink(linkto, dstname) elif os.path.isdir(srcname): - copytree(srcname, dstname, symlinks, ignore) + copytree(srcname, dstname, symlinks, ignore, copy_function) else: # Will raise a SpecialFileError for unsupported file types - copy2(srcname, dstname) + copy_function(srcname, dstname) # catch the Error from the recursive copytree so that we can # continue with other files except Error as err: diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py index 3faa95e..dfa6f9f 100644 --- a/Lib/test/test_shutil.py +++ b/Lib/test/test_shutil.py @@ -74,6 +74,7 @@ class TestShutil(unittest.TestCase): d = tempfile.mkdtemp() self.tempdirs.append(d) return d + def test_rmtree_errors(self): # filename is guaranteed not to exist filename = tempfile.mktemp() @@ -140,11 +141,12 @@ class TestShutil(unittest.TestCase): self.assertRaises(OSError, shutil.rmtree, path) os.remove(path) + def _write_data(self, path, data): + f = open(path, "w") + f.write(data) + f.close() + def test_copytree_simple(self): - def write_data(path, data): - f = open(path, "w") - f.write(data) - f.close() def read_data(path): f = open(path) @@ -154,11 +156,9 @@ class TestShutil(unittest.TestCase): src_dir = tempfile.mkdtemp() dst_dir = os.path.join(tempfile.mkdtemp(), 'destination') - - write_data(os.path.join(src_dir, 'test.txt'), '123') - + self._write_data(os.path.join(src_dir, 'test.txt'), '123') os.mkdir(os.path.join(src_dir, 'test_dir')) - write_data(os.path.join(src_dir, 'test_dir', 'test.txt'), '456') + self._write_data(os.path.join(src_dir, 'test_dir', 'test.txt'), '456') try: shutil.copytree(src_dir, dst_dir) @@ -187,11 +187,6 @@ class TestShutil(unittest.TestCase): def test_copytree_with_exclude(self): - def write_data(path, data): - f = open(path, "w") - f.write(data) - f.close() - def read_data(path): f = open(path) data = f.read() @@ -204,16 +199,18 @@ class TestShutil(unittest.TestCase): src_dir = tempfile.mkdtemp() try: dst_dir = join(tempfile.mkdtemp(), 'destination') - write_data(join(src_dir, 'test.txt'), '123') - write_data(join(src_dir, 'test.tmp'), '123') + self._write_data(join(src_dir, 'test.txt'), '123') + self._write_data(join(src_dir, 'test.tmp'), '123') os.mkdir(join(src_dir, 'test_dir')) - write_data(join(src_dir, 'test_dir', 'test.txt'), '456') + self._write_data(join(src_dir, 'test_dir', 'test.txt'), '456') os.mkdir(join(src_dir, 'test_dir2')) - write_data(join(src_dir, 'test_dir2', 'test.txt'), '456') + self._write_data(join(src_dir, 'test_dir2', 'test.txt'), '456') os.mkdir(join(src_dir, 'test_dir2', 'subdir')) os.mkdir(join(src_dir, 'test_dir2', 'subdir2')) - write_data(join(src_dir, 'test_dir2', 'subdir', 'test.txt'), '456') - write_data(join(src_dir, 'test_dir2', 'subdir2', 'test.py'), '456') + self._write_data(join(src_dir, 'test_dir2', 'subdir', 'test.txt'), + '456') + self._write_data(join(src_dir, 'test_dir2', 'subdir2', 'test.py'), + '456') # testing glob-like patterns @@ -339,6 +336,21 @@ class TestShutil(unittest.TestCase): shutil.rmtree(TESTFN, ignore_errors=True) shutil.rmtree(TESTFN2, ignore_errors=True) + def test_copytree_special_func(self): + + src_dir = self.mkdtemp() + dst_dir = os.path.join(self.mkdtemp(), 'destination') + self._write_data(os.path.join(src_dir, 'test.txt'), '123') + os.mkdir(os.path.join(src_dir, 'test_dir')) + self._write_data(os.path.join(src_dir, 'test_dir', 'test.txt'), '456') + + copied = [] + def _copy(src, dst): + copied.append((src, dst)) + + shutil.copytree(src_dir, dst_dir, copy_function=_copy) + self.assertEquals(len(copied), 2) + @unittest.skipUnless(zlib, "requires zlib") def test_make_tarball(self): # creating something to tar @@ -728,6 +740,7 @@ class TestMove(unittest.TestCase): finally: shutil.rmtree(TESTFN, ignore_errors=True) + def test_main(): support.run_unittest(TestShutil, TestMove) |