summaryrefslogtreecommitdiffstats
path: root/Lib/unittest/mock.py
diff options
context:
space:
mode:
authorXtreak <tir.karthi@gmail.com>2019-05-28 07:07:39 (GMT)
committerMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>2019-05-28 07:07:38 (GMT)
commit436c2b0d67da68465e709a96daac7340af3a5238 (patch)
tree771d8d39bd772a7aa72640670e247b7e5bb14f6b /Lib/unittest/mock.py
parent71dc7c5fbd856df83202f39c1f41ccd07c6eceb7 (diff)
downloadcpython-436c2b0d67da68465e709a96daac7340af3a5238.zip
cpython-436c2b0d67da68465e709a96daac7340af3a5238.tar.gz
cpython-436c2b0d67da68465e709a96daac7340af3a5238.tar.bz2
bpo-36996: Handle async functions when mock.patch is used as a decorator (GH-13562)
Return a coroutine while patching async functions with a decorator. Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com> https://bugs.python.org/issue36996
Diffstat (limited to 'Lib/unittest/mock.py')
-rw-r--r--Lib/unittest/mock.py84
1 files changed, 57 insertions, 27 deletions
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index b91afd8..fac4535 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -26,6 +26,7 @@ __all__ = (
__version__ = '1.0'
import asyncio
+import contextlib
import io
import inspect
import pprint
@@ -1220,6 +1221,8 @@ class _patch(object):
def __call__(self, func):
if isinstance(func, type):
return self.decorate_class(func)
+ if inspect.iscoroutinefunction(func):
+ return self.decorate_async_callable(func)
return self.decorate_callable(func)
@@ -1237,41 +1240,68 @@ class _patch(object):
return klass
+ @contextlib.contextmanager
+ def decoration_helper(self, patched, args, keywargs):
+ extra_args = []
+ entered_patchers = []
+ patching = None
+
+ exc_info = tuple()
+ try:
+ for patching in patched.patchings:
+ arg = patching.__enter__()
+ entered_patchers.append(patching)
+ if patching.attribute_name is not None:
+ keywargs.update(arg)
+ elif patching.new is DEFAULT:
+ extra_args.append(arg)
+
+ args += tuple(extra_args)
+ yield (args, keywargs)
+ except:
+ if (patching not in entered_patchers and
+ _is_started(patching)):
+ # the patcher may have been started, but an exception
+ # raised whilst entering one of its additional_patchers
+ entered_patchers.append(patching)
+ # Pass the exception to __exit__
+ exc_info = sys.exc_info()
+ # re-raise the exception
+ raise
+ finally:
+ for patching in reversed(entered_patchers):
+ patching.__exit__(*exc_info)
+
+
def decorate_callable(self, func):
+ # NB. Keep the method in sync with decorate_async_callable()
if hasattr(func, 'patchings'):
func.patchings.append(self)
return func
@wraps(func)
def patched(*args, **keywargs):
- extra_args = []
- entered_patchers = []
+ with self.decoration_helper(patched,
+ args,
+ keywargs) as (newargs, newkeywargs):
+ return func(*newargs, **newkeywargs)
- exc_info = tuple()
- try:
- for patching in patched.patchings:
- arg = patching.__enter__()
- entered_patchers.append(patching)
- if patching.attribute_name is not None:
- keywargs.update(arg)
- elif patching.new is DEFAULT:
- extra_args.append(arg)
-
- args += tuple(extra_args)
- return func(*args, **keywargs)
- except:
- if (patching not in entered_patchers and
- _is_started(patching)):
- # the patcher may have been started, but an exception
- # raised whilst entering one of its additional_patchers
- entered_patchers.append(patching)
- # Pass the exception to __exit__
- exc_info = sys.exc_info()
- # re-raise the exception
- raise
- finally:
- for patching in reversed(entered_patchers):
- patching.__exit__(*exc_info)
+ patched.patchings = [self]
+ return patched
+
+
+ def decorate_async_callable(self, func):
+ # NB. Keep the method in sync with decorate_callable()
+ if hasattr(func, 'patchings'):
+ func.patchings.append(self)
+ return func
+
+ @wraps(func)
+ async def patched(*args, **keywargs):
+ with self.decoration_helper(patched,
+ args,
+ keywargs) as (newargs, newkeywargs):
+ return await func(*newargs, **newkeywargs)
patched.patchings = [self]
return patched