diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/unittest/mock.py | 59 | ||||
-rw-r--r-- | Lib/unittest/test/testmock/testasync.py | 63 |
2 files changed, 99 insertions, 23 deletions
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index b14bf01..b91afd8 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -51,6 +51,13 @@ def _is_async_obj(obj): return False +def _is_async_func(func): + if getattr(func, '__code__', None): + return asyncio.iscoroutinefunction(func) + else: + return False + + def _is_instance_mock(obj): # can't use isinstance on Mock objects because they override __class__ # The base class for all mocks is NonCallableMock @@ -225,6 +232,34 @@ def _setup_func(funcopy, mock, sig): mock._mock_delegate = funcopy +def _setup_async_mock(mock): + mock._is_coroutine = asyncio.coroutines._is_coroutine + mock.await_count = 0 + mock.await_args = None + mock.await_args_list = _CallList() + mock.awaited = _AwaitEvent(mock) + + # Mock is not configured yet so the attributes are set + # to a function and then the corresponding mock helper function + # is called when the helper is accessed similar to _setup_func. + def wrapper(attr, *args, **kwargs): + return getattr(mock.mock, attr)(*args, **kwargs) + + for attribute in ('assert_awaited', + 'assert_awaited_once', + 'assert_awaited_with', + 'assert_awaited_once_with', + 'assert_any_await', + 'assert_has_awaits', + 'assert_not_awaited'): + + # setattr(mock, attribute, wrapper) causes late binding + # hence attribute will always be the last value in the loop + # Use partial(wrapper, attribute) to ensure the attribute is bound + # correctly. + setattr(mock, attribute, partial(wrapper, attribute)) + + def _is_magic(name): return '__%s__' % name[2:-2] == name @@ -2151,7 +2186,7 @@ class AsyncMockMixin(Base): """ self = _mock_self if self.await_count != 0: - msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + msg = (f"Expected {self._mock_name or 'mock'} to not have been awaited." f" Awaited {self.await_count} times.") raise AssertionError(msg) @@ -2457,10 +2492,7 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, spec = type(spec) is_type = isinstance(spec, type) - if getattr(spec, '__code__', None): - is_async_func = asyncio.iscoroutinefunction(spec) - else: - is_async_func = False + is_async_func = _is_async_func(spec) _kwargs = {'spec': spec} if spec_set: _kwargs = {'spec_set': spec} @@ -2498,26 +2530,11 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, name=_name, **_kwargs) if isinstance(spec, FunctionTypes): - wrapped_mock = mock # should only happen at the top level because we don't # recurse for functions mock = _set_signature(mock, spec) if is_async_func: - mock._is_coroutine = asyncio.coroutines._is_coroutine - mock.await_count = 0 - mock.await_args = None - mock.await_args_list = _CallList() - - for a in ('assert_awaited', - 'assert_awaited_once', - 'assert_awaited_with', - 'assert_awaited_once_with', - 'assert_any_await', - 'assert_has_awaits', - 'assert_not_awaited'): - def f(*args, **kwargs): - return getattr(wrapped_mock, a)(*args, **kwargs) - setattr(mock, a, f) + _setup_async_mock(mock) else: _check_signature(spec, mock, is_type, instance) diff --git a/Lib/unittest/test/testmock/testasync.py b/Lib/unittest/test/testmock/testasync.py index a9aa143..0519d59 100644 --- a/Lib/unittest/test/testmock/testasync.py +++ b/Lib/unittest/test/testmock/testasync.py @@ -2,7 +2,8 @@ import asyncio import inspect import unittest -from unittest.mock import call, AsyncMock, patch, MagicMock, create_autospec +from unittest.mock import (call, AsyncMock, patch, MagicMock, create_autospec, + _AwaitEvent) def tearDownModule(): @@ -20,6 +21,9 @@ class AsyncClass: async def async_func(): pass +async def async_func_args(a, b, *, c): + pass + def normal_func(): pass @@ -141,8 +145,63 @@ class AsyncAutospecTest(unittest.TestCase): create_autospec(async_func, instance=True) def test_create_autospec(self): - spec = create_autospec(async_func) + spec = create_autospec(async_func_args) + awaitable = spec(1, 2, c=3) + async def main(): + await awaitable + + self.assertEqual(spec.await_count, 0) + self.assertIsNone(spec.await_args) + self.assertEqual(spec.await_args_list, []) + self.assertIsInstance(spec.awaited, _AwaitEvent) + spec.assert_not_awaited() + + asyncio.run(main()) + self.assertTrue(asyncio.iscoroutinefunction(spec)) + self.assertTrue(asyncio.iscoroutine(awaitable)) + self.assertEqual(spec.await_count, 1) + self.assertEqual(spec.await_args, call(1, 2, c=3)) + self.assertEqual(spec.await_args_list, [call(1, 2, c=3)]) + spec.assert_awaited_once() + spec.assert_awaited_once_with(1, 2, c=3) + spec.assert_awaited_with(1, 2, c=3) + spec.assert_awaited() + + def test_patch_with_autospec(self): + + async def test_async(): + with patch(f"{__name__}.async_func_args", autospec=True) as mock_method: + awaitable = mock_method(1, 2, c=3) + self.assertIsInstance(mock_method.mock, AsyncMock) + + self.assertTrue(asyncio.iscoroutinefunction(mock_method)) + self.assertTrue(asyncio.iscoroutine(awaitable)) + self.assertTrue(inspect.isawaitable(awaitable)) + + # Verify the default values during mock setup + self.assertEqual(mock_method.await_count, 0) + self.assertEqual(mock_method.await_args_list, []) + self.assertIsNone(mock_method.await_args) + self.assertIsInstance(mock_method.awaited, _AwaitEvent) + mock_method.assert_not_awaited() + + await awaitable + + self.assertEqual(mock_method.await_count, 1) + self.assertEqual(mock_method.await_args, call(1, 2, c=3)) + self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)]) + mock_method.assert_awaited_once() + mock_method.assert_awaited_once_with(1, 2, c=3) + mock_method.assert_awaited_with(1, 2, c=3) + mock_method.assert_awaited() + + mock_method.reset_mock() + self.assertEqual(mock_method.await_count, 0) + self.assertIsNone(mock_method.await_args) + self.assertEqual(mock_method.await_args_list, []) + + asyncio.run(test_async()) class AsyncSpecTest(unittest.TestCase): |