diff options
-rw-r--r-- | Doc/library/unittest.mock.rst | 215 | ||||
-rw-r--r-- | Doc/whatsnew/3.8.rst | 4 | ||||
-rw-r--r-- | Lib/unittest/mock.py | 406 | ||||
-rw-r--r-- | Lib/unittest/test/testmock/testasync.py | 549 | ||||
-rw-r--r-- | Lib/unittest/test/testmock/testmock.py | 5 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2018-09-13-20-33-24.bpo-26467.cahAk3.rst | 2 |
6 files changed, 1161 insertions, 20 deletions
diff --git a/Doc/library/unittest.mock.rst b/Doc/library/unittest.mock.rst index ed00ee6..21e4709 100644 --- a/Doc/library/unittest.mock.rst +++ b/Doc/library/unittest.mock.rst @@ -201,9 +201,11 @@ The Mock Class .. testsetup:: + import asyncio + import inspect import unittest from unittest.mock import sentinel, DEFAULT, ANY - from unittest.mock import patch, call, Mock, MagicMock, PropertyMock + from unittest.mock import patch, call, Mock, MagicMock, PropertyMock, AsyncMock from unittest.mock import mock_open :class:`Mock` is a flexible mock object intended to replace the use of stubs and @@ -851,6 +853,217 @@ object:: >>> p.assert_called_once_with() +.. class:: AsyncMock(spec=None, side_effect=None, return_value=DEFAULT, wraps=None, name=None, spec_set=None, unsafe=False, **kwargs) + + An asynchronous version of :class:`Mock`. The :class:`AsyncMock` object will + behave so the object is recognized as an async function, and the result of a + call is an awaitable. + + >>> mock = AsyncMock() + >>> asyncio.iscoroutinefunction(mock) + True + >>> inspect.isawaitable(mock()) + True + + The result of ``mock()`` is an async function which will have the outcome + of ``side_effect`` or ``return_value``: + + - if ``side_effect`` is a function, the async function will return the + result of that function, + - if ``side_effect`` is an exception, the async function will raise the + exception, + - if ``side_effect`` is an iterable, the async function will return the + next value of the iterable, however, if the sequence of result is + exhausted, ``StopIteration`` is raised immediately, + - if ``side_effect`` is not defined, the async function will return the + value defined by ``return_value``, hence, by default, the async function + returns a new :class:`AsyncMock` object. + + + Setting the *spec* of a :class:`Mock` or :class:`MagicMock` to an async function + will result in a coroutine object being returned after calling. + + >>> async def async_func(): pass + ... + >>> mock = MagicMock(async_func) + >>> mock + <MagicMock spec='function' id='...'> + >>> mock() + <coroutine object AsyncMockMixin._mock_call at ...> + + .. method:: assert_awaited() + + Assert that the mock was awaited at least once. + + >>> mock = AsyncMock() + >>> async def main(): + ... await mock() + ... + >>> asyncio.run(main()) + >>> mock.assert_awaited() + >>> mock_2 = AsyncMock() + >>> mock_2.assert_awaited() + Traceback (most recent call last): + ... + AssertionError: Expected mock to have been awaited. + + .. method:: assert_awaited_once() + + Assert that the mock was awaited exactly once. + + >>> mock = AsyncMock() + >>> async def main(): + ... await mock() + ... + >>> asyncio.run(main()) + >>> mock.assert_awaited_once() + >>> asyncio.run(main()) + >>> mock.method.assert_awaited_once() + Traceback (most recent call last): + ... + AssertionError: Expected mock to have been awaited once. Awaited 2 times. + + .. method:: assert_awaited_with(*args, **kwargs) + + Assert that the last await was with the specified arguments. + + >>> mock = AsyncMock() + >>> async def main(*args, **kwargs): + ... await mock(*args, **kwargs) + ... + >>> asyncio.run(main('foo', bar='bar')) + >>> mock.assert_awaited_with('foo', bar='bar') + >>> mock.assert_awaited_with('other') + Traceback (most recent call last): + ... + AssertionError: expected call not found. + Expected: mock('other') + Actual: mock('foo', bar='bar') + + .. method:: assert_awaited_once_with(*args, **kwargs) + + Assert that the mock was awaited exactly once and with the specified + arguments. + + >>> mock = AsyncMock() + >>> async def main(*args, **kwargs): + ... await mock(*args, **kwargs) + ... + >>> asyncio.run(main('foo', bar='bar')) + >>> mock.assert_awaited_once_with('foo', bar='bar') + >>> asyncio.run(main('foo', bar='bar')) + >>> mock.assert_awaited_once_with('foo', bar='bar') + Traceback (most recent call last): + ... + AssertionError: Expected mock to have been awaited once. Awaited 2 times. + + .. method:: assert_any_await(*args, **kwargs) + + Assert the mock has ever been awaited with the specified arguments. + + >>> mock = AsyncMock() + >>> async def main(*args, **kwargs): + ... await mock(*args, **kwargs) + ... + >>> asyncio.run(main('foo', bar='bar')) + >>> asyncio.run(main('hello')) + >>> mock.assert_any_await('foo', bar='bar') + >>> mock.assert_any_await('other') + Traceback (most recent call last): + ... + AssertionError: mock('other') await not found + + .. method:: assert_has_awaits(calls, any_order=False) + + Assert the mock has been awaited with the specified calls. + The :attr:`await_args_list` list is checked for the awaits. + + If *any_order* is False (the default) then the awaits must be + sequential. There can be extra calls before or after the + specified awaits. + + If *any_order* is True then the awaits can be in any order, but + they must all appear in :attr:`await_args_list`. + + >>> mock = AsyncMock() + >>> async def main(*args, **kwargs): + ... await mock(*args, **kwargs) + ... + >>> calls = [call("foo"), call("bar")] + >>> mock.assert_has_calls(calls) + Traceback (most recent call last): + ... + AssertionError: Calls not found. + Expected: [call('foo'), call('bar')] + >>> asyncio.run(main('foo')) + >>> asyncio.run(main('bar')) + >>> mock.assert_has_calls(calls) + + .. method:: assert_not_awaited() + + Assert that the mock was never awaited. + + >>> mock = AsyncMock() + >>> mock.assert_not_awaited() + + .. method:: reset_mock(*args, **kwargs) + + See :func:`Mock.reset_mock`. Also sets :attr:`await_count` to 0, + :attr:`await_args` to None, and clears the :attr:`await_args_list`. + + .. attribute:: await_count + + An integer keeping track of how many times the mock object has been awaited. + + >>> mock = AsyncMock() + >>> async def main(): + ... await mock() + ... + >>> asyncio.run(main()) + >>> mock.await_count + 1 + >>> asyncio.run(main()) + >>> mock.await_count + 2 + + .. attribute:: await_args + + This is either ``None`` (if the mock hasn’t been awaited), or the arguments that + the mock was last awaited with. Functions the same as :attr:`Mock.call_args`. + + >>> mock = AsyncMock() + >>> async def main(*args): + ... await mock(*args) + ... + >>> mock.await_args + >>> asyncio.run(main('foo')) + >>> mock.await_args + call('foo') + >>> asyncio.run(main('bar')) + >>> mock.await_args + call('bar') + + + .. attribute:: await_args_list + + This is a list of all the awaits made to the mock object in sequence (so the + length of the list is the number of times it has been awaited). Before any + awaits have been made it is an empty list. + + >>> mock = AsyncMock() + >>> async def main(*args): + ... await mock(*args) + ... + >>> mock.await_args_list + [] + >>> asyncio.run(main('foo')) + >>> mock.await_args_list + [call('foo')] + >>> asyncio.run(main('bar')) + >>> mock.await_args_list + [call('foo'), call('bar')] + + Calling ~~~~~~~ diff --git a/Doc/whatsnew/3.8.rst b/Doc/whatsnew/3.8.rst index 07da404..0a79b6c 100644 --- a/Doc/whatsnew/3.8.rst +++ b/Doc/whatsnew/3.8.rst @@ -538,6 +538,10 @@ unicodedata unittest -------- +* XXX Added :class:`AsyncMock` to support an asynchronous version of :class:`Mock`. + Appropriate new assert functions for testing have been added as well. + (Contributed by Lisa Roach in :issue:`26467`). + * Added :func:`~unittest.addModuleCleanup()` and :meth:`~unittest.TestCase.addClassCleanup()` to unittest to support cleanups for :func:`~unittest.setUpModule()` and diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index 47ed06c..166c100 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -13,6 +13,7 @@ __all__ = ( 'ANY', 'call', 'create_autospec', + 'AsyncMock', 'FILTER_DIR', 'NonCallableMock', 'NonCallableMagicMock', @@ -24,13 +25,13 @@ __all__ = ( __version__ = '1.0' - +import asyncio import io import inspect import pprint import sys import builtins -from types import ModuleType, MethodType +from types import CodeType, ModuleType, MethodType from unittest.util import safe_repr from functools import wraps, partial @@ -43,6 +44,13 @@ FILTER_DIR = True # Without this, the __class__ properties wouldn't be set correctly _safe_super = super +def _is_async_obj(obj): + if getattr(obj, '__code__', None): + return asyncio.iscoroutinefunction(obj) or inspect.isawaitable(obj) + else: + return False + + def _is_instance_mock(obj): # can't use isinstance on Mock objects because they override __class__ # The base class for all mocks is NonCallableMock @@ -355,7 +363,20 @@ class NonCallableMock(Base): # every instance has its own class # so we can create magic methods on the # class without stomping on other mocks - new = type(cls.__name__, (cls,), {'__doc__': cls.__doc__}) + bases = (cls,) + if not issubclass(cls, AsyncMock): + # Check if spec is an async object or function + sig = inspect.signature(NonCallableMock.__init__) + bound_args = sig.bind_partial(cls, *args, **kw).arguments + spec_arg = [ + arg for arg in bound_args.keys() + if arg.startswith('spec') + ] + if spec_arg: + # what if spec_set is different than spec? + if _is_async_obj(bound_args[spec_arg[0]]): + bases = (AsyncMockMixin, cls,) + new = type(cls.__name__, bases, {'__doc__': cls.__doc__}) instance = object.__new__(new) return instance @@ -431,6 +452,11 @@ class NonCallableMock(Base): _eat_self=False): _spec_class = None _spec_signature = None + _spec_asyncs = [] + + for attr in dir(spec): + if asyncio.iscoroutinefunction(getattr(spec, attr, None)): + _spec_asyncs.append(attr) if spec is not None and not _is_list(spec): if isinstance(spec, type): @@ -448,7 +474,7 @@ class NonCallableMock(Base): __dict__['_spec_set'] = spec_set __dict__['_spec_signature'] = _spec_signature __dict__['_mock_methods'] = spec - + __dict__['_spec_asyncs'] = _spec_asyncs def __get_return_value(self): ret = self._mock_return_value @@ -886,7 +912,15 @@ class NonCallableMock(Base): For non-callable mocks the callable variant will be used (rather than any custom subclass).""" + _new_name = kw.get("_new_name") + if _new_name in self.__dict__['_spec_asyncs']: + return AsyncMock(**kw) + _type = type(self) + if issubclass(_type, MagicMock) and _new_name in _async_method_magics: + klass = AsyncMock + if issubclass(_type, AsyncMockMixin): + klass = MagicMock if not issubclass(_type, CallableMixin): if issubclass(_type, NonCallableMagicMock): klass = MagicMock @@ -932,14 +966,12 @@ def _try_iter(obj): return obj - class CallableMixin(Base): def __init__(self, spec=None, side_effect=None, return_value=DEFAULT, wraps=None, name=None, spec_set=None, parent=None, _spec_state=None, _new_name='', _new_parent=None, **kwargs): self.__dict__['_mock_return_value'] = return_value - _safe_super(CallableMixin, self).__init__( spec, wraps, name, spec_set, parent, _spec_state, _new_name, _new_parent, **kwargs @@ -1081,7 +1113,6 @@ class Mock(CallableMixin, NonCallableMock): """ - def _dot_lookup(thing, comp, import_path): try: return getattr(thing, comp) @@ -1279,8 +1310,10 @@ class _patch(object): if isinstance(original, type): # If we're patching out a class and there is a spec inherit = True - - Klass = MagicMock + if spec is None and _is_async_obj(original): + Klass = AsyncMock + else: + Klass = MagicMock _kwargs = {} if new_callable is not None: Klass = new_callable @@ -1292,7 +1325,9 @@ class _patch(object): not_callable = '__call__' not in this_spec else: not_callable = not callable(this_spec) - if not_callable: + if _is_async_obj(this_spec): + Klass = AsyncMock + elif not_callable: Klass = NonCallableMagicMock if spec is not None: @@ -1733,7 +1768,7 @@ _non_defaults = { '__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__', '__getstate__', '__setstate__', '__getformat__', '__setformat__', '__repr__', '__dir__', '__subclasses__', '__format__', - '__getnewargs_ex__', + '__getnewargs_ex__', '__aenter__', '__aexit__', '__anext__', '__aiter__', } @@ -1750,6 +1785,11 @@ _magics = { ' '.join([magic_methods, numerics, inplace, right]).split() } +# Magic methods used for async `with` statements +_async_method_magics = {"__aenter__", "__aexit__", "__anext__"} +# `__aiter__` is a plain function but used with async calls +_async_magics = _async_method_magics | {"__aiter__"} + _all_magics = _magics | _non_defaults _unsupported_magics = { @@ -1779,6 +1819,7 @@ _return_values = { '__float__': 1.0, '__bool__': True, '__index__': 1, + '__aexit__': False, } @@ -1811,10 +1852,19 @@ def _get_iter(self): return iter(ret_val) return __iter__ +def _get_async_iter(self): + def __aiter__(): + ret_val = self.__aiter__._mock_return_value + if ret_val is DEFAULT: + return _AsyncIterator(iter([])) + return _AsyncIterator(iter(ret_val)) + return __aiter__ + _side_effect_methods = { '__eq__': _get_eq, '__ne__': _get_ne, '__iter__': _get_iter, + '__aiter__': _get_async_iter } @@ -1879,8 +1929,33 @@ class NonCallableMagicMock(MagicMixin, NonCallableMock): self._mock_set_magics() +class AsyncMagicMixin: + def __init__(self, *args, **kw): + self._mock_set_async_magics() # make magic work for kwargs in init + _safe_super(AsyncMagicMixin, self).__init__(*args, **kw) + self._mock_set_async_magics() # fix magic broken by upper level init + + def _mock_set_async_magics(self): + these_magics = _async_magics -class MagicMock(MagicMixin, Mock): + if getattr(self, "_mock_methods", None) is not None: + these_magics = _async_magics.intersection(self._mock_methods) + remove_magics = _async_magics - these_magics + + for entry in remove_magics: + if entry in type(self).__dict__: + # remove unneeded magic methods + delattr(self, entry) + + # don't overwrite existing attributes if called a second time + these_magics = these_magics - set(type(self).__dict__) + + _type = type(self) + for entry in these_magics: + setattr(_type, entry, MagicProxy(entry, self)) + + +class MagicMock(MagicMixin, AsyncMagicMixin, Mock): """ MagicMock is a subclass of Mock with default implementations of most of the magic methods. You can use MagicMock without having to @@ -1920,6 +1995,218 @@ class MagicProxy(object): return self.create_mock() +class AsyncMockMixin(Base): + awaited = _delegating_property('awaited') + await_count = _delegating_property('await_count') + await_args = _delegating_property('await_args') + await_args_list = _delegating_property('await_args_list') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # asyncio.iscoroutinefunction() checks _is_coroutine property to say if an + # object is a coroutine. Without this check it looks to see if it is a + # function/method, which in this case it is not (since it is an + # AsyncMock). + # It is set through __dict__ because when spec_set is True, this + # attribute is likely undefined. + self.__dict__['_is_coroutine'] = asyncio.coroutines._is_coroutine + self.__dict__['_mock_awaited'] = _AwaitEvent(self) + self.__dict__['_mock_await_count'] = 0 + self.__dict__['_mock_await_args'] = None + self.__dict__['_mock_await_args_list'] = _CallList() + code_mock = NonCallableMock(spec_set=CodeType) + code_mock.co_flags = inspect.CO_COROUTINE + self.__dict__['__code__'] = code_mock + + async def _mock_call(_mock_self, *args, **kwargs): + self = _mock_self + try: + result = super()._mock_call(*args, **kwargs) + except (BaseException, StopIteration) as e: + side_effect = self.side_effect + if side_effect is not None and not callable(side_effect): + raise + return await _raise(e) + + _call = self.call_args + + async def proxy(): + try: + if inspect.isawaitable(result): + return await result + else: + return result + finally: + self.await_count += 1 + self.await_args = _call + self.await_args_list.append(_call) + await self.awaited._notify() + + return await proxy() + + def assert_awaited(_mock_self): + """ + Assert that the mock was awaited at least once. + """ + self = _mock_self + if self.await_count == 0: + msg = f"Expected {self._mock_name or 'mock'} to have been awaited." + raise AssertionError(msg) + + def assert_awaited_once(_mock_self): + """ + Assert that the mock was awaited exactly once. + """ + self = _mock_self + if not self.await_count == 1: + msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + + def assert_awaited_with(_mock_self, *args, **kwargs): + """ + Assert that the last await was with the specified arguments. + """ + self = _mock_self + if self.await_args is None: + expected = self._format_mock_call_signature(args, kwargs) + raise AssertionError(f'Expected await: {expected}\nNot awaited') + + def _error_message(): + msg = self._format_mock_failure_message(args, kwargs) + return msg + + expected = self._call_matcher((args, kwargs)) + actual = self._call_matcher(self.await_args) + if expected != actual: + cause = expected if isinstance(expected, Exception) else None + raise AssertionError(_error_message()) from cause + + def assert_awaited_once_with(_mock_self, *args, **kwargs): + """ + Assert that the mock was awaited exactly once and with the specified + arguments. + """ + self = _mock_self + if not self.await_count == 1: + msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + return self.assert_awaited_with(*args, **kwargs) + + def assert_any_await(_mock_self, *args, **kwargs): + """ + Assert the mock has ever been awaited with the specified arguments. + """ + self = _mock_self + expected = self._call_matcher((args, kwargs)) + actual = [self._call_matcher(c) for c in self.await_args_list] + if expected not in actual: + cause = expected if isinstance(expected, Exception) else None + expected_string = self._format_mock_call_signature(args, kwargs) + raise AssertionError( + '%s await not found' % expected_string + ) from cause + + def assert_has_awaits(_mock_self, calls, any_order=False): + """ + Assert the mock has been awaited with the specified calls. + The :attr:`await_args_list` list is checked for the awaits. + + If `any_order` is False (the default) then the awaits must be + sequential. There can be extra calls before or after the + specified awaits. + + If `any_order` is True then the awaits can be in any order, but + they must all appear in :attr:`await_args_list`. + """ + self = _mock_self + expected = [self._call_matcher(c) for c in calls] + cause = expected if isinstance(expected, Exception) else None + all_awaits = _CallList(self._call_matcher(c) for c in self.await_args_list) + if not any_order: + if expected not in all_awaits: + raise AssertionError( + f'Awaits not found.\nExpected: {_CallList(calls)}\n', + f'Actual: {self.await_args_list}' + ) from cause + return + + all_awaits = list(all_awaits) + + not_found = [] + for kall in expected: + try: + all_awaits.remove(kall) + except ValueError: + not_found.append(kall) + if not_found: + raise AssertionError( + '%r not all found in await list' % (tuple(not_found),) + ) from cause + + def assert_not_awaited(_mock_self): + """ + Assert that the mock was never awaited. + """ + self = _mock_self + if self.await_count != 0: + msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + + def reset_mock(self, *args, **kwargs): + """ + See :func:`.Mock.reset_mock()` + """ + super().reset_mock(*args, **kwargs) + self.await_count = 0 + self.await_args = None + self.await_args_list = _CallList() + + +class AsyncMock(AsyncMockMixin, AsyncMagicMixin, Mock): + """ + Enhance :class:`Mock` with features allowing to mock + an async function. + + The :class:`AsyncMock` object will behave so the object is + recognized as an async function, and the result of a call is an awaitable: + + >>> mock = AsyncMock() + >>> asyncio.iscoroutinefunction(mock) + True + >>> inspect.isawaitable(mock()) + True + + + The result of ``mock()`` is an async function which will have the outcome + of ``side_effect`` or ``return_value``: + + - if ``side_effect`` is a function, the async function will return the + result of that function, + - if ``side_effect`` is an exception, the async function will raise the + exception, + - if ``side_effect`` is an iterable, the async function will return the + next value of the iterable, however, if the sequence of result is + exhausted, ``StopIteration`` is raised immediately, + - if ``side_effect`` is not defined, the async function will return the + value defined by ``return_value``, hence, by default, the async function + returns a new :class:`AsyncMock` object. + + If the outcome of ``side_effect`` or ``return_value`` is an async function, + the mock async function obtained when the mock object is called will be this + async function itself (and not an async function returning an async + function). + + The test author can also specify a wrapped object with ``wraps``. In this + case, the :class:`Mock` object behavior is the same as with an + :class:`.Mock` object: the wrapped object may have methods + defined as async function functions. + + Based on Martin Richard's asyntest project. + """ + class _ANY(object): "A helper object that compares equal to everything." @@ -2145,7 +2432,6 @@ class _Call(tuple): call = _Call(from_kall=False) - def create_autospec(spec, spec_set=False, instance=False, _parent=None, _name=None, **kwargs): """Create a mock object using another object as a spec. Attributes on the @@ -2171,7 +2457,10 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, spec = type(spec) is_type = isinstance(spec, type) - + if getattr(spec, '__code__', None): + is_async_func = asyncio.iscoroutinefunction(spec) + else: + is_async_func = False _kwargs = {'spec': spec} if spec_set: _kwargs = {'spec_set': spec} @@ -2188,6 +2477,11 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, # descriptors don't have a spec # because we don't know what type they return _kwargs = {} + elif is_async_func: + if instance: + raise RuntimeError("Instance can not be True when create_autospec " + "is mocking an async function") + Klass = AsyncMock elif not _callable(spec): Klass = NonCallableMagicMock elif is_type and instance and not _instance_callable(spec): @@ -2204,9 +2498,26 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, name=_name, **_kwargs) if isinstance(spec, FunctionTypes): + wrapped_mock = mock # should only happen at the top level because we don't # recurse for functions mock = _set_signature(mock, spec) + if is_async_func: + mock._is_coroutine = asyncio.coroutines._is_coroutine + mock.await_count = 0 + mock.await_args = None + mock.await_args_list = _CallList() + + for a in ('assert_awaited', + 'assert_awaited_once', + 'assert_awaited_with', + 'assert_awaited_once_with', + 'assert_any_await', + 'assert_has_awaits', + 'assert_not_awaited'): + def f(*args, **kwargs): + return getattr(wrapped_mock, a)(*args, **kwargs) + setattr(mock, a, f) else: _check_signature(spec, mock, is_type, instance) @@ -2250,9 +2561,13 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, skipfirst = _must_skip(spec, entry, is_type) kwargs['_eat_self'] = skipfirst - new = MagicMock(parent=parent, name=entry, _new_name=entry, - _new_parent=parent, - **kwargs) + if asyncio.iscoroutinefunction(original): + child_klass = AsyncMock + else: + child_klass = MagicMock + new = child_klass(parent=parent, name=entry, _new_name=entry, + _new_parent=parent, + **kwargs) mock._mock_children[entry] = new _check_signature(original, new, skipfirst=skipfirst) @@ -2438,3 +2753,60 @@ def seal(mock): continue if m._mock_new_parent is mock: seal(m) + + +async def _raise(exception): + raise exception + + +class _AsyncIterator: + """ + Wraps an iterator in an asynchronous iterator. + """ + def __init__(self, iterator): + self.iterator = iterator + code_mock = NonCallableMock(spec_set=CodeType) + code_mock.co_flags = inspect.CO_ITERABLE_COROUTINE + self.__dict__['__code__'] = code_mock + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iterator) + except StopIteration: + pass + raise StopAsyncIteration + + +class _AwaitEvent: + def __init__(self, mock): + self._mock = mock + self._condition = None + + async def _notify(self): + condition = self._get_condition() + try: + await condition.acquire() + condition.notify_all() + finally: + condition.release() + + def _get_condition(self): + """ + Creation of condition is delayed, to minimize the chance of using the + wrong loop. + A user may create a mock with _AwaitEvent before selecting the + execution loop. Requiring a user to delay creation is error-prone and + inflexible. Instead, condition is created when user actually starts to + use the mock. + """ + # No synchronization is needed: + # - asyncio is thread unsafe + # - there are no awaits here, method will be executed without + # switching asyncio context. + if self._condition is None: + self._condition = asyncio.Condition() + + return self._condition diff --git a/Lib/unittest/test/testmock/testasync.py b/Lib/unittest/test/testmock/testasync.py new file mode 100644 index 0000000..a9aa143 --- /dev/null +++ b/Lib/unittest/test/testmock/testasync.py @@ -0,0 +1,549 @@ +import asyncio +import inspect +import unittest + +from unittest.mock import call, AsyncMock, patch, MagicMock, create_autospec + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class AsyncClass: + def __init__(self): + pass + async def async_method(self): + pass + def normal_method(self): + pass + +async def async_func(): + pass + +def normal_func(): + pass + +class NormalClass(object): + def a(self): + pass + + +async_foo_name = f'{__name__}.AsyncClass' +normal_foo_name = f'{__name__}.NormalClass' + + +class AsyncPatchDecoratorTest(unittest.TestCase): + def test_is_coroutine_function_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + self.assertTrue(asyncio.iscoroutinefunction(mock_method)) + test_async() + + def test_is_async_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + @patch(f'{async_foo_name}.async_method') + def test_no_parent_attribute(mock_method): + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + test_async() + test_no_parent_attribute() + + def test_is_AsyncMock_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + +class AsyncPatchCMTest(unittest.TestCase): + def test_is_async_function_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + self.assertTrue(asyncio.iscoroutinefunction(mock_method)) + + test_async() + + def test_is_async_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + test_async() + + def test_is_AsyncMock_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + +class AsyncMockTest(unittest.TestCase): + def test_iscoroutinefunction_default(self): + mock = AsyncMock() + self.assertTrue(asyncio.iscoroutinefunction(mock)) + + def test_iscoroutinefunction_function(self): + async def foo(): pass + mock = AsyncMock(foo) + self.assertTrue(asyncio.iscoroutinefunction(mock)) + self.assertTrue(inspect.iscoroutinefunction(mock)) + + def test_isawaitable(self): + mock = AsyncMock() + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + self.assertIn('assert_awaited', dir(mock)) + + def test_iscoroutinefunction_normal_function(self): + def foo(): pass + mock = AsyncMock(foo) + self.assertTrue(asyncio.iscoroutinefunction(mock)) + self.assertTrue(inspect.iscoroutinefunction(mock)) + + def test_future_isfuture(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + fut = asyncio.Future() + loop.stop() + loop.close() + mock = AsyncMock(fut) + self.assertIsInstance(mock, asyncio.Future) + + +class AsyncAutospecTest(unittest.TestCase): + def test_is_AsyncMock_patch(self): + @patch(async_foo_name, autospec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method.async_method, AsyncMock) + self.assertIsInstance(mock_method, MagicMock) + + @patch(async_foo_name, autospec=True) + def test_normal_method(mock_method): + self.assertIsInstance(mock_method.normal_method, MagicMock) + + test_async() + test_normal_method() + + def test_create_autospec_instance(self): + with self.assertRaises(RuntimeError): + create_autospec(async_func, instance=True) + + def test_create_autospec(self): + spec = create_autospec(async_func) + self.assertTrue(asyncio.iscoroutinefunction(spec)) + + +class AsyncSpecTest(unittest.TestCase): + def test_spec_as_async_positional_magicmock(self): + mock = MagicMock(async_func) + self.assertIsInstance(mock, MagicMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_as_async_kw_magicmock(self): + mock = MagicMock(spec=async_func) + self.assertIsInstance(mock, MagicMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_as_async_kw_AsyncMock(self): + mock = AsyncMock(spec=async_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_as_async_positional_AsyncMock(self): + mock = AsyncMock(async_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_as_normal_kw_AsyncMock(self): + mock = AsyncMock(spec=normal_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_as_normal_positional_AsyncMock(self): + mock = AsyncMock(normal_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_async_mock(self): + @patch.object(AsyncClass, 'async_method', spec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_spec_parent_not_async_attribute_is(self): + @patch(async_foo_name, spec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method, MagicMock) + self.assertIsInstance(mock_method.async_method, AsyncMock) + + test_async() + + def test_target_async_spec_not(self): + @patch.object(AsyncClass, 'async_method', spec=NormalClass.a) + def test_async_attribute(mock_method): + self.assertIsInstance(mock_method, MagicMock) + self.assertFalse(inspect.iscoroutine(mock_method)) + self.assertFalse(inspect.isawaitable(mock_method)) + + test_async_attribute() + + def test_target_not_async_spec_is(self): + @patch.object(NormalClass, 'a', spec=async_func) + def test_attribute_not_async_spec_is(mock_async_func): + self.assertIsInstance(mock_async_func, AsyncMock) + test_attribute_not_async_spec_is() + + def test_spec_async_attributes(self): + @patch(normal_foo_name, spec=AsyncClass) + def test_async_attributes_coroutines(MockNormalClass): + self.assertIsInstance(MockNormalClass.async_method, AsyncMock) + self.assertIsInstance(MockNormalClass, MagicMock) + + test_async_attributes_coroutines() + + +class AsyncSpecSetTest(unittest.TestCase): + def test_is_AsyncMock_patch(self): + @patch.object(AsyncClass, 'async_method', spec_set=True) + def test_async(async_method): + self.assertIsInstance(async_method, AsyncMock) + + def test_is_async_AsyncMock(self): + mock = AsyncMock(spec_set=AsyncClass.async_method) + self.assertTrue(asyncio.iscoroutinefunction(mock)) + self.assertIsInstance(mock, AsyncMock) + + def test_is_child_AsyncMock(self): + mock = MagicMock(spec_set=AsyncClass) + self.assertTrue(asyncio.iscoroutinefunction(mock.async_method)) + self.assertFalse(asyncio.iscoroutinefunction(mock.normal_method)) + self.assertIsInstance(mock.async_method, AsyncMock) + self.assertIsInstance(mock.normal_method, MagicMock) + self.assertIsInstance(mock, MagicMock) + + +class AsyncArguments(unittest.TestCase): + def test_add_return_value(self): + async def addition(self, var): + return var + 1 + + mock = AsyncMock(addition, return_value=10) + output = asyncio.run(mock(5)) + + self.assertEqual(output, 10) + + def test_add_side_effect_exception(self): + async def addition(var): + return var + 1 + mock = AsyncMock(addition, side_effect=Exception('err')) + with self.assertRaises(Exception): + asyncio.run(mock(5)) + + def test_add_side_effect_function(self): + async def addition(var): + return var + 1 + mock = AsyncMock(side_effect=addition) + result = asyncio.run(mock(5)) + self.assertEqual(result, 6) + + def test_add_side_effect_iterable(self): + vals = [1, 2, 3] + mock = AsyncMock(side_effect=vals) + for item in vals: + self.assertEqual(item, asyncio.run(mock())) + + with self.assertRaises(RuntimeError) as e: + asyncio.run(mock()) + self.assertEqual( + e.exception, + RuntimeError('coroutine raised StopIteration') + ) + + +class AsyncContextManagerTest(unittest.TestCase): + class WithAsyncContextManager: + def __init__(self): + self.entered = False + self.exited = False + + async def __aenter__(self, *args, **kwargs): + self.entered = True + return self + + async def __aexit__(self, *args, **kwargs): + self.exited = True + + def test_magic_methods_are_async_mocks(self): + mock = MagicMock(self.WithAsyncContextManager()) + self.assertIsInstance(mock.__aenter__, AsyncMock) + self.assertIsInstance(mock.__aexit__, AsyncMock) + + def test_mock_supports_async_context_manager(self): + called = False + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + + async def use_context_manager(): + nonlocal called + async with mock_instance as result: + called = True + return result + + result = asyncio.run(use_context_manager()) + self.assertFalse(instance.entered) + self.assertFalse(instance.exited) + self.assertTrue(called) + self.assertTrue(mock_instance.entered) + self.assertTrue(mock_instance.exited) + self.assertTrue(mock_instance.__aenter__.called) + self.assertTrue(mock_instance.__aexit__.called) + self.assertIsNot(mock_instance, result) + self.assertIsInstance(result, AsyncMock) + + def test_mock_customize_async_context_manager(self): + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + + expected_result = object() + mock_instance.__aenter__.return_value = expected_result + + async def use_context_manager(): + async with mock_instance as result: + return result + + self.assertIs(asyncio.run(use_context_manager()), expected_result) + + def test_mock_customize_async_context_manager_with_coroutine(self): + enter_called = False + exit_called = False + + async def enter_coroutine(*args): + nonlocal enter_called + enter_called = True + + async def exit_coroutine(*args): + nonlocal exit_called + exit_called = True + + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + + mock_instance.__aenter__ = enter_coroutine + mock_instance.__aexit__ = exit_coroutine + + async def use_context_manager(): + async with mock_instance: + pass + + asyncio.run(use_context_manager()) + self.assertTrue(enter_called) + self.assertTrue(exit_called) + + def test_context_manager_raise_exception_by_default(self): + async def raise_in(context_manager): + async with context_manager: + raise TypeError() + + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + with self.assertRaises(TypeError): + asyncio.run(raise_in(mock_instance)) + + +class AsyncIteratorTest(unittest.TestCase): + class WithAsyncIterator(object): + def __init__(self): + self.items = ["foo", "NormalFoo", "baz"] + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return self.items.pop() + except IndexError: + pass + + raise StopAsyncIteration + + def test_mock_aiter_and_anext(self): + instance = self.WithAsyncIterator() + mock_instance = MagicMock(instance) + + self.assertEqual(asyncio.iscoroutine(instance.__aiter__), + asyncio.iscoroutine(mock_instance.__aiter__)) + self.assertEqual(asyncio.iscoroutine(instance.__anext__), + asyncio.iscoroutine(mock_instance.__anext__)) + + iterator = instance.__aiter__() + if asyncio.iscoroutine(iterator): + iterator = asyncio.run(iterator) + + mock_iterator = mock_instance.__aiter__() + if asyncio.iscoroutine(mock_iterator): + mock_iterator = asyncio.run(mock_iterator) + + self.assertEqual(asyncio.iscoroutine(iterator.__aiter__), + asyncio.iscoroutine(mock_iterator.__aiter__)) + self.assertEqual(asyncio.iscoroutine(iterator.__anext__), + asyncio.iscoroutine(mock_iterator.__anext__)) + + def test_mock_async_for(self): + async def iterate(iterator): + accumulator = [] + async for item in iterator: + accumulator.append(item) + + return accumulator + + expected = ["FOO", "BAR", "BAZ"] + with self.subTest("iterate through default value"): + mock_instance = MagicMock(self.WithAsyncIterator()) + self.assertEqual([], asyncio.run(iterate(mock_instance))) + + with self.subTest("iterate through set return_value"): + mock_instance = MagicMock(self.WithAsyncIterator()) + mock_instance.__aiter__.return_value = expected[:] + self.assertEqual(expected, asyncio.run(iterate(mock_instance))) + + with self.subTest("iterate through set return_value iterator"): + mock_instance = MagicMock(self.WithAsyncIterator()) + mock_instance.__aiter__.return_value = iter(expected[:]) + self.assertEqual(expected, asyncio.run(iterate(mock_instance))) + + +class AsyncMockAssert(unittest.TestCase): + def setUp(self): + self.mock = AsyncMock() + + async def _runnable_test(self, *args): + if not args: + await self.mock() + else: + await self.mock(*args) + + def test_assert_awaited(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited() + + asyncio.run(self._runnable_test()) + self.mock.assert_awaited() + + def test_assert_awaited_once(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once() + + asyncio.run(self._runnable_test()) + self.mock.assert_awaited_once() + + asyncio.run(self._runnable_test()) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once() + + def test_assert_awaited_with(self): + asyncio.run(self._runnable_test()) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_with('foo') + + asyncio.run(self._runnable_test('foo')) + self.mock.assert_awaited_with('foo') + + asyncio.run(self._runnable_test('SomethingElse')) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_with('foo') + + def test_assert_awaited_once_with(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once_with('foo') + + asyncio.run(self._runnable_test('foo')) + self.mock.assert_awaited_once_with('foo') + + asyncio.run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once_with('foo') + + def test_assert_any_wait(self): + with self.assertRaises(AssertionError): + self.mock.assert_any_await('NormalFoo') + + asyncio.run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_any_await('NormalFoo') + + asyncio.run(self._runnable_test('NormalFoo')) + self.mock.assert_any_await('NormalFoo') + + asyncio.run(self._runnable_test('SomethingElse')) + self.mock.assert_any_await('NormalFoo') + + def test_assert_has_awaits_no_order(self): + calls = [call('NormalFoo'), call('baz')] + + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls) + + asyncio.run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls) + + asyncio.run(self._runnable_test('NormalFoo')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls) + + asyncio.run(self._runnable_test('baz')) + self.mock.assert_has_awaits(calls) + + asyncio.run(self._runnable_test('SomethingElse')) + self.mock.assert_has_awaits(calls) + + def test_assert_has_awaits_ordered(self): + calls = [call('NormalFoo'), call('baz')] + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('baz')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('NormalFoo')) + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('qux')) + self.mock.assert_has_awaits(calls, any_order=True) + + def test_assert_not_awaited(self): + self.mock.assert_not_awaited() + + asyncio.run(self._runnable_test()) + with self.assertRaises(AssertionError): + self.mock.assert_not_awaited() diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py index b20b8e2..307b8b7 100644 --- a/Lib/unittest/test/testmock/testmock.py +++ b/Lib/unittest/test/testmock/testmock.py @@ -9,7 +9,7 @@ from unittest import mock from unittest.mock import ( call, DEFAULT, patch, sentinel, MagicMock, Mock, NonCallableMock, - NonCallableMagicMock, _Call, _CallList, + NonCallableMagicMock, AsyncMock, _Call, _CallList, create_autospec ) @@ -1618,7 +1618,8 @@ class MockTest(unittest.TestCase): def test_adding_child_mock(self): - for Klass in NonCallableMock, Mock, MagicMock, NonCallableMagicMock: + for Klass in (NonCallableMock, Mock, MagicMock, NonCallableMagicMock, + AsyncMock): mock = Klass() mock.foo = Mock() diff --git a/Misc/NEWS.d/next/Library/2018-09-13-20-33-24.bpo-26467.cahAk3.rst b/Misc/NEWS.d/next/Library/2018-09-13-20-33-24.bpo-26467.cahAk3.rst new file mode 100644 index 0000000..4cf3f2a --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-09-13-20-33-24.bpo-26467.cahAk3.rst @@ -0,0 +1,2 @@ +Added AsyncMock to support using unittest to mock asyncio coroutines. +Patch by Lisa Roach. |