diff options
author | Lisa Roach <lisaroach14@gmail.com> | 2019-11-21 18:14:32 (GMT) |
---|---|---|
committer | Andrew Svetlov <andrew.svetlov@gmail.com> | 2019-11-21 18:14:32 (GMT) |
commit | b2744c1be73f5af0d2dc4b952389efc90c8de94e (patch) | |
tree | a0c6798e4d84800e91a9967b3d6fda05327450f4 /Lib/unittest/mock.py | |
parent | 9458c5c42bbe5fb6ef2393c9ee66f012a2c13ca3 (diff) | |
download | cpython-b2744c1be73f5af0d2dc4b952389efc90c8de94e.zip cpython-b2744c1be73f5af0d2dc4b952389efc90c8de94e.tar.gz cpython-b2744c1be73f5af0d2dc4b952389efc90c8de94e.tar.bz2 |
[3.8] bpo-38857: AsyncMock fix for awaitable values and StopIteration fix [3.8] (GH-17269) (#17304)
(cherry picked from commit 046442d02bcc6e848e71e93e47f6cde9e279e993)
Co-authored-by: Jason Fried <fried@fb.com>
Diffstat (limited to 'Lib/unittest/mock.py')
-rw-r--r-- | Lib/unittest/mock.py | 62 |
1 files changed, 37 insertions, 25 deletions
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index 488ab1c..d6e3067 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -1125,8 +1125,8 @@ class CallableMixin(Base): _new_parent = _new_parent._mock_new_parent def _execute_mock_call(self, /, *args, **kwargs): - # seperate from _increment_mock_call so that awaited functions are - # executed seperately from their call + # separate from _increment_mock_call so that awaited functions are + # executed separately from their call, also AsyncMock overrides this method effect = self.side_effect if effect is not None: @@ -2122,29 +2122,45 @@ class AsyncMockMixin(Base): code_mock.co_flags = inspect.CO_COROUTINE self.__dict__['__code__'] = code_mock - async def _mock_call(self, /, *args, **kwargs): - try: - result = super()._mock_call(*args, **kwargs) - except (BaseException, StopIteration) as e: - side_effect = self.side_effect - if side_effect is not None and not callable(side_effect): - raise - return await _raise(e) + async def _execute_mock_call(self, /, *args, **kwargs): + # This is nearly just like super(), except for sepcial handling + # of coroutines _call = self.call_args + self.await_count += 1 + self.await_args = _call + self.await_args_list.append(_call) - async def proxy(): - try: - if inspect.isawaitable(result): - return await result - else: - return result - finally: - self.await_count += 1 - self.await_args = _call - self.await_args_list.append(_call) + effect = self.side_effect + if effect is not None: + if _is_exception(effect): + raise effect + elif not _callable(effect): + try: + result = next(effect) + except StopIteration: + # It is impossible to propogate a StopIteration + # through coroutines because of PEP 479 + raise StopAsyncIteration + if _is_exception(result): + raise result + elif asyncio.iscoroutinefunction(effect): + result = await effect(*args, **kwargs) + else: + result = effect(*args, **kwargs) - return await proxy() + if result is not DEFAULT: + return result + + if self._mock_return_value is not DEFAULT: + return self.return_value + + if self._mock_wraps is not None: + if asyncio.iscoroutinefunction(self._mock_wraps): + return await self._mock_wraps(*args, **kwargs) + return self._mock_wraps(*args, **kwargs) + + return self.return_value def assert_awaited(self): """ @@ -2852,10 +2868,6 @@ def seal(mock): seal(m) -async def _raise(exception): - raise exception - - class _AsyncIterator: """ Wraps an iterator in an asynchronous iterator. |