summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_contextlib_async.py
diff options
context:
space:
mode:
authorJelle Zijlstra <jelle.zijlstra@gmail.com>2017-05-01 01:25:58 (GMT)
committerYury Selivanov <yselivanov@gmail.com>2017-05-01 01:25:58 (GMT)
commit2e624690bd74071358566300b7ef0bc45f444a30 (patch)
treef96176f5997f38c00974854907b586ce887981a3 /Lib/test/test_contextlib_async.py
parent9dc2b3809f38be2e403ee264958106badfda142d (diff)
downloadcpython-2e624690bd74071358566300b7ef0bc45f444a30.zip
cpython-2e624690bd74071358566300b7ef0bc45f444a30.tar.gz
cpython-2e624690bd74071358566300b7ef0bc45f444a30.tar.bz2
bpo-29679: Implement @contextlib.asynccontextmanager (#360)
Diffstat (limited to 'Lib/test/test_contextlib_async.py')
-rw-r--r--Lib/test/test_contextlib_async.py212
1 files changed, 212 insertions, 0 deletions
diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py
new file mode 100644
index 0000000..42cc331
--- /dev/null
+++ b/Lib/test/test_contextlib_async.py
@@ -0,0 +1,212 @@
+import asyncio
+from contextlib import asynccontextmanager
+import functools
+from test import support
+import unittest
+
+
+def _async_test(func):
+ """Decorator to turn an async function into a test case."""
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ coro = func(*args, **kwargs)
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ return loop.run_until_complete(coro)
+ finally:
+ loop.close()
+ asyncio.set_event_loop(None)
+ return wrapper
+
+
+class AsyncContextManagerTestCase(unittest.TestCase):
+
+ @_async_test
+ async def test_contextmanager_plain(self):
+ state = []
+ @asynccontextmanager
+ async def woohoo():
+ state.append(1)
+ yield 42
+ state.append(999)
+ async with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ self.assertEqual(state, [1, 42, 999])
+
+ @_async_test
+ async def test_contextmanager_finally(self):
+ state = []
+ @asynccontextmanager
+ async def woohoo():
+ state.append(1)
+ try:
+ yield 42
+ finally:
+ state.append(999)
+ with self.assertRaises(ZeroDivisionError):
+ async with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ raise ZeroDivisionError()
+ self.assertEqual(state, [1, 42, 999])
+
+ @_async_test
+ async def test_contextmanager_no_reraise(self):
+ @asynccontextmanager
+ async def whee():
+ yield
+ ctx = whee()
+ await ctx.__aenter__()
+ # Calling __aexit__ should not result in an exception
+ self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
+
+ @_async_test
+ async def test_contextmanager_trap_yield_after_throw(self):
+ @asynccontextmanager
+ async def whoo():
+ try:
+ yield
+ except:
+ yield
+ ctx = whoo()
+ await ctx.__aenter__()
+ with self.assertRaises(RuntimeError):
+ await ctx.__aexit__(TypeError, TypeError('foo'), None)
+
+ @_async_test
+ async def test_contextmanager_trap_no_yield(self):
+ @asynccontextmanager
+ async def whoo():
+ if False:
+ yield
+ ctx = whoo()
+ with self.assertRaises(RuntimeError):
+ await ctx.__aenter__()
+
+ @_async_test
+ async def test_contextmanager_trap_second_yield(self):
+ @asynccontextmanager
+ async def whoo():
+ yield
+ yield
+ ctx = whoo()
+ await ctx.__aenter__()
+ with self.assertRaises(RuntimeError):
+ await ctx.__aexit__(None, None, None)
+
+ @_async_test
+ async def test_contextmanager_non_normalised(self):
+ @asynccontextmanager
+ async def whoo():
+ try:
+ yield
+ except RuntimeError:
+ raise SyntaxError
+
+ ctx = whoo()
+ await ctx.__aenter__()
+ with self.assertRaises(SyntaxError):
+ await ctx.__aexit__(RuntimeError, None, None)
+
+ @_async_test
+ async def test_contextmanager_except(self):
+ state = []
+ @asynccontextmanager
+ async def woohoo():
+ state.append(1)
+ try:
+ yield 42
+ except ZeroDivisionError as e:
+ state.append(e.args[0])
+ self.assertEqual(state, [1, 42, 999])
+ async with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ raise ZeroDivisionError(999)
+ self.assertEqual(state, [1, 42, 999])
+
+ @_async_test
+ async def test_contextmanager_except_stopiter(self):
+ @asynccontextmanager
+ async def woohoo():
+ yield
+
+ for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
+ with self.subTest(type=type(stop_exc)):
+ try:
+ async with woohoo():
+ raise stop_exc
+ except Exception as ex:
+ self.assertIs(ex, stop_exc)
+ else:
+ self.fail(f'{stop_exc} was suppressed')
+
+ @_async_test
+ async def test_contextmanager_wrap_runtimeerror(self):
+ @asynccontextmanager
+ async def woohoo():
+ try:
+ yield
+ except Exception as exc:
+ raise RuntimeError(f'caught {exc}') from exc
+
+ with self.assertRaises(RuntimeError):
+ async with woohoo():
+ 1 / 0
+
+ # If the context manager wrapped StopAsyncIteration in a RuntimeError,
+ # we also unwrap it, because we can't tell whether the wrapping was
+ # done by the generator machinery or by the generator itself.
+ with self.assertRaises(StopAsyncIteration):
+ async with woohoo():
+ raise StopAsyncIteration
+
+ def _create_contextmanager_attribs(self):
+ def attribs(**kw):
+ def decorate(func):
+ for k,v in kw.items():
+ setattr(func,k,v)
+ return func
+ return decorate
+ @asynccontextmanager
+ @attribs(foo='bar')
+ async def baz(spam):
+ """Whee!"""
+ yield
+ return baz
+
+ def test_contextmanager_attribs(self):
+ baz = self._create_contextmanager_attribs()
+ self.assertEqual(baz.__name__,'baz')
+ self.assertEqual(baz.foo, 'bar')
+
+ @support.requires_docstrings
+ def test_contextmanager_doc_attrib(self):
+ baz = self._create_contextmanager_attribs()
+ self.assertEqual(baz.__doc__, "Whee!")
+
+ @support.requires_docstrings
+ @_async_test
+ async def test_instance_docstring_given_cm_docstring(self):
+ baz = self._create_contextmanager_attribs()(None)
+ self.assertEqual(baz.__doc__, "Whee!")
+ async with baz:
+ pass # suppress warning
+
+ @_async_test
+ async def test_keywords(self):
+ # Ensure no keyword arguments are inhibited
+ @asynccontextmanager
+ async def woohoo(self, func, args, kwds):
+ yield (self, func, args, kwds)
+ async with woohoo(self=11, func=22, args=33, kwds=44) as target:
+ self.assertEqual(target, (11, 22, 33, 44))
+
+
+if __name__ == '__main__':
+ unittest.main()