From 69058e20e420181abdc51094474f590d32cd7174 Mon Sep 17 00:00:00 2001 From: Barney Gale Date: Tue, 18 Jun 2024 22:15:18 +0100 Subject: GH-73991: Use same signature for `shutil._rmtree_[un]safe()`. (#120517) Preparatory work for moving `_rmtree_unsafe()` and `_rmtree_safe_fd()` to `pathlib._os` so that they can be used from both `shutil` and `pathlib`. Move implementation-specific setup from `rmtree()` into the safe/unsafe functions, and give them the same signature `(path, dir_fd, onexc)`. In the tests, mock `os.open` rather than `_rmtree_safe_fd()` to ensure the FD-based walk is used, and replace a couple references to `shutil._use_fd_functions` with `shutil.rmtree.avoids_symlink_attacks` (which has the same value). No change of behaviour. --- Lib/shutil.py | 75 +++++++++++++++++++++++++------------------------ Lib/test/test_shutil.py | 14 ++++----- 2 files changed, 44 insertions(+), 45 deletions(-) diff --git a/Lib/shutil.py b/Lib/shutil.py index b0d49e9..0235f6b 100644 --- a/Lib/shutil.py +++ b/Lib/shutil.py @@ -605,7 +605,22 @@ else: return stat.S_ISLNK(st.st_mode) # version vulnerable to race conditions -def _rmtree_unsafe(path, onexc): +def _rmtree_unsafe(path, dir_fd, onexc): + if dir_fd is not None: + raise NotImplementedError("dir_fd unavailable on this platform") + try: + st = os.lstat(path) + except OSError as err: + onexc(os.lstat, path, err) + return + try: + if _rmtree_islink(st): + # symlinks to directories are forbidden, see bug #1669 + raise OSError("Cannot call rmtree on a symbolic link") + except OSError as err: + onexc(os.path.islink, path, err) + # can't continue even if onexc hook returns + return def onerror(err): if not isinstance(err, FileNotFoundError): onexc(os.scandir, err.filename, err) @@ -635,7 +650,26 @@ def _rmtree_unsafe(path, onexc): onexc(os.rmdir, path, err) # Version using fd-based APIs to protect against races -def _rmtree_safe_fd(stack, onexc): +def _rmtree_safe_fd(path, dir_fd, onexc): + # While the unsafe rmtree works fine on bytes, the fd based does not. + if isinstance(path, bytes): + path = os.fsdecode(path) + stack = [(os.lstat, dir_fd, path, None)] + try: + while stack: + _rmtree_safe_fd_step(stack, onexc) + finally: + # Close any file descriptors still on the stack. + while stack: + func, fd, path, entry = stack.pop() + if func is not os.close: + continue + try: + os.close(fd) + except OSError as err: + onexc(os.close, path, err) + +def _rmtree_safe_fd_step(stack, onexc): # Each stack item has four elements: # * func: The first operation to perform: os.lstat, os.close or os.rmdir. # Walking a directory starts with an os.lstat() to detect symlinks; in @@ -710,6 +744,7 @@ _use_fd_functions = ({os.open, os.stat, os.unlink, os.rmdir} <= os.supports_dir_fd and os.scandir in os.supports_fd and os.stat in os.supports_follow_symlinks) +_rmtree_impl = _rmtree_safe_fd if _use_fd_functions else _rmtree_unsafe def rmtree(path, ignore_errors=False, onerror=None, *, onexc=None, dir_fd=None): """Recursively delete a directory tree. @@ -753,41 +788,7 @@ def rmtree(path, ignore_errors=False, onerror=None, *, onexc=None, dir_fd=None): exc_info = type(exc), exc, exc.__traceback__ return onerror(func, path, exc_info) - if _use_fd_functions: - # While the unsafe rmtree works fine on bytes, the fd based does not. - if isinstance(path, bytes): - path = os.fsdecode(path) - stack = [(os.lstat, dir_fd, path, None)] - try: - while stack: - _rmtree_safe_fd(stack, onexc) - finally: - # Close any file descriptors still on the stack. - while stack: - func, fd, path, entry = stack.pop() - if func is not os.close: - continue - try: - os.close(fd) - except OSError as err: - onexc(os.close, path, err) - else: - if dir_fd is not None: - raise NotImplementedError("dir_fd unavailable on this platform") - try: - st = os.lstat(path) - except OSError as err: - onexc(os.lstat, path, err) - return - try: - if _rmtree_islink(st): - # symlinks to directories are forbidden, see bug #1669 - raise OSError("Cannot call rmtree on a symbolic link") - except OSError as err: - onexc(os.path.islink, path, err) - # can't continue even if onexc hook returns - return - return _rmtree_unsafe(path, onexc) + _rmtree_impl(path, dir_fd, onexc) # Allow introspection of whether or not the hardening against symlink # attacks is supported on the current platform diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py index bccb81e..02ef172 100644 --- a/Lib/test/test_shutil.py +++ b/Lib/test/test_shutil.py @@ -558,25 +558,23 @@ class TestRmTree(BaseTest, unittest.TestCase): os.listdir in os.supports_fd and os.stat in os.supports_follow_symlinks) if _use_fd_functions: - self.assertTrue(shutil._use_fd_functions) self.assertTrue(shutil.rmtree.avoids_symlink_attacks) tmp_dir = self.mkdtemp() d = os.path.join(tmp_dir, 'a') os.mkdir(d) try: - real_rmtree = shutil._rmtree_safe_fd + real_open = os.open class Called(Exception): pass def _raiser(*args, **kwargs): raise Called - shutil._rmtree_safe_fd = _raiser + os.open = _raiser self.assertRaises(Called, shutil.rmtree, d) finally: - shutil._rmtree_safe_fd = real_rmtree + os.open = real_open else: - self.assertFalse(shutil._use_fd_functions) self.assertFalse(shutil.rmtree.avoids_symlink_attacks) - @unittest.skipUnless(shutil._use_fd_functions, "requires safe rmtree") + @unittest.skipUnless(shutil.rmtree.avoids_symlink_attacks, "requires safe rmtree") def test_rmtree_fails_on_close(self): # Test that the error handler is called for failed os.close() and that # os.close() is only called once for a file descriptor. @@ -611,7 +609,7 @@ class TestRmTree(BaseTest, unittest.TestCase): self.assertEqual(errors[1][1], dir1) self.assertEqual(close_count, 2) - @unittest.skipUnless(shutil._use_fd_functions, "dir_fd is not supported") + @unittest.skipUnless(shutil.rmtree.avoids_symlink_attacks, "dir_fd is not supported") def test_rmtree_with_dir_fd(self): tmp_dir = self.mkdtemp() victim = 'killme' @@ -625,7 +623,7 @@ class TestRmTree(BaseTest, unittest.TestCase): shutil.rmtree(victim, dir_fd=dir_fd) self.assertFalse(os.path.exists(fullname)) - @unittest.skipIf(shutil._use_fd_functions, "dir_fd is supported") + @unittest.skipIf(shutil.rmtree.avoids_symlink_attacks, "dir_fd is supported") def test_rmtree_with_dir_fd_unsupported(self): tmp_dir = self.mkdtemp() with self.assertRaises(NotImplementedError): -- cgit v0.12