summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorJelle Zijlstra <jelle.zijlstra@gmail.com>2017-05-01 01:25:58 (GMT)
committerYury Selivanov <yselivanov@gmail.com>2017-05-01 01:25:58 (GMT)
commit2e624690bd74071358566300b7ef0bc45f444a30 (patch)
treef96176f5997f38c00974854907b586ce887981a3 /Lib
parent9dc2b3809f38be2e403ee264958106badfda142d (diff)
downloadcpython-2e624690bd74071358566300b7ef0bc45f444a30.zip
cpython-2e624690bd74071358566300b7ef0bc45f444a30.tar.gz
cpython-2e624690bd74071358566300b7ef0bc45f444a30.tar.bz2
bpo-29679: Implement @contextlib.asynccontextmanager (#360)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/contextlib.py99
-rw-r--r--Lib/test/test_contextlib_async.py212
2 files changed, 305 insertions, 6 deletions
diff --git a/Lib/contextlib.py b/Lib/contextlib.py
index 5e47054..c53b35e 100644
--- a/Lib/contextlib.py
+++ b/Lib/contextlib.py
@@ -4,9 +4,9 @@ import sys
from collections import deque
from functools import wraps
-__all__ = ["contextmanager", "closing", "AbstractContextManager",
- "ContextDecorator", "ExitStack", "redirect_stdout",
- "redirect_stderr", "suppress"]
+__all__ = ["asynccontextmanager", "contextmanager", "closing",
+ "AbstractContextManager", "ContextDecorator", "ExitStack",
+ "redirect_stdout", "redirect_stderr", "suppress"]
class AbstractContextManager(abc.ABC):
@@ -54,8 +54,8 @@ class ContextDecorator(object):
return inner
-class _GeneratorContextManager(ContextDecorator, AbstractContextManager):
- """Helper for @contextmanager decorator."""
+class _GeneratorContextManagerBase:
+ """Shared functionality for @contextmanager and @asynccontextmanager."""
def __init__(self, func, args, kwds):
self.gen = func(*args, **kwds)
@@ -71,6 +71,12 @@ class _GeneratorContextManager(ContextDecorator, AbstractContextManager):
# for the class instead.
# See http://bugs.python.org/issue19404 for more details.
+
+class _GeneratorContextManager(_GeneratorContextManagerBase,
+ AbstractContextManager,
+ ContextDecorator):
+ """Helper for @contextmanager decorator."""
+
def _recreate_cm(self):
# _GCM instances are one-shot context managers, so the
# CM must be recreated each time a decorated function is
@@ -121,12 +127,61 @@ class _GeneratorContextManager(ContextDecorator, AbstractContextManager):
# fixes the impedance mismatch between the throw() protocol
# and the __exit__() protocol.
#
+ # This cannot use 'except BaseException as exc' (as in the
+ # async implementation) to maintain compatibility with
+ # Python 2, where old-style class exceptions are not caught
+ # by 'except BaseException'.
if sys.exc_info()[1] is value:
return False
raise
raise RuntimeError("generator didn't stop after throw()")
+class _AsyncGeneratorContextManager(_GeneratorContextManagerBase):
+ """Helper for @asynccontextmanager."""
+
+ async def __aenter__(self):
+ try:
+ return await self.gen.__anext__()
+ except StopAsyncIteration:
+ raise RuntimeError("generator didn't yield") from None
+
+ async def __aexit__(self, typ, value, traceback):
+ if typ is None:
+ try:
+ await self.gen.__anext__()
+ except StopAsyncIteration:
+ return
+ else:
+ raise RuntimeError("generator didn't stop")
+ else:
+ if value is None:
+ value = typ()
+ # See _GeneratorContextManager.__exit__ for comments on subtleties
+ # in this implementation
+ try:
+ await self.gen.athrow(typ, value, traceback)
+ raise RuntimeError("generator didn't stop after throw()")
+ except StopAsyncIteration as exc:
+ return exc is not value
+ except RuntimeError as exc:
+ if exc is value:
+ return False
+ # Avoid suppressing if a StopIteration exception
+ # was passed to throw() and later wrapped into a RuntimeError
+ # (see PEP 479 for sync generators; async generators also
+ # have this behavior). But do this only if the exception wrapped
+ # by the RuntimeError is actully Stop(Async)Iteration (see
+ # issue29692).
+ if isinstance(value, (StopIteration, StopAsyncIteration)):
+ if exc.__cause__ is value:
+ return False
+ raise
+ except BaseException as exc:
+ if exc is not value:
+ raise
+
+
def contextmanager(func):
"""@contextmanager decorator.
@@ -153,7 +208,6 @@ def contextmanager(func):
<body>
finally:
<cleanup>
-
"""
@wraps(func)
def helper(*args, **kwds):
@@ -161,6 +215,39 @@ def contextmanager(func):
return helper
+def asynccontextmanager(func):
+ """@asynccontextmanager decorator.
+
+ Typical usage:
+
+ @asynccontextmanager
+ async def some_async_generator(<arguments>):
+ <setup>
+ try:
+ yield <value>
+ finally:
+ <cleanup>
+
+ This makes this:
+
+ async with some_async_generator(<arguments>) as <variable>:
+ <body>
+
+ equivalent to this:
+
+ <setup>
+ try:
+ <variable> = <value>
+ <body>
+ finally:
+ <cleanup>
+ """
+ @wraps(func)
+ def helper(*args, **kwds):
+ return _AsyncGeneratorContextManager(func, args, kwds)
+ return helper
+
+
class closing(AbstractContextManager):
"""Context to automatically close something at the end of a block.
diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py
new file mode 100644
index 0000000..42cc331
--- /dev/null
+++ b/Lib/test/test_contextlib_async.py
@@ -0,0 +1,212 @@
+import asyncio
+from contextlib import asynccontextmanager
+import functools
+from test import support
+import unittest
+
+
+def _async_test(func):
+ """Decorator to turn an async function into a test case."""
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ coro = func(*args, **kwargs)
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ return loop.run_until_complete(coro)
+ finally:
+ loop.close()
+ asyncio.set_event_loop(None)
+ return wrapper
+
+
+class AsyncContextManagerTestCase(unittest.TestCase):
+
+ @_async_test
+ async def test_contextmanager_plain(self):
+ state = []
+ @asynccontextmanager
+ async def woohoo():
+ state.append(1)
+ yield 42
+ state.append(999)
+ async with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ self.assertEqual(state, [1, 42, 999])
+
+ @_async_test
+ async def test_contextmanager_finally(self):
+ state = []
+ @asynccontextmanager
+ async def woohoo():
+ state.append(1)
+ try:
+ yield 42
+ finally:
+ state.append(999)
+ with self.assertRaises(ZeroDivisionError):
+ async with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ raise ZeroDivisionError()
+ self.assertEqual(state, [1, 42, 999])
+
+ @_async_test
+ async def test_contextmanager_no_reraise(self):
+ @asynccontextmanager
+ async def whee():
+ yield
+ ctx = whee()
+ await ctx.__aenter__()
+ # Calling __aexit__ should not result in an exception
+ self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
+
+ @_async_test
+ async def test_contextmanager_trap_yield_after_throw(self):
+ @asynccontextmanager
+ async def whoo():
+ try:
+ yield
+ except:
+ yield
+ ctx = whoo()
+ await ctx.__aenter__()
+ with self.assertRaises(RuntimeError):
+ await ctx.__aexit__(TypeError, TypeError('foo'), None)
+
+ @_async_test
+ async def test_contextmanager_trap_no_yield(self):
+ @asynccontextmanager
+ async def whoo():
+ if False:
+ yield
+ ctx = whoo()
+ with self.assertRaises(RuntimeError):
+ await ctx.__aenter__()
+
+ @_async_test
+ async def test_contextmanager_trap_second_yield(self):
+ @asynccontextmanager
+ async def whoo():
+ yield
+ yield
+ ctx = whoo()
+ await ctx.__aenter__()
+ with self.assertRaises(RuntimeError):
+ await ctx.__aexit__(None, None, None)
+
+ @_async_test
+ async def test_contextmanager_non_normalised(self):
+ @asynccontextmanager
+ async def whoo():
+ try:
+ yield
+ except RuntimeError:
+ raise SyntaxError
+
+ ctx = whoo()
+ await ctx.__aenter__()
+ with self.assertRaises(SyntaxError):
+ await ctx.__aexit__(RuntimeError, None, None)
+
+ @_async_test
+ async def test_contextmanager_except(self):
+ state = []
+ @asynccontextmanager
+ async def woohoo():
+ state.append(1)
+ try:
+ yield 42
+ except ZeroDivisionError as e:
+ state.append(e.args[0])
+ self.assertEqual(state, [1, 42, 999])
+ async with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ raise ZeroDivisionError(999)
+ self.assertEqual(state, [1, 42, 999])
+
+ @_async_test
+ async def test_contextmanager_except_stopiter(self):
+ @asynccontextmanager
+ async def woohoo():
+ yield
+
+ for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
+ with self.subTest(type=type(stop_exc)):
+ try:
+ async with woohoo():
+ raise stop_exc
+ except Exception as ex:
+ self.assertIs(ex, stop_exc)
+ else:
+ self.fail(f'{stop_exc} was suppressed')
+
+ @_async_test
+ async def test_contextmanager_wrap_runtimeerror(self):
+ @asynccontextmanager
+ async def woohoo():
+ try:
+ yield
+ except Exception as exc:
+ raise RuntimeError(f'caught {exc}') from exc
+
+ with self.assertRaises(RuntimeError):
+ async with woohoo():
+ 1 / 0
+
+ # If the context manager wrapped StopAsyncIteration in a RuntimeError,
+ # we also unwrap it, because we can't tell whether the wrapping was
+ # done by the generator machinery or by the generator itself.
+ with self.assertRaises(StopAsyncIteration):
+ async with woohoo():
+ raise StopAsyncIteration
+
+ def _create_contextmanager_attribs(self):
+ def attribs(**kw):
+ def decorate(func):
+ for k,v in kw.items():
+ setattr(func,k,v)
+ return func
+ return decorate
+ @asynccontextmanager
+ @attribs(foo='bar')
+ async def baz(spam):
+ """Whee!"""
+ yield
+ return baz
+
+ def test_contextmanager_attribs(self):
+ baz = self._create_contextmanager_attribs()
+ self.assertEqual(baz.__name__,'baz')
+ self.assertEqual(baz.foo, 'bar')
+
+ @support.requires_docstrings
+ def test_contextmanager_doc_attrib(self):
+ baz = self._create_contextmanager_attribs()
+ self.assertEqual(baz.__doc__, "Whee!")
+
+ @support.requires_docstrings
+ @_async_test
+ async def test_instance_docstring_given_cm_docstring(self):
+ baz = self._create_contextmanager_attribs()(None)
+ self.assertEqual(baz.__doc__, "Whee!")
+ async with baz:
+ pass # suppress warning
+
+ @_async_test
+ async def test_keywords(self):
+ # Ensure no keyword arguments are inhibited
+ @asynccontextmanager
+ async def woohoo(self, func, args, kwds):
+ yield (self, func, args, kwds)
+ async with woohoo(self=11, func=22, args=33, kwds=44) as target:
+ self.assertEqual(target, (11, 22, 33, 44))
+
+
+if __name__ == '__main__':
+ unittest.main()