diff options
author | Xtreak <tir.karthi@gmail.com> | 2019-05-28 07:07:39 (GMT) |
---|---|---|
committer | Miss Islington (bot) <31488909+miss-islington@users.noreply.github.com> | 2019-05-28 07:07:38 (GMT) |
commit | 436c2b0d67da68465e709a96daac7340af3a5238 (patch) | |
tree | 771d8d39bd772a7aa72640670e247b7e5bb14f6b | |
parent | 71dc7c5fbd856df83202f39c1f41ccd07c6eceb7 (diff) | |
download | cpython-436c2b0d67da68465e709a96daac7340af3a5238.zip cpython-436c2b0d67da68465e709a96daac7340af3a5238.tar.gz cpython-436c2b0d67da68465e709a96daac7340af3a5238.tar.bz2 |
bpo-36996: Handle async functions when mock.patch is used as a decorator (GH-13562)
Return a coroutine while patching async functions with a decorator.
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
https://bugs.python.org/issue36996
-rw-r--r-- | Lib/unittest/mock.py | 84 | ||||
-rw-r--r-- | Lib/unittest/test/testmock/testasync.py | 16 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst | 1 |
3 files changed, 74 insertions, 27 deletions
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index b91afd8..fac4535 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -26,6 +26,7 @@ __all__ = ( __version__ = '1.0' import asyncio +import contextlib import io import inspect import pprint @@ -1220,6 +1221,8 @@ class _patch(object): def __call__(self, func): if isinstance(func, type): return self.decorate_class(func) + if inspect.iscoroutinefunction(func): + return self.decorate_async_callable(func) return self.decorate_callable(func) @@ -1237,41 +1240,68 @@ class _patch(object): return klass + @contextlib.contextmanager + def decoration_helper(self, patched, args, keywargs): + extra_args = [] + entered_patchers = [] + patching = None + + exc_info = tuple() + try: + for patching in patched.patchings: + arg = patching.__enter__() + entered_patchers.append(patching) + if patching.attribute_name is not None: + keywargs.update(arg) + elif patching.new is DEFAULT: + extra_args.append(arg) + + args += tuple(extra_args) + yield (args, keywargs) + except: + if (patching not in entered_patchers and + _is_started(patching)): + # the patcher may have been started, but an exception + # raised whilst entering one of its additional_patchers + entered_patchers.append(patching) + # Pass the exception to __exit__ + exc_info = sys.exc_info() + # re-raise the exception + raise + finally: + for patching in reversed(entered_patchers): + patching.__exit__(*exc_info) + + def decorate_callable(self, func): + # NB. Keep the method in sync with decorate_async_callable() if hasattr(func, 'patchings'): func.patchings.append(self) return func @wraps(func) def patched(*args, **keywargs): - extra_args = [] - entered_patchers = [] + with self.decoration_helper(patched, + args, + keywargs) as (newargs, newkeywargs): + return func(*newargs, **newkeywargs) - exc_info = tuple() - try: - for patching in patched.patchings: - arg = patching.__enter__() - entered_patchers.append(patching) - if patching.attribute_name is not None: - keywargs.update(arg) - elif patching.new is DEFAULT: - extra_args.append(arg) - - args += tuple(extra_args) - return func(*args, **keywargs) - except: - if (patching not in entered_patchers and - _is_started(patching)): - # the patcher may have been started, but an exception - # raised whilst entering one of its additional_patchers - entered_patchers.append(patching) - # Pass the exception to __exit__ - exc_info = sys.exc_info() - # re-raise the exception - raise - finally: - for patching in reversed(entered_patchers): - patching.__exit__(*exc_info) + patched.patchings = [self] + return patched + + + def decorate_async_callable(self, func): + # NB. Keep the method in sync with decorate_callable() + if hasattr(func, 'patchings'): + func.patchings.append(self) + return func + + @wraps(func) + async def patched(*args, **keywargs): + with self.decoration_helper(patched, + args, + keywargs) as (newargs, newkeywargs): + return await func(*newargs, **newkeywargs) patched.patchings = [self] return patched diff --git a/Lib/unittest/test/testmock/testasync.py b/Lib/unittest/test/testmock/testasync.py index 0519d59..ccea4fe 100644 --- a/Lib/unittest/test/testmock/testasync.py +++ b/Lib/unittest/test/testmock/testasync.py @@ -66,6 +66,14 @@ class AsyncPatchDecoratorTest(unittest.TestCase): test_async() + def test_async_def_patch(self): + @patch(f"{__name__}.async_func", AsyncMock()) + async def test_async(): + self.assertIsInstance(async_func, AsyncMock) + + asyncio.run(test_async()) + self.assertTrue(inspect.iscoroutinefunction(async_func)) + class AsyncPatchCMTest(unittest.TestCase): def test_is_async_function_cm(self): @@ -91,6 +99,14 @@ class AsyncPatchCMTest(unittest.TestCase): test_async() + def test_async_def_cm(self): + async def test_async(): + with patch(f"{__name__}.async_func", AsyncMock()): + self.assertIsInstance(async_func, AsyncMock) + self.assertTrue(inspect.iscoroutinefunction(async_func)) + + asyncio.run(test_async()) + class AsyncMockTest(unittest.TestCase): def test_iscoroutinefunction_default(self): diff --git a/Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst b/Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst new file mode 100644 index 0000000..69d18d9 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst @@ -0,0 +1 @@ +Handle :func:`unittest.mock.patch` used as a decorator on async functions. |