summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/contextlib.py8
-rw-r--r--Lib/test/test_contextlib_async.py76
2 files changed, 84 insertions, 0 deletions
diff --git a/Lib/contextlib.py b/Lib/contextlib.py
index 8343d7e..1384d89 100644
--- a/Lib/contextlib.py
+++ b/Lib/contextlib.py
@@ -191,6 +191,14 @@ class _AsyncGeneratorContextManager(
):
"""Helper for @asynccontextmanager decorator."""
+ def __call__(self, func):
+ @wraps(func)
+ async def inner(*args, **kwds):
+ async with self.__class__(self.func, self.args, self.kwds):
+ return await func(*args, **kwds)
+
+ return inner
+
async def __aenter__(self):
# do not keep args and kwds alive unnecessarily
# they are only needed for recreation, which is not possible anymore
diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py
index 74fddef..c738bf3 100644
--- a/Lib/test/test_contextlib_async.py
+++ b/Lib/test/test_contextlib_async.py
@@ -318,6 +318,82 @@ class AsyncContextManagerTestCase(unittest.TestCase):
self.assertEqual(ncols, 10)
self.assertEqual(depth, 0)
+ @_async_test
+ async def test_decorator(self):
+ entered = False
+
+ @asynccontextmanager
+ async def context():
+ nonlocal entered
+ entered = True
+ yield
+ entered = False
+
+ @context()
+ async def test():
+ self.assertTrue(entered)
+
+ self.assertFalse(entered)
+ await test()
+ self.assertFalse(entered)
+
+ @_async_test
+ async def test_decorator_with_exception(self):
+ entered = False
+
+ @asynccontextmanager
+ async def context():
+ nonlocal entered
+ try:
+ entered = True
+ yield
+ finally:
+ entered = False
+
+ @context()
+ async def test():
+ self.assertTrue(entered)
+ raise NameError('foo')
+
+ self.assertFalse(entered)
+ with self.assertRaisesRegex(NameError, 'foo'):
+ await test()
+ self.assertFalse(entered)
+
+ @_async_test
+ async def test_decorating_method(self):
+
+ @asynccontextmanager
+ async def context():
+ yield
+
+
+ class Test(object):
+
+ @context()
+ async def method(self, a, b, c=None):
+ self.a = a
+ self.b = b
+ self.c = c
+
+ # these tests are for argument passing when used as a decorator
+ test = Test()
+ await test.method(1, 2)
+ self.assertEqual(test.a, 1)
+ self.assertEqual(test.b, 2)
+ self.assertEqual(test.c, None)
+
+ test = Test()
+ await test.method('a', 'b', 'c')
+ self.assertEqual(test.a, 'a')
+ self.assertEqual(test.b, 'b')
+ self.assertEqual(test.c, 'c')
+
+ test = Test()
+ await test.method(a=1, b=2)
+ self.assertEqual(test.a, 1)
+ self.assertEqual(test.b, 2)
+
class AclosingTestCase(unittest.TestCase):