diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/shutil.py | 22 | ||||
-rw-r--r-- | Lib/test/test_shutil.py | 25 |
2 files changed, 38 insertions, 9 deletions
diff --git a/Lib/shutil.py b/Lib/shutil.py index 74348ba..8d0de72 100644 --- a/Lib/shutil.py +++ b/Lib/shutil.py @@ -432,13 +432,13 @@ def ignore_patterns(*patterns): return _ignore_patterns def _copytree(entries, src, dst, symlinks, ignore, copy_function, - ignore_dangling_symlinks): + ignore_dangling_symlinks, dirs_exist_ok=False): if ignore is not None: ignored_names = ignore(src, set(os.listdir(src))) else: ignored_names = set() - os.makedirs(dst) + os.makedirs(dst, exist_ok=dirs_exist_ok) errors = [] use_srcentry = copy_function is copy2 or copy_function is copy @@ -461,14 +461,15 @@ def _copytree(entries, src, dst, symlinks, ignore, copy_function, # ignore dangling symlink if the flag is on if not os.path.exists(linkto) and ignore_dangling_symlinks: continue - # otherwise let the copy occurs. copy2 will raise an error + # otherwise let the copy occur. copy2 will raise an error if srcentry.is_dir(): copytree(srcobj, dstname, symlinks, ignore, - copy_function) + copy_function, dirs_exist_ok=dirs_exist_ok) else: copy_function(srcobj, dstname) elif srcentry.is_dir(): - copytree(srcobj, dstname, symlinks, ignore, copy_function) + copytree(srcobj, dstname, symlinks, ignore, copy_function, + dirs_exist_ok=dirs_exist_ok) else: # Will raise a SpecialFileError for unsupported file types copy_function(srcentry, dstname) @@ -489,10 +490,12 @@ def _copytree(entries, src, dst, symlinks, ignore, copy_function, return dst def copytree(src, dst, symlinks=False, ignore=None, copy_function=copy2, - ignore_dangling_symlinks=False): - """Recursively copy a directory tree. + ignore_dangling_symlinks=False, dirs_exist_ok=False): + """Recursively copy a directory tree and return the destination directory. + + dirs_exist_ok dictates whether to raise an exception in case dst or any + missing parent directory already exists. - The destination directory must not already exist. If exception(s) occur, an Error is raised with a list of reasons. If the optional symlinks flag is true, symbolic links in the @@ -527,7 +530,8 @@ def copytree(src, dst, symlinks=False, ignore=None, copy_function=copy2, with os.scandir(src) as entries: return _copytree(entries=entries, src=src, dst=dst, symlinks=symlinks, ignore=ignore, copy_function=copy_function, - ignore_dangling_symlinks=ignore_dangling_symlinks) + ignore_dangling_symlinks=ignore_dangling_symlinks, + dirs_exist_ok=dirs_exist_ok) # version vulnerable to race conditions def _rmtree_unsafe(path, onerror): diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py index ec8fcc3..6f22e53 100644 --- a/Lib/test/test_shutil.py +++ b/Lib/test/test_shutil.py @@ -691,6 +691,31 @@ class TestShutil(unittest.TestCase): actual = read_file((dst_dir, 'test_dir', 'test.txt')) self.assertEqual(actual, '456') + def test_copytree_dirs_exist_ok(self): + src_dir = tempfile.mkdtemp() + dst_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, src_dir) + self.addCleanup(shutil.rmtree, dst_dir) + + write_file((src_dir, 'nonexisting.txt'), '123') + os.mkdir(os.path.join(src_dir, 'existing_dir')) + os.mkdir(os.path.join(dst_dir, 'existing_dir')) + write_file((dst_dir, 'existing_dir', 'existing.txt'), 'will be replaced') + write_file((src_dir, 'existing_dir', 'existing.txt'), 'has been replaced') + + shutil.copytree(src_dir, dst_dir, dirs_exist_ok=True) + self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'nonexisting.txt'))) + self.assertTrue(os.path.isdir(os.path.join(dst_dir, 'existing_dir'))) + self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'existing_dir', + 'existing.txt'))) + actual = read_file((dst_dir, 'nonexisting.txt')) + self.assertEqual(actual, '123') + actual = read_file((dst_dir, 'existing_dir', 'existing.txt')) + self.assertEqual(actual, 'has been replaced') + + with self.assertRaises(FileExistsError): + shutil.copytree(src_dir, dst_dir, dirs_exist_ok=False) + @support.skip_unless_symlink def test_copytree_symlinks(self): tmp_dir = self.mkdtemp() |