summaryrefslogtreecommitdiffstats
path: root/Lib/unittest/mock.py
diff options
context:
space:
mode:
authorLisa Roach <lisaroach14@gmail.com>2019-11-21 18:14:32 (GMT)
committerAndrew Svetlov <andrew.svetlov@gmail.com>2019-11-21 18:14:32 (GMT)
commitb2744c1be73f5af0d2dc4b952389efc90c8de94e (patch)
treea0c6798e4d84800e91a9967b3d6fda05327450f4 /Lib/unittest/mock.py
parent9458c5c42bbe5fb6ef2393c9ee66f012a2c13ca3 (diff)
downloadcpython-b2744c1be73f5af0d2dc4b952389efc90c8de94e.zip
cpython-b2744c1be73f5af0d2dc4b952389efc90c8de94e.tar.gz
cpython-b2744c1be73f5af0d2dc4b952389efc90c8de94e.tar.bz2
[3.8] bpo-38857: AsyncMock fix for awaitable values and StopIteration fix [3.8] (GH-17269) (#17304)
(cherry picked from commit 046442d02bcc6e848e71e93e47f6cde9e279e993) Co-authored-by: Jason Fried <fried@fb.com>
Diffstat (limited to 'Lib/unittest/mock.py')
-rw-r--r--Lib/unittest/mock.py62
1 files changed, 37 insertions, 25 deletions
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index 488ab1c..d6e3067 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -1125,8 +1125,8 @@ class CallableMixin(Base):
_new_parent = _new_parent._mock_new_parent
def _execute_mock_call(self, /, *args, **kwargs):
- # seperate from _increment_mock_call so that awaited functions are
- # executed seperately from their call
+ # separate from _increment_mock_call so that awaited functions are
+ # executed separately from their call, also AsyncMock overrides this method
effect = self.side_effect
if effect is not None:
@@ -2122,29 +2122,45 @@ class AsyncMockMixin(Base):
code_mock.co_flags = inspect.CO_COROUTINE
self.__dict__['__code__'] = code_mock
- async def _mock_call(self, /, *args, **kwargs):
- try:
- result = super()._mock_call(*args, **kwargs)
- except (BaseException, StopIteration) as e:
- side_effect = self.side_effect
- if side_effect is not None and not callable(side_effect):
- raise
- return await _raise(e)
+ async def _execute_mock_call(self, /, *args, **kwargs):
+ # This is nearly just like super(), except for sepcial handling
+ # of coroutines
_call = self.call_args
+ self.await_count += 1
+ self.await_args = _call
+ self.await_args_list.append(_call)
- async def proxy():
- try:
- if inspect.isawaitable(result):
- return await result
- else:
- return result
- finally:
- self.await_count += 1
- self.await_args = _call
- self.await_args_list.append(_call)
+ effect = self.side_effect
+ if effect is not None:
+ if _is_exception(effect):
+ raise effect
+ elif not _callable(effect):
+ try:
+ result = next(effect)
+ except StopIteration:
+ # It is impossible to propogate a StopIteration
+ # through coroutines because of PEP 479
+ raise StopAsyncIteration
+ if _is_exception(result):
+ raise result
+ elif asyncio.iscoroutinefunction(effect):
+ result = await effect(*args, **kwargs)
+ else:
+ result = effect(*args, **kwargs)
- return await proxy()
+ if result is not DEFAULT:
+ return result
+
+ if self._mock_return_value is not DEFAULT:
+ return self.return_value
+
+ if self._mock_wraps is not None:
+ if asyncio.iscoroutinefunction(self._mock_wraps):
+ return await self._mock_wraps(*args, **kwargs)
+ return self._mock_wraps(*args, **kwargs)
+
+ return self.return_value
def assert_awaited(self):
"""
@@ -2852,10 +2868,6 @@ def seal(mock):
seal(m)
-async def _raise(exception):
- raise exception
-
-
class _AsyncIterator:
"""
Wraps an iterator in an asynchronous iterator.