summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorJason Fried <fried@fb.com>2021-09-23 21:36:03 (GMT)
committerGitHub <noreply@github.com>2021-09-23 21:36:03 (GMT)
commit86b833badd3d6864868404ead2f8c7994d24f85c (patch)
tree0610a6e8c1b1ca5e527181d58d29227b749ab484 /Lib
parentaf90b5498b8c6acd67b50fdad007d26dfd1c5823 (diff)
downloadcpython-86b833badd3d6864868404ead2f8c7994d24f85c.zip
cpython-86b833badd3d6864868404ead2f8c7994d24f85c.tar.gz
cpython-86b833badd3d6864868404ead2f8c7994d24f85c.tar.bz2
bpo-38415: Allow using @asynccontextmanager-made ctx managers as decorators (GH-16667)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/contextlib.py8
-rw-r--r--Lib/test/test_contextlib_async.py76
2 files changed, 84 insertions, 0 deletions
diff --git a/Lib/contextlib.py b/Lib/contextlib.py
index 8343d7e..1384d89 100644
--- a/Lib/contextlib.py
+++ b/Lib/contextlib.py
@@ -191,6 +191,14 @@ class _AsyncGeneratorContextManager(
):
"""Helper for @asynccontextmanager decorator."""
+ def __call__(self, func):
+ @wraps(func)
+ async def inner(*args, **kwds):
+ async with self.__class__(self.func, self.args, self.kwds):
+ return await func(*args, **kwds)
+
+ return inner
+
async def __aenter__(self):
# do not keep args and kwds alive unnecessarily
# they are only needed for recreation, which is not possible anymore
diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py
index 74fddef..c738bf3 100644
--- a/Lib/test/test_contextlib_async.py
+++ b/Lib/test/test_contextlib_async.py
@@ -318,6 +318,82 @@ class AsyncContextManagerTestCase(unittest.TestCase):
self.assertEqual(ncols, 10)
self.assertEqual(depth, 0)
+ @_async_test
+ async def test_decorator(self):
+ entered = False
+
+ @asynccontextmanager
+ async def context():
+ nonlocal entered
+ entered = True
+ yield
+ entered = False
+
+ @context()
+ async def test():
+ self.assertTrue(entered)
+
+ self.assertFalse(entered)
+ await test()
+ self.assertFalse(entered)
+
+ @_async_test
+ async def test_decorator_with_exception(self):
+ entered = False
+
+ @asynccontextmanager
+ async def context():
+ nonlocal entered
+ try:
+ entered = True
+ yield
+ finally:
+ entered = False
+
+ @context()
+ async def test():
+ self.assertTrue(entered)
+ raise NameError('foo')
+
+ self.assertFalse(entered)
+ with self.assertRaisesRegex(NameError, 'foo'):
+ await test()
+ self.assertFalse(entered)
+
+ @_async_test
+ async def test_decorating_method(self):
+
+ @asynccontextmanager
+ async def context():
+ yield
+
+
+ class Test(object):
+
+ @context()
+ async def method(self, a, b, c=None):
+ self.a = a
+ self.b = b
+ self.c = c
+
+ # these tests are for argument passing when used as a decorator
+ test = Test()
+ await test.method(1, 2)
+ self.assertEqual(test.a, 1)
+ self.assertEqual(test.b, 2)
+ self.assertEqual(test.c, None)
+
+ test = Test()
+ await test.method('a', 'b', 'c')
+ self.assertEqual(test.a, 'a')
+ self.assertEqual(test.b, 'b')
+ self.assertEqual(test.c, 'c')
+
+ test = Test()
+ await test.method(a=1, b=2)
+ self.assertEqual(test.a, 1)
+ self.assertEqual(test.b, 2)
+
class AclosingTestCase(unittest.TestCase):