diff options
Diffstat (limited to 'Lib/test/test_contextlib_async.py')
| -rw-r--r-- | Lib/test/test_contextlib_async.py | 212 |
1 files changed, 212 insertions, 0 deletions
diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py new file mode 100644 index 0000000..42cc331 --- /dev/null +++ b/Lib/test/test_contextlib_async.py @@ -0,0 +1,212 @@ +import asyncio +from contextlib import asynccontextmanager +import functools +from test import support +import unittest + + +def _async_test(func): + """Decorator to turn an async function into a test case.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + coro = func(*args, **kwargs) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(coro) + finally: + loop.close() + asyncio.set_event_loop(None) + return wrapper + + +class AsyncContextManagerTestCase(unittest.TestCase): + + @_async_test + async def test_contextmanager_plain(self): + state = [] + @asynccontextmanager + async def woohoo(): + state.append(1) + yield 42 + state.append(999) + async with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + self.assertEqual(state, [1, 42, 999]) + + @_async_test + async def test_contextmanager_finally(self): + state = [] + @asynccontextmanager + async def woohoo(): + state.append(1) + try: + yield 42 + finally: + state.append(999) + with self.assertRaises(ZeroDivisionError): + async with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError() + self.assertEqual(state, [1, 42, 999]) + + @_async_test + async def test_contextmanager_no_reraise(self): + @asynccontextmanager + async def whee(): + yield + ctx = whee() + await ctx.__aenter__() + # Calling __aexit__ should not result in an exception + self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) + + @_async_test + async def test_contextmanager_trap_yield_after_throw(self): + @asynccontextmanager + async def whoo(): + try: + yield + except: + yield + ctx = whoo() + await ctx.__aenter__() + with self.assertRaises(RuntimeError): + await ctx.__aexit__(TypeError, TypeError('foo'), None) + + @_async_test + async def test_contextmanager_trap_no_yield(self): + @asynccontextmanager + async def whoo(): + if False: + yield + ctx = whoo() + with self.assertRaises(RuntimeError): + await ctx.__aenter__() + + @_async_test + async def test_contextmanager_trap_second_yield(self): + @asynccontextmanager + async def whoo(): + yield + yield + ctx = whoo() + await ctx.__aenter__() + with self.assertRaises(RuntimeError): + await ctx.__aexit__(None, None, None) + + @_async_test + async def test_contextmanager_non_normalised(self): + @asynccontextmanager + async def whoo(): + try: + yield + except RuntimeError: + raise SyntaxError + + ctx = whoo() + await ctx.__aenter__() + with self.assertRaises(SyntaxError): + await ctx.__aexit__(RuntimeError, None, None) + + @_async_test + async def test_contextmanager_except(self): + state = [] + @asynccontextmanager + async def woohoo(): + state.append(1) + try: + yield 42 + except ZeroDivisionError as e: + state.append(e.args[0]) + self.assertEqual(state, [1, 42, 999]) + async with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError(999) + self.assertEqual(state, [1, 42, 999]) + + @_async_test + async def test_contextmanager_except_stopiter(self): + @asynccontextmanager + async def woohoo(): + yield + + for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')): + with self.subTest(type=type(stop_exc)): + try: + async with woohoo(): + raise stop_exc + except Exception as ex: + self.assertIs(ex, stop_exc) + else: + self.fail(f'{stop_exc} was suppressed') + + @_async_test + async def test_contextmanager_wrap_runtimeerror(self): + @asynccontextmanager + async def woohoo(): + try: + yield + except Exception as exc: + raise RuntimeError(f'caught {exc}') from exc + + with self.assertRaises(RuntimeError): + async with woohoo(): + 1 / 0 + + # If the context manager wrapped StopAsyncIteration in a RuntimeError, + # we also unwrap it, because we can't tell whether the wrapping was + # done by the generator machinery or by the generator itself. + with self.assertRaises(StopAsyncIteration): + async with woohoo(): + raise StopAsyncIteration + + def _create_contextmanager_attribs(self): + def attribs(**kw): + def decorate(func): + for k,v in kw.items(): + setattr(func,k,v) + return func + return decorate + @asynccontextmanager + @attribs(foo='bar') + async def baz(spam): + """Whee!""" + yield + return baz + + def test_contextmanager_attribs(self): + baz = self._create_contextmanager_attribs() + self.assertEqual(baz.__name__,'baz') + self.assertEqual(baz.foo, 'bar') + + @support.requires_docstrings + def test_contextmanager_doc_attrib(self): + baz = self._create_contextmanager_attribs() + self.assertEqual(baz.__doc__, "Whee!") + + @support.requires_docstrings + @_async_test + async def test_instance_docstring_given_cm_docstring(self): + baz = self._create_contextmanager_attribs()(None) + self.assertEqual(baz.__doc__, "Whee!") + async with baz: + pass # suppress warning + + @_async_test + async def test_keywords(self): + # Ensure no keyword arguments are inhibited + @asynccontextmanager + async def woohoo(self, func, args, kwds): + yield (self, func, args, kwds) + async with woohoo(self=11, func=22, args=33, kwds=44) as target: + self.assertEqual(target, (11, 22, 33, 44)) + + +if __name__ == '__main__': + unittest.main() |
