summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Doc/library/unittest.mock.rst215
-rw-r--r--Doc/whatsnew/3.8.rst4
-rw-r--r--Lib/unittest/mock.py406
-rw-r--r--Lib/unittest/test/testmock/testasync.py549
-rw-r--r--Lib/unittest/test/testmock/testmock.py5
-rw-r--r--Misc/NEWS.d/next/Library/2018-09-13-20-33-24.bpo-26467.cahAk3.rst2
6 files changed, 1161 insertions, 20 deletions
diff --git a/Doc/library/unittest.mock.rst b/Doc/library/unittest.mock.rst
index ed00ee6..21e4709 100644
--- a/Doc/library/unittest.mock.rst
+++ b/Doc/library/unittest.mock.rst
@@ -201,9 +201,11 @@ The Mock Class
.. testsetup::
+ import asyncio
+ import inspect
import unittest
from unittest.mock import sentinel, DEFAULT, ANY
- from unittest.mock import patch, call, Mock, MagicMock, PropertyMock
+ from unittest.mock import patch, call, Mock, MagicMock, PropertyMock, AsyncMock
from unittest.mock import mock_open
:class:`Mock` is a flexible mock object intended to replace the use of stubs and
@@ -851,6 +853,217 @@ object::
>>> p.assert_called_once_with()
+.. class:: AsyncMock(spec=None, side_effect=None, return_value=DEFAULT, wraps=None, name=None, spec_set=None, unsafe=False, **kwargs)
+
+ An asynchronous version of :class:`Mock`. The :class:`AsyncMock` object will
+ behave so the object is recognized as an async function, and the result of a
+ call is an awaitable.
+
+ >>> mock = AsyncMock()
+ >>> asyncio.iscoroutinefunction(mock)
+ True
+ >>> inspect.isawaitable(mock())
+ True
+
+ The result of ``mock()`` is an async function which will have the outcome
+ of ``side_effect`` or ``return_value``:
+
+ - if ``side_effect`` is a function, the async function will return the
+ result of that function,
+ - if ``side_effect`` is an exception, the async function will raise the
+ exception,
+ - if ``side_effect`` is an iterable, the async function will return the
+ next value of the iterable, however, if the sequence of result is
+ exhausted, ``StopIteration`` is raised immediately,
+ - if ``side_effect`` is not defined, the async function will return the
+ value defined by ``return_value``, hence, by default, the async function
+ returns a new :class:`AsyncMock` object.
+
+
+ Setting the *spec* of a :class:`Mock` or :class:`MagicMock` to an async function
+ will result in a coroutine object being returned after calling.
+
+ >>> async def async_func(): pass
+ ...
+ >>> mock = MagicMock(async_func)
+ >>> mock
+ <MagicMock spec='function' id='...'>
+ >>> mock()
+ <coroutine object AsyncMockMixin._mock_call at ...>
+
+ .. method:: assert_awaited()
+
+ Assert that the mock was awaited at least once.
+
+ >>> mock = AsyncMock()
+ >>> async def main():
+ ... await mock()
+ ...
+ >>> asyncio.run(main())
+ >>> mock.assert_awaited()
+ >>> mock_2 = AsyncMock()
+ >>> mock_2.assert_awaited()
+ Traceback (most recent call last):
+ ...
+ AssertionError: Expected mock to have been awaited.
+
+ .. method:: assert_awaited_once()
+
+ Assert that the mock was awaited exactly once.
+
+ >>> mock = AsyncMock()
+ >>> async def main():
+ ... await mock()
+ ...
+ >>> asyncio.run(main())
+ >>> mock.assert_awaited_once()
+ >>> asyncio.run(main())
+ >>> mock.method.assert_awaited_once()
+ Traceback (most recent call last):
+ ...
+ AssertionError: Expected mock to have been awaited once. Awaited 2 times.
+
+ .. method:: assert_awaited_with(*args, **kwargs)
+
+ Assert that the last await was with the specified arguments.
+
+ >>> mock = AsyncMock()
+ >>> async def main(*args, **kwargs):
+ ... await mock(*args, **kwargs)
+ ...
+ >>> asyncio.run(main('foo', bar='bar'))
+ >>> mock.assert_awaited_with('foo', bar='bar')
+ >>> mock.assert_awaited_with('other')
+ Traceback (most recent call last):
+ ...
+ AssertionError: expected call not found.
+ Expected: mock('other')
+ Actual: mock('foo', bar='bar')
+
+ .. method:: assert_awaited_once_with(*args, **kwargs)
+
+ Assert that the mock was awaited exactly once and with the specified
+ arguments.
+
+ >>> mock = AsyncMock()
+ >>> async def main(*args, **kwargs):
+ ... await mock(*args, **kwargs)
+ ...
+ >>> asyncio.run(main('foo', bar='bar'))
+ >>> mock.assert_awaited_once_with('foo', bar='bar')
+ >>> asyncio.run(main('foo', bar='bar'))
+ >>> mock.assert_awaited_once_with('foo', bar='bar')
+ Traceback (most recent call last):
+ ...
+ AssertionError: Expected mock to have been awaited once. Awaited 2 times.
+
+ .. method:: assert_any_await(*args, **kwargs)
+
+ Assert the mock has ever been awaited with the specified arguments.
+
+ >>> mock = AsyncMock()
+ >>> async def main(*args, **kwargs):
+ ... await mock(*args, **kwargs)
+ ...
+ >>> asyncio.run(main('foo', bar='bar'))
+ >>> asyncio.run(main('hello'))
+ >>> mock.assert_any_await('foo', bar='bar')
+ >>> mock.assert_any_await('other')
+ Traceback (most recent call last):
+ ...
+ AssertionError: mock('other') await not found
+
+ .. method:: assert_has_awaits(calls, any_order=False)
+
+ Assert the mock has been awaited with the specified calls.
+ The :attr:`await_args_list` list is checked for the awaits.
+
+ If *any_order* is False (the default) then the awaits must be
+ sequential. There can be extra calls before or after the
+ specified awaits.
+
+ If *any_order* is True then the awaits can be in any order, but
+ they must all appear in :attr:`await_args_list`.
+
+ >>> mock = AsyncMock()
+ >>> async def main(*args, **kwargs):
+ ... await mock(*args, **kwargs)
+ ...
+ >>> calls = [call("foo"), call("bar")]
+ >>> mock.assert_has_calls(calls)
+ Traceback (most recent call last):
+ ...
+ AssertionError: Calls not found.
+ Expected: [call('foo'), call('bar')]
+ >>> asyncio.run(main('foo'))
+ >>> asyncio.run(main('bar'))
+ >>> mock.assert_has_calls(calls)
+
+ .. method:: assert_not_awaited()
+
+ Assert that the mock was never awaited.
+
+ >>> mock = AsyncMock()
+ >>> mock.assert_not_awaited()
+
+ .. method:: reset_mock(*args, **kwargs)
+
+ See :func:`Mock.reset_mock`. Also sets :attr:`await_count` to 0,
+ :attr:`await_args` to None, and clears the :attr:`await_args_list`.
+
+ .. attribute:: await_count
+
+ An integer keeping track of how many times the mock object has been awaited.
+
+ >>> mock = AsyncMock()
+ >>> async def main():
+ ... await mock()
+ ...
+ >>> asyncio.run(main())
+ >>> mock.await_count
+ 1
+ >>> asyncio.run(main())
+ >>> mock.await_count
+ 2
+
+ .. attribute:: await_args
+
+ This is either ``None`` (if the mock hasn’t been awaited), or the arguments that
+ the mock was last awaited with. Functions the same as :attr:`Mock.call_args`.
+
+ >>> mock = AsyncMock()
+ >>> async def main(*args):
+ ... await mock(*args)
+ ...
+ >>> mock.await_args
+ >>> asyncio.run(main('foo'))
+ >>> mock.await_args
+ call('foo')
+ >>> asyncio.run(main('bar'))
+ >>> mock.await_args
+ call('bar')
+
+
+ .. attribute:: await_args_list
+
+ This is a list of all the awaits made to the mock object in sequence (so the
+ length of the list is the number of times it has been awaited). Before any
+ awaits have been made it is an empty list.
+
+ >>> mock = AsyncMock()
+ >>> async def main(*args):
+ ... await mock(*args)
+ ...
+ >>> mock.await_args_list
+ []
+ >>> asyncio.run(main('foo'))
+ >>> mock.await_args_list
+ [call('foo')]
+ >>> asyncio.run(main('bar'))
+ >>> mock.await_args_list
+ [call('foo'), call('bar')]
+
+
Calling
~~~~~~~
diff --git a/Doc/whatsnew/3.8.rst b/Doc/whatsnew/3.8.rst
index 07da404..0a79b6c 100644
--- a/Doc/whatsnew/3.8.rst
+++ b/Doc/whatsnew/3.8.rst
@@ -538,6 +538,10 @@ unicodedata
unittest
--------
+* XXX Added :class:`AsyncMock` to support an asynchronous version of :class:`Mock`.
+ Appropriate new assert functions for testing have been added as well.
+ (Contributed by Lisa Roach in :issue:`26467`).
+
* Added :func:`~unittest.addModuleCleanup()` and
:meth:`~unittest.TestCase.addClassCleanup()` to unittest to support
cleanups for :func:`~unittest.setUpModule()` and
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index 47ed06c..166c100 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -13,6 +13,7 @@ __all__ = (
'ANY',
'call',
'create_autospec',
+ 'AsyncMock',
'FILTER_DIR',
'NonCallableMock',
'NonCallableMagicMock',
@@ -24,13 +25,13 @@ __all__ = (
__version__ = '1.0'
-
+import asyncio
import io
import inspect
import pprint
import sys
import builtins
-from types import ModuleType, MethodType
+from types import CodeType, ModuleType, MethodType
from unittest.util import safe_repr
from functools import wraps, partial
@@ -43,6 +44,13 @@ FILTER_DIR = True
# Without this, the __class__ properties wouldn't be set correctly
_safe_super = super
+def _is_async_obj(obj):
+ if getattr(obj, '__code__', None):
+ return asyncio.iscoroutinefunction(obj) or inspect.isawaitable(obj)
+ else:
+ return False
+
+
def _is_instance_mock(obj):
# can't use isinstance on Mock objects because they override __class__
# The base class for all mocks is NonCallableMock
@@ -355,7 +363,20 @@ class NonCallableMock(Base):
# every instance has its own class
# so we can create magic methods on the
# class without stomping on other mocks
- new = type(cls.__name__, (cls,), {'__doc__': cls.__doc__})
+ bases = (cls,)
+ if not issubclass(cls, AsyncMock):
+ # Check if spec is an async object or function
+ sig = inspect.signature(NonCallableMock.__init__)
+ bound_args = sig.bind_partial(cls, *args, **kw).arguments
+ spec_arg = [
+ arg for arg in bound_args.keys()
+ if arg.startswith('spec')
+ ]
+ if spec_arg:
+ # what if spec_set is different than spec?
+ if _is_async_obj(bound_args[spec_arg[0]]):
+ bases = (AsyncMockMixin, cls,)
+ new = type(cls.__name__, bases, {'__doc__': cls.__doc__})
instance = object.__new__(new)
return instance
@@ -431,6 +452,11 @@ class NonCallableMock(Base):
_eat_self=False):
_spec_class = None
_spec_signature = None
+ _spec_asyncs = []
+
+ for attr in dir(spec):
+ if asyncio.iscoroutinefunction(getattr(spec, attr, None)):
+ _spec_asyncs.append(attr)
if spec is not None and not _is_list(spec):
if isinstance(spec, type):
@@ -448,7 +474,7 @@ class NonCallableMock(Base):
__dict__['_spec_set'] = spec_set
__dict__['_spec_signature'] = _spec_signature
__dict__['_mock_methods'] = spec
-
+ __dict__['_spec_asyncs'] = _spec_asyncs
def __get_return_value(self):
ret = self._mock_return_value
@@ -886,7 +912,15 @@ class NonCallableMock(Base):
For non-callable mocks the callable variant will be used (rather than
any custom subclass)."""
+ _new_name = kw.get("_new_name")
+ if _new_name in self.__dict__['_spec_asyncs']:
+ return AsyncMock(**kw)
+
_type = type(self)
+ if issubclass(_type, MagicMock) and _new_name in _async_method_magics:
+ klass = AsyncMock
+ if issubclass(_type, AsyncMockMixin):
+ klass = MagicMock
if not issubclass(_type, CallableMixin):
if issubclass(_type, NonCallableMagicMock):
klass = MagicMock
@@ -932,14 +966,12 @@ def _try_iter(obj):
return obj
-
class CallableMixin(Base):
def __init__(self, spec=None, side_effect=None, return_value=DEFAULT,
wraps=None, name=None, spec_set=None, parent=None,
_spec_state=None, _new_name='', _new_parent=None, **kwargs):
self.__dict__['_mock_return_value'] = return_value
-
_safe_super(CallableMixin, self).__init__(
spec, wraps, name, spec_set, parent,
_spec_state, _new_name, _new_parent, **kwargs
@@ -1081,7 +1113,6 @@ class Mock(CallableMixin, NonCallableMock):
"""
-
def _dot_lookup(thing, comp, import_path):
try:
return getattr(thing, comp)
@@ -1279,8 +1310,10 @@ class _patch(object):
if isinstance(original, type):
# If we're patching out a class and there is a spec
inherit = True
-
- Klass = MagicMock
+ if spec is None and _is_async_obj(original):
+ Klass = AsyncMock
+ else:
+ Klass = MagicMock
_kwargs = {}
if new_callable is not None:
Klass = new_callable
@@ -1292,7 +1325,9 @@ class _patch(object):
not_callable = '__call__' not in this_spec
else:
not_callable = not callable(this_spec)
- if not_callable:
+ if _is_async_obj(this_spec):
+ Klass = AsyncMock
+ elif not_callable:
Klass = NonCallableMagicMock
if spec is not None:
@@ -1733,7 +1768,7 @@ _non_defaults = {
'__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__',
'__getstate__', '__setstate__', '__getformat__', '__setformat__',
'__repr__', '__dir__', '__subclasses__', '__format__',
- '__getnewargs_ex__',
+ '__getnewargs_ex__', '__aenter__', '__aexit__', '__anext__', '__aiter__',
}
@@ -1750,6 +1785,11 @@ _magics = {
' '.join([magic_methods, numerics, inplace, right]).split()
}
+# Magic methods used for async `with` statements
+_async_method_magics = {"__aenter__", "__aexit__", "__anext__"}
+# `__aiter__` is a plain function but used with async calls
+_async_magics = _async_method_magics | {"__aiter__"}
+
_all_magics = _magics | _non_defaults
_unsupported_magics = {
@@ -1779,6 +1819,7 @@ _return_values = {
'__float__': 1.0,
'__bool__': True,
'__index__': 1,
+ '__aexit__': False,
}
@@ -1811,10 +1852,19 @@ def _get_iter(self):
return iter(ret_val)
return __iter__
+def _get_async_iter(self):
+ def __aiter__():
+ ret_val = self.__aiter__._mock_return_value
+ if ret_val is DEFAULT:
+ return _AsyncIterator(iter([]))
+ return _AsyncIterator(iter(ret_val))
+ return __aiter__
+
_side_effect_methods = {
'__eq__': _get_eq,
'__ne__': _get_ne,
'__iter__': _get_iter,
+ '__aiter__': _get_async_iter
}
@@ -1879,8 +1929,33 @@ class NonCallableMagicMock(MagicMixin, NonCallableMock):
self._mock_set_magics()
+class AsyncMagicMixin:
+ def __init__(self, *args, **kw):
+ self._mock_set_async_magics() # make magic work for kwargs in init
+ _safe_super(AsyncMagicMixin, self).__init__(*args, **kw)
+ self._mock_set_async_magics() # fix magic broken by upper level init
+
+ def _mock_set_async_magics(self):
+ these_magics = _async_magics
-class MagicMock(MagicMixin, Mock):
+ if getattr(self, "_mock_methods", None) is not None:
+ these_magics = _async_magics.intersection(self._mock_methods)
+ remove_magics = _async_magics - these_magics
+
+ for entry in remove_magics:
+ if entry in type(self).__dict__:
+ # remove unneeded magic methods
+ delattr(self, entry)
+
+ # don't overwrite existing attributes if called a second time
+ these_magics = these_magics - set(type(self).__dict__)
+
+ _type = type(self)
+ for entry in these_magics:
+ setattr(_type, entry, MagicProxy(entry, self))
+
+
+class MagicMock(MagicMixin, AsyncMagicMixin, Mock):
"""
MagicMock is a subclass of Mock with default implementations
of most of the magic methods. You can use MagicMock without having to
@@ -1920,6 +1995,218 @@ class MagicProxy(object):
return self.create_mock()
+class AsyncMockMixin(Base):
+ awaited = _delegating_property('awaited')
+ await_count = _delegating_property('await_count')
+ await_args = _delegating_property('await_args')
+ await_args_list = _delegating_property('await_args_list')
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # asyncio.iscoroutinefunction() checks _is_coroutine property to say if an
+ # object is a coroutine. Without this check it looks to see if it is a
+ # function/method, which in this case it is not (since it is an
+ # AsyncMock).
+ # It is set through __dict__ because when spec_set is True, this
+ # attribute is likely undefined.
+ self.__dict__['_is_coroutine'] = asyncio.coroutines._is_coroutine
+ self.__dict__['_mock_awaited'] = _AwaitEvent(self)
+ self.__dict__['_mock_await_count'] = 0
+ self.__dict__['_mock_await_args'] = None
+ self.__dict__['_mock_await_args_list'] = _CallList()
+ code_mock = NonCallableMock(spec_set=CodeType)
+ code_mock.co_flags = inspect.CO_COROUTINE
+ self.__dict__['__code__'] = code_mock
+
+ async def _mock_call(_mock_self, *args, **kwargs):
+ self = _mock_self
+ try:
+ result = super()._mock_call(*args, **kwargs)
+ except (BaseException, StopIteration) as e:
+ side_effect = self.side_effect
+ if side_effect is not None and not callable(side_effect):
+ raise
+ return await _raise(e)
+
+ _call = self.call_args
+
+ async def proxy():
+ try:
+ if inspect.isawaitable(result):
+ return await result
+ else:
+ return result
+ finally:
+ self.await_count += 1
+ self.await_args = _call
+ self.await_args_list.append(_call)
+ await self.awaited._notify()
+
+ return await proxy()
+
+ def assert_awaited(_mock_self):
+ """
+ Assert that the mock was awaited at least once.
+ """
+ self = _mock_self
+ if self.await_count == 0:
+ msg = f"Expected {self._mock_name or 'mock'} to have been awaited."
+ raise AssertionError(msg)
+
+ def assert_awaited_once(_mock_self):
+ """
+ Assert that the mock was awaited exactly once.
+ """
+ self = _mock_self
+ if not self.await_count == 1:
+ msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
+ f" Awaited {self.await_count} times.")
+ raise AssertionError(msg)
+
+ def assert_awaited_with(_mock_self, *args, **kwargs):
+ """
+ Assert that the last await was with the specified arguments.
+ """
+ self = _mock_self
+ if self.await_args is None:
+ expected = self._format_mock_call_signature(args, kwargs)
+ raise AssertionError(f'Expected await: {expected}\nNot awaited')
+
+ def _error_message():
+ msg = self._format_mock_failure_message(args, kwargs)
+ return msg
+
+ expected = self._call_matcher((args, kwargs))
+ actual = self._call_matcher(self.await_args)
+ if expected != actual:
+ cause = expected if isinstance(expected, Exception) else None
+ raise AssertionError(_error_message()) from cause
+
+ def assert_awaited_once_with(_mock_self, *args, **kwargs):
+ """
+ Assert that the mock was awaited exactly once and with the specified
+ arguments.
+ """
+ self = _mock_self
+ if not self.await_count == 1:
+ msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
+ f" Awaited {self.await_count} times.")
+ raise AssertionError(msg)
+ return self.assert_awaited_with(*args, **kwargs)
+
+ def assert_any_await(_mock_self, *args, **kwargs):
+ """
+ Assert the mock has ever been awaited with the specified arguments.
+ """
+ self = _mock_self
+ expected = self._call_matcher((args, kwargs))
+ actual = [self._call_matcher(c) for c in self.await_args_list]
+ if expected not in actual:
+ cause = expected if isinstance(expected, Exception) else None
+ expected_string = self._format_mock_call_signature(args, kwargs)
+ raise AssertionError(
+ '%s await not found' % expected_string
+ ) from cause
+
+ def assert_has_awaits(_mock_self, calls, any_order=False):
+ """
+ Assert the mock has been awaited with the specified calls.
+ The :attr:`await_args_list` list is checked for the awaits.
+
+ If `any_order` is False (the default) then the awaits must be
+ sequential. There can be extra calls before or after the
+ specified awaits.
+
+ If `any_order` is True then the awaits can be in any order, but
+ they must all appear in :attr:`await_args_list`.
+ """
+ self = _mock_self
+ expected = [self._call_matcher(c) for c in calls]
+ cause = expected if isinstance(expected, Exception) else None
+ all_awaits = _CallList(self._call_matcher(c) for c in self.await_args_list)
+ if not any_order:
+ if expected not in all_awaits:
+ raise AssertionError(
+ f'Awaits not found.\nExpected: {_CallList(calls)}\n',
+ f'Actual: {self.await_args_list}'
+ ) from cause
+ return
+
+ all_awaits = list(all_awaits)
+
+ not_found = []
+ for kall in expected:
+ try:
+ all_awaits.remove(kall)
+ except ValueError:
+ not_found.append(kall)
+ if not_found:
+ raise AssertionError(
+ '%r not all found in await list' % (tuple(not_found),)
+ ) from cause
+
+ def assert_not_awaited(_mock_self):
+ """
+ Assert that the mock was never awaited.
+ """
+ self = _mock_self
+ if self.await_count != 0:
+ msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
+ f" Awaited {self.await_count} times.")
+ raise AssertionError(msg)
+
+ def reset_mock(self, *args, **kwargs):
+ """
+ See :func:`.Mock.reset_mock()`
+ """
+ super().reset_mock(*args, **kwargs)
+ self.await_count = 0
+ self.await_args = None
+ self.await_args_list = _CallList()
+
+
+class AsyncMock(AsyncMockMixin, AsyncMagicMixin, Mock):
+ """
+ Enhance :class:`Mock` with features allowing to mock
+ an async function.
+
+ The :class:`AsyncMock` object will behave so the object is
+ recognized as an async function, and the result of a call is an awaitable:
+
+ >>> mock = AsyncMock()
+ >>> asyncio.iscoroutinefunction(mock)
+ True
+ >>> inspect.isawaitable(mock())
+ True
+
+
+ The result of ``mock()`` is an async function which will have the outcome
+ of ``side_effect`` or ``return_value``:
+
+ - if ``side_effect`` is a function, the async function will return the
+ result of that function,
+ - if ``side_effect`` is an exception, the async function will raise the
+ exception,
+ - if ``side_effect`` is an iterable, the async function will return the
+ next value of the iterable, however, if the sequence of result is
+ exhausted, ``StopIteration`` is raised immediately,
+ - if ``side_effect`` is not defined, the async function will return the
+ value defined by ``return_value``, hence, by default, the async function
+ returns a new :class:`AsyncMock` object.
+
+ If the outcome of ``side_effect`` or ``return_value`` is an async function,
+ the mock async function obtained when the mock object is called will be this
+ async function itself (and not an async function returning an async
+ function).
+
+ The test author can also specify a wrapped object with ``wraps``. In this
+ case, the :class:`Mock` object behavior is the same as with an
+ :class:`.Mock` object: the wrapped object may have methods
+ defined as async function functions.
+
+ Based on Martin Richard's asyntest project.
+ """
+
class _ANY(object):
"A helper object that compares equal to everything."
@@ -2145,7 +2432,6 @@ class _Call(tuple):
call = _Call(from_kall=False)
-
def create_autospec(spec, spec_set=False, instance=False, _parent=None,
_name=None, **kwargs):
"""Create a mock object using another object as a spec. Attributes on the
@@ -2171,7 +2457,10 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
spec = type(spec)
is_type = isinstance(spec, type)
-
+ if getattr(spec, '__code__', None):
+ is_async_func = asyncio.iscoroutinefunction(spec)
+ else:
+ is_async_func = False
_kwargs = {'spec': spec}
if spec_set:
_kwargs = {'spec_set': spec}
@@ -2188,6 +2477,11 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
# descriptors don't have a spec
# because we don't know what type they return
_kwargs = {}
+ elif is_async_func:
+ if instance:
+ raise RuntimeError("Instance can not be True when create_autospec "
+ "is mocking an async function")
+ Klass = AsyncMock
elif not _callable(spec):
Klass = NonCallableMagicMock
elif is_type and instance and not _instance_callable(spec):
@@ -2204,9 +2498,26 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
name=_name, **_kwargs)
if isinstance(spec, FunctionTypes):
+ wrapped_mock = mock
# should only happen at the top level because we don't
# recurse for functions
mock = _set_signature(mock, spec)
+ if is_async_func:
+ mock._is_coroutine = asyncio.coroutines._is_coroutine
+ mock.await_count = 0
+ mock.await_args = None
+ mock.await_args_list = _CallList()
+
+ for a in ('assert_awaited',
+ 'assert_awaited_once',
+ 'assert_awaited_with',
+ 'assert_awaited_once_with',
+ 'assert_any_await',
+ 'assert_has_awaits',
+ 'assert_not_awaited'):
+ def f(*args, **kwargs):
+ return getattr(wrapped_mock, a)(*args, **kwargs)
+ setattr(mock, a, f)
else:
_check_signature(spec, mock, is_type, instance)
@@ -2250,9 +2561,13 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
skipfirst = _must_skip(spec, entry, is_type)
kwargs['_eat_self'] = skipfirst
- new = MagicMock(parent=parent, name=entry, _new_name=entry,
- _new_parent=parent,
- **kwargs)
+ if asyncio.iscoroutinefunction(original):
+ child_klass = AsyncMock
+ else:
+ child_klass = MagicMock
+ new = child_klass(parent=parent, name=entry, _new_name=entry,
+ _new_parent=parent,
+ **kwargs)
mock._mock_children[entry] = new
_check_signature(original, new, skipfirst=skipfirst)
@@ -2438,3 +2753,60 @@ def seal(mock):
continue
if m._mock_new_parent is mock:
seal(m)
+
+
+async def _raise(exception):
+ raise exception
+
+
+class _AsyncIterator:
+ """
+ Wraps an iterator in an asynchronous iterator.
+ """
+ def __init__(self, iterator):
+ self.iterator = iterator
+ code_mock = NonCallableMock(spec_set=CodeType)
+ code_mock.co_flags = inspect.CO_ITERABLE_COROUTINE
+ self.__dict__['__code__'] = code_mock
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ try:
+ return next(self.iterator)
+ except StopIteration:
+ pass
+ raise StopAsyncIteration
+
+
+class _AwaitEvent:
+ def __init__(self, mock):
+ self._mock = mock
+ self._condition = None
+
+ async def _notify(self):
+ condition = self._get_condition()
+ try:
+ await condition.acquire()
+ condition.notify_all()
+ finally:
+ condition.release()
+
+ def _get_condition(self):
+ """
+ Creation of condition is delayed, to minimize the chance of using the
+ wrong loop.
+ A user may create a mock with _AwaitEvent before selecting the
+ execution loop. Requiring a user to delay creation is error-prone and
+ inflexible. Instead, condition is created when user actually starts to
+ use the mock.
+ """
+ # No synchronization is needed:
+ # - asyncio is thread unsafe
+ # - there are no awaits here, method will be executed without
+ # switching asyncio context.
+ if self._condition is None:
+ self._condition = asyncio.Condition()
+
+ return self._condition
diff --git a/Lib/unittest/test/testmock/testasync.py b/Lib/unittest/test/testmock/testasync.py
new file mode 100644
index 0000000..a9aa143
--- /dev/null
+++ b/Lib/unittest/test/testmock/testasync.py
@@ -0,0 +1,549 @@
+import asyncio
+import inspect
+import unittest
+
+from unittest.mock import call, AsyncMock, patch, MagicMock, create_autospec
+
+
+def tearDownModule():
+ asyncio.set_event_loop_policy(None)
+
+
+class AsyncClass:
+ def __init__(self):
+ pass
+ async def async_method(self):
+ pass
+ def normal_method(self):
+ pass
+
+async def async_func():
+ pass
+
+def normal_func():
+ pass
+
+class NormalClass(object):
+ def a(self):
+ pass
+
+
+async_foo_name = f'{__name__}.AsyncClass'
+normal_foo_name = f'{__name__}.NormalClass'
+
+
+class AsyncPatchDecoratorTest(unittest.TestCase):
+ def test_is_coroutine_function_patch(self):
+ @patch.object(AsyncClass, 'async_method')
+ def test_async(mock_method):
+ self.assertTrue(asyncio.iscoroutinefunction(mock_method))
+ test_async()
+
+ def test_is_async_patch(self):
+ @patch.object(AsyncClass, 'async_method')
+ def test_async(mock_method):
+ m = mock_method()
+ self.assertTrue(inspect.isawaitable(m))
+ asyncio.run(m)
+
+ @patch(f'{async_foo_name}.async_method')
+ def test_no_parent_attribute(mock_method):
+ m = mock_method()
+ self.assertTrue(inspect.isawaitable(m))
+ asyncio.run(m)
+
+ test_async()
+ test_no_parent_attribute()
+
+ def test_is_AsyncMock_patch(self):
+ @patch.object(AsyncClass, 'async_method')
+ def test_async(mock_method):
+ self.assertIsInstance(mock_method, AsyncMock)
+
+ test_async()
+
+
+class AsyncPatchCMTest(unittest.TestCase):
+ def test_is_async_function_cm(self):
+ def test_async():
+ with patch.object(AsyncClass, 'async_method') as mock_method:
+ self.assertTrue(asyncio.iscoroutinefunction(mock_method))
+
+ test_async()
+
+ def test_is_async_cm(self):
+ def test_async():
+ with patch.object(AsyncClass, 'async_method') as mock_method:
+ m = mock_method()
+ self.assertTrue(inspect.isawaitable(m))
+ asyncio.run(m)
+
+ test_async()
+
+ def test_is_AsyncMock_cm(self):
+ def test_async():
+ with patch.object(AsyncClass, 'async_method') as mock_method:
+ self.assertIsInstance(mock_method, AsyncMock)
+
+ test_async()
+
+
+class AsyncMockTest(unittest.TestCase):
+ def test_iscoroutinefunction_default(self):
+ mock = AsyncMock()
+ self.assertTrue(asyncio.iscoroutinefunction(mock))
+
+ def test_iscoroutinefunction_function(self):
+ async def foo(): pass
+ mock = AsyncMock(foo)
+ self.assertTrue(asyncio.iscoroutinefunction(mock))
+ self.assertTrue(inspect.iscoroutinefunction(mock))
+
+ def test_isawaitable(self):
+ mock = AsyncMock()
+ m = mock()
+ self.assertTrue(inspect.isawaitable(m))
+ asyncio.run(m)
+ self.assertIn('assert_awaited', dir(mock))
+
+ def test_iscoroutinefunction_normal_function(self):
+ def foo(): pass
+ mock = AsyncMock(foo)
+ self.assertTrue(asyncio.iscoroutinefunction(mock))
+ self.assertTrue(inspect.iscoroutinefunction(mock))
+
+ def test_future_isfuture(self):
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ fut = asyncio.Future()
+ loop.stop()
+ loop.close()
+ mock = AsyncMock(fut)
+ self.assertIsInstance(mock, asyncio.Future)
+
+
+class AsyncAutospecTest(unittest.TestCase):
+ def test_is_AsyncMock_patch(self):
+ @patch(async_foo_name, autospec=True)
+ def test_async(mock_method):
+ self.assertIsInstance(mock_method.async_method, AsyncMock)
+ self.assertIsInstance(mock_method, MagicMock)
+
+ @patch(async_foo_name, autospec=True)
+ def test_normal_method(mock_method):
+ self.assertIsInstance(mock_method.normal_method, MagicMock)
+
+ test_async()
+ test_normal_method()
+
+ def test_create_autospec_instance(self):
+ with self.assertRaises(RuntimeError):
+ create_autospec(async_func, instance=True)
+
+ def test_create_autospec(self):
+ spec = create_autospec(async_func)
+ self.assertTrue(asyncio.iscoroutinefunction(spec))
+
+
+class AsyncSpecTest(unittest.TestCase):
+ def test_spec_as_async_positional_magicmock(self):
+ mock = MagicMock(async_func)
+ self.assertIsInstance(mock, MagicMock)
+ m = mock()
+ self.assertTrue(inspect.isawaitable(m))
+ asyncio.run(m)
+
+ def test_spec_as_async_kw_magicmock(self):
+ mock = MagicMock(spec=async_func)
+ self.assertIsInstance(mock, MagicMock)
+ m = mock()
+ self.assertTrue(inspect.isawaitable(m))
+ asyncio.run(m)
+
+ def test_spec_as_async_kw_AsyncMock(self):
+ mock = AsyncMock(spec=async_func)
+ self.assertIsInstance(mock, AsyncMock)
+ m = mock()
+ self.assertTrue(inspect.isawaitable(m))
+ asyncio.run(m)
+
+ def test_spec_as_async_positional_AsyncMock(self):
+ mock = AsyncMock(async_func)
+ self.assertIsInstance(mock, AsyncMock)
+ m = mock()
+ self.assertTrue(inspect.isawaitable(m))
+ asyncio.run(m)
+
+ def test_spec_as_normal_kw_AsyncMock(self):
+ mock = AsyncMock(spec=normal_func)
+ self.assertIsInstance(mock, AsyncMock)
+ m = mock()
+ self.assertTrue(inspect.isawaitable(m))
+ asyncio.run(m)
+
+ def test_spec_as_normal_positional_AsyncMock(self):
+ mock = AsyncMock(normal_func)
+ self.assertIsInstance(mock, AsyncMock)
+ m = mock()
+ self.assertTrue(inspect.isawaitable(m))
+ asyncio.run(m)
+
+ def test_spec_async_mock(self):
+ @patch.object(AsyncClass, 'async_method', spec=True)
+ def test_async(mock_method):
+ self.assertIsInstance(mock_method, AsyncMock)
+
+ test_async()
+
+ def test_spec_parent_not_async_attribute_is(self):
+ @patch(async_foo_name, spec=True)
+ def test_async(mock_method):
+ self.assertIsInstance(mock_method, MagicMock)
+ self.assertIsInstance(mock_method.async_method, AsyncMock)
+
+ test_async()
+
+ def test_target_async_spec_not(self):
+ @patch.object(AsyncClass, 'async_method', spec=NormalClass.a)
+ def test_async_attribute(mock_method):
+ self.assertIsInstance(mock_method, MagicMock)
+ self.assertFalse(inspect.iscoroutine(mock_method))
+ self.assertFalse(inspect.isawaitable(mock_method))
+
+ test_async_attribute()
+
+ def test_target_not_async_spec_is(self):
+ @patch.object(NormalClass, 'a', spec=async_func)
+ def test_attribute_not_async_spec_is(mock_async_func):
+ self.assertIsInstance(mock_async_func, AsyncMock)
+ test_attribute_not_async_spec_is()
+
+ def test_spec_async_attributes(self):
+ @patch(normal_foo_name, spec=AsyncClass)
+ def test_async_attributes_coroutines(MockNormalClass):
+ self.assertIsInstance(MockNormalClass.async_method, AsyncMock)
+ self.assertIsInstance(MockNormalClass, MagicMock)
+
+ test_async_attributes_coroutines()
+
+
+class AsyncSpecSetTest(unittest.TestCase):
+ def test_is_AsyncMock_patch(self):
+ @patch.object(AsyncClass, 'async_method', spec_set=True)
+ def test_async(async_method):
+ self.assertIsInstance(async_method, AsyncMock)
+
+ def test_is_async_AsyncMock(self):
+ mock = AsyncMock(spec_set=AsyncClass.async_method)
+ self.assertTrue(asyncio.iscoroutinefunction(mock))
+ self.assertIsInstance(mock, AsyncMock)
+
+ def test_is_child_AsyncMock(self):
+ mock = MagicMock(spec_set=AsyncClass)
+ self.assertTrue(asyncio.iscoroutinefunction(mock.async_method))
+ self.assertFalse(asyncio.iscoroutinefunction(mock.normal_method))
+ self.assertIsInstance(mock.async_method, AsyncMock)
+ self.assertIsInstance(mock.normal_method, MagicMock)
+ self.assertIsInstance(mock, MagicMock)
+
+
+class AsyncArguments(unittest.TestCase):
+ def test_add_return_value(self):
+ async def addition(self, var):
+ return var + 1
+
+ mock = AsyncMock(addition, return_value=10)
+ output = asyncio.run(mock(5))
+
+ self.assertEqual(output, 10)
+
+ def test_add_side_effect_exception(self):
+ async def addition(var):
+ return var + 1
+ mock = AsyncMock(addition, side_effect=Exception('err'))
+ with self.assertRaises(Exception):
+ asyncio.run(mock(5))
+
+ def test_add_side_effect_function(self):
+ async def addition(var):
+ return var + 1
+ mock = AsyncMock(side_effect=addition)
+ result = asyncio.run(mock(5))
+ self.assertEqual(result, 6)
+
+ def test_add_side_effect_iterable(self):
+ vals = [1, 2, 3]
+ mock = AsyncMock(side_effect=vals)
+ for item in vals:
+ self.assertEqual(item, asyncio.run(mock()))
+
+ with self.assertRaises(RuntimeError) as e:
+ asyncio.run(mock())
+ self.assertEqual(
+ e.exception,
+ RuntimeError('coroutine raised StopIteration')
+ )
+
+
+class AsyncContextManagerTest(unittest.TestCase):
+ class WithAsyncContextManager:
+ def __init__(self):
+ self.entered = False
+ self.exited = False
+
+ async def __aenter__(self, *args, **kwargs):
+ self.entered = True
+ return self
+
+ async def __aexit__(self, *args, **kwargs):
+ self.exited = True
+
+ def test_magic_methods_are_async_mocks(self):
+ mock = MagicMock(self.WithAsyncContextManager())
+ self.assertIsInstance(mock.__aenter__, AsyncMock)
+ self.assertIsInstance(mock.__aexit__, AsyncMock)
+
+ def test_mock_supports_async_context_manager(self):
+ called = False
+ instance = self.WithAsyncContextManager()
+ mock_instance = MagicMock(instance)
+
+ async def use_context_manager():
+ nonlocal called
+ async with mock_instance as result:
+ called = True
+ return result
+
+ result = asyncio.run(use_context_manager())
+ self.assertFalse(instance.entered)
+ self.assertFalse(instance.exited)
+ self.assertTrue(called)
+ self.assertTrue(mock_instance.entered)
+ self.assertTrue(mock_instance.exited)
+ self.assertTrue(mock_instance.__aenter__.called)
+ self.assertTrue(mock_instance.__aexit__.called)
+ self.assertIsNot(mock_instance, result)
+ self.assertIsInstance(result, AsyncMock)
+
+ def test_mock_customize_async_context_manager(self):
+ instance = self.WithAsyncContextManager()
+ mock_instance = MagicMock(instance)
+
+ expected_result = object()
+ mock_instance.__aenter__.return_value = expected_result
+
+ async def use_context_manager():
+ async with mock_instance as result:
+ return result
+
+ self.assertIs(asyncio.run(use_context_manager()), expected_result)
+
+ def test_mock_customize_async_context_manager_with_coroutine(self):
+ enter_called = False
+ exit_called = False
+
+ async def enter_coroutine(*args):
+ nonlocal enter_called
+ enter_called = True
+
+ async def exit_coroutine(*args):
+ nonlocal exit_called
+ exit_called = True
+
+ instance = self.WithAsyncContextManager()
+ mock_instance = MagicMock(instance)
+
+ mock_instance.__aenter__ = enter_coroutine
+ mock_instance.__aexit__ = exit_coroutine
+
+ async def use_context_manager():
+ async with mock_instance:
+ pass
+
+ asyncio.run(use_context_manager())
+ self.assertTrue(enter_called)
+ self.assertTrue(exit_called)
+
+ def test_context_manager_raise_exception_by_default(self):
+ async def raise_in(context_manager):
+ async with context_manager:
+ raise TypeError()
+
+ instance = self.WithAsyncContextManager()
+ mock_instance = MagicMock(instance)
+ with self.assertRaises(TypeError):
+ asyncio.run(raise_in(mock_instance))
+
+
+class AsyncIteratorTest(unittest.TestCase):
+ class WithAsyncIterator(object):
+ def __init__(self):
+ self.items = ["foo", "NormalFoo", "baz"]
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ try:
+ return self.items.pop()
+ except IndexError:
+ pass
+
+ raise StopAsyncIteration
+
+ def test_mock_aiter_and_anext(self):
+ instance = self.WithAsyncIterator()
+ mock_instance = MagicMock(instance)
+
+ self.assertEqual(asyncio.iscoroutine(instance.__aiter__),
+ asyncio.iscoroutine(mock_instance.__aiter__))
+ self.assertEqual(asyncio.iscoroutine(instance.__anext__),
+ asyncio.iscoroutine(mock_instance.__anext__))
+
+ iterator = instance.__aiter__()
+ if asyncio.iscoroutine(iterator):
+ iterator = asyncio.run(iterator)
+
+ mock_iterator = mock_instance.__aiter__()
+ if asyncio.iscoroutine(mock_iterator):
+ mock_iterator = asyncio.run(mock_iterator)
+
+ self.assertEqual(asyncio.iscoroutine(iterator.__aiter__),
+ asyncio.iscoroutine(mock_iterator.__aiter__))
+ self.assertEqual(asyncio.iscoroutine(iterator.__anext__),
+ asyncio.iscoroutine(mock_iterator.__anext__))
+
+ def test_mock_async_for(self):
+ async def iterate(iterator):
+ accumulator = []
+ async for item in iterator:
+ accumulator.append(item)
+
+ return accumulator
+
+ expected = ["FOO", "BAR", "BAZ"]
+ with self.subTest("iterate through default value"):
+ mock_instance = MagicMock(self.WithAsyncIterator())
+ self.assertEqual([], asyncio.run(iterate(mock_instance)))
+
+ with self.subTest("iterate through set return_value"):
+ mock_instance = MagicMock(self.WithAsyncIterator())
+ mock_instance.__aiter__.return_value = expected[:]
+ self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
+
+ with self.subTest("iterate through set return_value iterator"):
+ mock_instance = MagicMock(self.WithAsyncIterator())
+ mock_instance.__aiter__.return_value = iter(expected[:])
+ self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
+
+
+class AsyncMockAssert(unittest.TestCase):
+ def setUp(self):
+ self.mock = AsyncMock()
+
+ async def _runnable_test(self, *args):
+ if not args:
+ await self.mock()
+ else:
+ await self.mock(*args)
+
+ def test_assert_awaited(self):
+ with self.assertRaises(AssertionError):
+ self.mock.assert_awaited()
+
+ asyncio.run(self._runnable_test())
+ self.mock.assert_awaited()
+
+ def test_assert_awaited_once(self):
+ with self.assertRaises(AssertionError):
+ self.mock.assert_awaited_once()
+
+ asyncio.run(self._runnable_test())
+ self.mock.assert_awaited_once()
+
+ asyncio.run(self._runnable_test())
+ with self.assertRaises(AssertionError):
+ self.mock.assert_awaited_once()
+
+ def test_assert_awaited_with(self):
+ asyncio.run(self._runnable_test())
+ with self.assertRaises(AssertionError):
+ self.mock.assert_awaited_with('foo')
+
+ asyncio.run(self._runnable_test('foo'))
+ self.mock.assert_awaited_with('foo')
+
+ asyncio.run(self._runnable_test('SomethingElse'))
+ with self.assertRaises(AssertionError):
+ self.mock.assert_awaited_with('foo')
+
+ def test_assert_awaited_once_with(self):
+ with self.assertRaises(AssertionError):
+ self.mock.assert_awaited_once_with('foo')
+
+ asyncio.run(self._runnable_test('foo'))
+ self.mock.assert_awaited_once_with('foo')
+
+ asyncio.run(self._runnable_test('foo'))
+ with self.assertRaises(AssertionError):
+ self.mock.assert_awaited_once_with('foo')
+
+ def test_assert_any_wait(self):
+ with self.assertRaises(AssertionError):
+ self.mock.assert_any_await('NormalFoo')
+
+ asyncio.run(self._runnable_test('foo'))
+ with self.assertRaises(AssertionError):
+ self.mock.assert_any_await('NormalFoo')
+
+ asyncio.run(self._runnable_test('NormalFoo'))
+ self.mock.assert_any_await('NormalFoo')
+
+ asyncio.run(self._runnable_test('SomethingElse'))
+ self.mock.assert_any_await('NormalFoo')
+
+ def test_assert_has_awaits_no_order(self):
+ calls = [call('NormalFoo'), call('baz')]
+
+ with self.assertRaises(AssertionError):
+ self.mock.assert_has_awaits(calls)
+
+ asyncio.run(self._runnable_test('foo'))
+ with self.assertRaises(AssertionError):
+ self.mock.assert_has_awaits(calls)
+
+ asyncio.run(self._runnable_test('NormalFoo'))
+ with self.assertRaises(AssertionError):
+ self.mock.assert_has_awaits(calls)
+
+ asyncio.run(self._runnable_test('baz'))
+ self.mock.assert_has_awaits(calls)
+
+ asyncio.run(self._runnable_test('SomethingElse'))
+ self.mock.assert_has_awaits(calls)
+
+ def test_assert_has_awaits_ordered(self):
+ calls = [call('NormalFoo'), call('baz')]
+ with self.assertRaises(AssertionError):
+ self.mock.assert_has_awaits(calls, any_order=True)
+
+ asyncio.run(self._runnable_test('baz'))
+ with self.assertRaises(AssertionError):
+ self.mock.assert_has_awaits(calls, any_order=True)
+
+ asyncio.run(self._runnable_test('foo'))
+ with self.assertRaises(AssertionError):
+ self.mock.assert_has_awaits(calls, any_order=True)
+
+ asyncio.run(self._runnable_test('NormalFoo'))
+ self.mock.assert_has_awaits(calls, any_order=True)
+
+ asyncio.run(self._runnable_test('qux'))
+ self.mock.assert_has_awaits(calls, any_order=True)
+
+ def test_assert_not_awaited(self):
+ self.mock.assert_not_awaited()
+
+ asyncio.run(self._runnable_test())
+ with self.assertRaises(AssertionError):
+ self.mock.assert_not_awaited()
diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py
index b20b8e2..307b8b7 100644
--- a/Lib/unittest/test/testmock/testmock.py
+++ b/Lib/unittest/test/testmock/testmock.py
@@ -9,7 +9,7 @@ from unittest import mock
from unittest.mock import (
call, DEFAULT, patch, sentinel,
MagicMock, Mock, NonCallableMock,
- NonCallableMagicMock, _Call, _CallList,
+ NonCallableMagicMock, AsyncMock, _Call, _CallList,
create_autospec
)
@@ -1618,7 +1618,8 @@ class MockTest(unittest.TestCase):
def test_adding_child_mock(self):
- for Klass in NonCallableMock, Mock, MagicMock, NonCallableMagicMock:
+ for Klass in (NonCallableMock, Mock, MagicMock, NonCallableMagicMock,
+ AsyncMock):
mock = Klass()
mock.foo = Mock()
diff --git a/Misc/NEWS.d/next/Library/2018-09-13-20-33-24.bpo-26467.cahAk3.rst b/Misc/NEWS.d/next/Library/2018-09-13-20-33-24.bpo-26467.cahAk3.rst
new file mode 100644
index 0000000..4cf3f2a
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-09-13-20-33-24.bpo-26467.cahAk3.rst
@@ -0,0 +1,2 @@
+Added AsyncMock to support using unittest to mock asyncio coroutines.
+Patch by Lisa Roach.