summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2020-04-11 07:59:24 (GMT)
committerGitHub <noreply@github.com>2020-04-11 07:59:24 (GMT)
commit4b222c9491d1700e9bdd98e6889b8d0ea1c7321e (patch)
tree65a24b8ad8fa5a39de328107f276e4ee224c204c /Lib
parentcd8295ff758891f21084a6a5ad3403d35dda38f7 (diff)
downloadcpython-4b222c9491d1700e9bdd98e6889b8d0ea1c7321e.zip
cpython-4b222c9491d1700e9bdd98e6889b8d0ea1c7321e.tar.gz
cpython-4b222c9491d1700e9bdd98e6889b8d0ea1c7321e.tar.bz2
bpo-40126: Fix reverting multiple patches in unittest.mock. (GH-19351)
Patcher's __exit__() is now never called if its __enter__() is failed. Returning true from __exit__() silences now the exception.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/unittest/mock.py74
-rw-r--r--Lib/unittest/test/testmock/testpatch.py2
2 files changed, 27 insertions, 49 deletions
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index 20daf72..c0178f1 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -1241,11 +1241,6 @@ def _importer(target):
return thing
-def _is_started(patcher):
- # XXXX horrible
- return hasattr(patcher, 'is_local')
-
-
class _patch(object):
attribute_name = None
@@ -1316,14 +1311,9 @@ class _patch(object):
@contextlib.contextmanager
def decoration_helper(self, patched, args, keywargs):
extra_args = []
- entered_patchers = []
- patching = None
-
- exc_info = tuple()
- try:
+ with contextlib.ExitStack() as exit_stack:
for patching in patched.patchings:
- arg = patching.__enter__()
- entered_patchers.append(patching)
+ arg = exit_stack.enter_context(patching)
if patching.attribute_name is not None:
keywargs.update(arg)
elif patching.new is DEFAULT:
@@ -1331,19 +1321,6 @@ class _patch(object):
args += tuple(extra_args)
yield (args, keywargs)
- except:
- if (patching not in entered_patchers and
- _is_started(patching)):
- # the patcher may have been started, but an exception
- # raised whilst entering one of its additional_patchers
- entered_patchers.append(patching)
- # Pass the exception to __exit__
- exc_info = sys.exc_info()
- # re-raise the exception
- raise
- finally:
- for patching in reversed(entered_patchers):
- patching.__exit__(*exc_info)
def decorate_callable(self, func):
@@ -1520,25 +1497,26 @@ class _patch(object):
self.temp_original = original
self.is_local = local
- setattr(self.target, self.attribute, new_attr)
- if self.attribute_name is not None:
- extra_args = {}
- if self.new is DEFAULT:
- extra_args[self.attribute_name] = new
- for patching in self.additional_patchers:
- arg = patching.__enter__()
- if patching.new is DEFAULT:
- extra_args.update(arg)
- return extra_args
-
- return new
-
+ self._exit_stack = contextlib.ExitStack()
+ try:
+ setattr(self.target, self.attribute, new_attr)
+ if self.attribute_name is not None:
+ extra_args = {}
+ if self.new is DEFAULT:
+ extra_args[self.attribute_name] = new
+ for patching in self.additional_patchers:
+ arg = self._exit_stack.enter_context(patching)
+ if patching.new is DEFAULT:
+ extra_args.update(arg)
+ return extra_args
+
+ return new
+ except:
+ if not self.__exit__(*sys.exc_info()):
+ raise
def __exit__(self, *exc_info):
"""Undo the patch."""
- if not _is_started(self):
- return
-
if self.is_local and self.temp_original is not DEFAULT:
setattr(self.target, self.attribute, self.temp_original)
else:
@@ -1553,9 +1531,9 @@ class _patch(object):
del self.temp_original
del self.is_local
del self.target
- for patcher in reversed(self.additional_patchers):
- if _is_started(patcher):
- patcher.__exit__(*exc_info)
+ exit_stack = self._exit_stack
+ del self._exit_stack
+ return exit_stack.__exit__(*exc_info)
def start(self):
@@ -1571,9 +1549,9 @@ class _patch(object):
self._active_patches.remove(self)
except ValueError:
# If the patch hasn't been started this will fail
- pass
+ return None
- return self.__exit__()
+ return self.__exit__(None, None, None)
@@ -1873,9 +1851,9 @@ class _patch_dict(object):
_patch._active_patches.remove(self)
except ValueError:
# If the patch hasn't been started this will fail
- pass
+ return None
- return self.__exit__()
+ return self.__exit__(None, None, None)
def _clear_dict(in_dict):
diff --git a/Lib/unittest/test/testmock/testpatch.py b/Lib/unittest/test/testmock/testpatch.py
index f1bc0e1..d8c1515 100644
--- a/Lib/unittest/test/testmock/testpatch.py
+++ b/Lib/unittest/test/testmock/testpatch.py
@@ -774,7 +774,7 @@ class PatchTest(unittest.TestCase):
d = {'foo': 'bar'}
original = d.copy()
patcher = patch.dict(d, [('spam', 'eggs')], clear=True)
- self.assertEqual(patcher.stop(), False)
+ self.assertFalse(patcher.stop())
self.assertEqual(d, original)