summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorXtreak <tir.karthi@gmail.com>2019-05-28 07:07:39 (GMT)
committerMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>2019-05-28 07:07:38 (GMT)
commit436c2b0d67da68465e709a96daac7340af3a5238 (patch)
tree771d8d39bd772a7aa72640670e247b7e5bb14f6b
parent71dc7c5fbd856df83202f39c1f41ccd07c6eceb7 (diff)
downloadcpython-436c2b0d67da68465e709a96daac7340af3a5238.zip
cpython-436c2b0d67da68465e709a96daac7340af3a5238.tar.gz
cpython-436c2b0d67da68465e709a96daac7340af3a5238.tar.bz2
bpo-36996: Handle async functions when mock.patch is used as a decorator (GH-13562)
Return a coroutine while patching async functions with a decorator. Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com> https://bugs.python.org/issue36996
-rw-r--r--Lib/unittest/mock.py84
-rw-r--r--Lib/unittest/test/testmock/testasync.py16
-rw-r--r--Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst1
3 files changed, 74 insertions, 27 deletions
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index b91afd8..fac4535 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -26,6 +26,7 @@ __all__ = (
__version__ = '1.0'
import asyncio
+import contextlib
import io
import inspect
import pprint
@@ -1220,6 +1221,8 @@ class _patch(object):
def __call__(self, func):
if isinstance(func, type):
return self.decorate_class(func)
+ if inspect.iscoroutinefunction(func):
+ return self.decorate_async_callable(func)
return self.decorate_callable(func)
@@ -1237,41 +1240,68 @@ class _patch(object):
return klass
+ @contextlib.contextmanager
+ def decoration_helper(self, patched, args, keywargs):
+ extra_args = []
+ entered_patchers = []
+ patching = None
+
+ exc_info = tuple()
+ try:
+ for patching in patched.patchings:
+ arg = patching.__enter__()
+ entered_patchers.append(patching)
+ if patching.attribute_name is not None:
+ keywargs.update(arg)
+ elif patching.new is DEFAULT:
+ extra_args.append(arg)
+
+ args += tuple(extra_args)
+ yield (args, keywargs)
+ except:
+ if (patching not in entered_patchers and
+ _is_started(patching)):
+ # the patcher may have been started, but an exception
+ # raised whilst entering one of its additional_patchers
+ entered_patchers.append(patching)
+ # Pass the exception to __exit__
+ exc_info = sys.exc_info()
+ # re-raise the exception
+ raise
+ finally:
+ for patching in reversed(entered_patchers):
+ patching.__exit__(*exc_info)
+
+
def decorate_callable(self, func):
+ # NB. Keep the method in sync with decorate_async_callable()
if hasattr(func, 'patchings'):
func.patchings.append(self)
return func
@wraps(func)
def patched(*args, **keywargs):
- extra_args = []
- entered_patchers = []
+ with self.decoration_helper(patched,
+ args,
+ keywargs) as (newargs, newkeywargs):
+ return func(*newargs, **newkeywargs)
- exc_info = tuple()
- try:
- for patching in patched.patchings:
- arg = patching.__enter__()
- entered_patchers.append(patching)
- if patching.attribute_name is not None:
- keywargs.update(arg)
- elif patching.new is DEFAULT:
- extra_args.append(arg)
-
- args += tuple(extra_args)
- return func(*args, **keywargs)
- except:
- if (patching not in entered_patchers and
- _is_started(patching)):
- # the patcher may have been started, but an exception
- # raised whilst entering one of its additional_patchers
- entered_patchers.append(patching)
- # Pass the exception to __exit__
- exc_info = sys.exc_info()
- # re-raise the exception
- raise
- finally:
- for patching in reversed(entered_patchers):
- patching.__exit__(*exc_info)
+ patched.patchings = [self]
+ return patched
+
+
+ def decorate_async_callable(self, func):
+ # NB. Keep the method in sync with decorate_callable()
+ if hasattr(func, 'patchings'):
+ func.patchings.append(self)
+ return func
+
+ @wraps(func)
+ async def patched(*args, **keywargs):
+ with self.decoration_helper(patched,
+ args,
+ keywargs) as (newargs, newkeywargs):
+ return await func(*newargs, **newkeywargs)
patched.patchings = [self]
return patched
diff --git a/Lib/unittest/test/testmock/testasync.py b/Lib/unittest/test/testmock/testasync.py
index 0519d59..ccea4fe 100644
--- a/Lib/unittest/test/testmock/testasync.py
+++ b/Lib/unittest/test/testmock/testasync.py
@@ -66,6 +66,14 @@ class AsyncPatchDecoratorTest(unittest.TestCase):
test_async()
+ def test_async_def_patch(self):
+ @patch(f"{__name__}.async_func", AsyncMock())
+ async def test_async():
+ self.assertIsInstance(async_func, AsyncMock)
+
+ asyncio.run(test_async())
+ self.assertTrue(inspect.iscoroutinefunction(async_func))
+
class AsyncPatchCMTest(unittest.TestCase):
def test_is_async_function_cm(self):
@@ -91,6 +99,14 @@ class AsyncPatchCMTest(unittest.TestCase):
test_async()
+ def test_async_def_cm(self):
+ async def test_async():
+ with patch(f"{__name__}.async_func", AsyncMock()):
+ self.assertIsInstance(async_func, AsyncMock)
+ self.assertTrue(inspect.iscoroutinefunction(async_func))
+
+ asyncio.run(test_async())
+
class AsyncMockTest(unittest.TestCase):
def test_iscoroutinefunction_default(self):
diff --git a/Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst b/Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst
new file mode 100644
index 0000000..69d18d9
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst
@@ -0,0 +1 @@
+Handle :func:`unittest.mock.patch` used as a decorator on async functions.