diff options
author | Jason Fried <fried@fb.com> | 2021-09-23 21:36:03 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-23 21:36:03 (GMT) |
commit | 86b833badd3d6864868404ead2f8c7994d24f85c (patch) | |
tree | 0610a6e8c1b1ca5e527181d58d29227b749ab484 /Lib | |
parent | af90b5498b8c6acd67b50fdad007d26dfd1c5823 (diff) | |
download | cpython-86b833badd3d6864868404ead2f8c7994d24f85c.zip cpython-86b833badd3d6864868404ead2f8c7994d24f85c.tar.gz cpython-86b833badd3d6864868404ead2f8c7994d24f85c.tar.bz2 |
bpo-38415: Allow using @asynccontextmanager-made ctx managers as decorators (GH-16667)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/contextlib.py | 8 | ||||
-rw-r--r-- | Lib/test/test_contextlib_async.py | 76 |
2 files changed, 84 insertions, 0 deletions
diff --git a/Lib/contextlib.py b/Lib/contextlib.py index 8343d7e..1384d89 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -191,6 +191,14 @@ class _AsyncGeneratorContextManager( ): """Helper for @asynccontextmanager decorator.""" + def __call__(self, func): + @wraps(func) + async def inner(*args, **kwds): + async with self.__class__(self.func, self.args, self.kwds): + return await func(*args, **kwds) + + return inner + async def __aenter__(self): # do not keep args and kwds alive unnecessarily # they are only needed for recreation, which is not possible anymore diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py index 74fddef..c738bf3 100644 --- a/Lib/test/test_contextlib_async.py +++ b/Lib/test/test_contextlib_async.py @@ -318,6 +318,82 @@ class AsyncContextManagerTestCase(unittest.TestCase): self.assertEqual(ncols, 10) self.assertEqual(depth, 0) + @_async_test + async def test_decorator(self): + entered = False + + @asynccontextmanager + async def context(): + nonlocal entered + entered = True + yield + entered = False + + @context() + async def test(): + self.assertTrue(entered) + + self.assertFalse(entered) + await test() + self.assertFalse(entered) + + @_async_test + async def test_decorator_with_exception(self): + entered = False + + @asynccontextmanager + async def context(): + nonlocal entered + try: + entered = True + yield + finally: + entered = False + + @context() + async def test(): + self.assertTrue(entered) + raise NameError('foo') + + self.assertFalse(entered) + with self.assertRaisesRegex(NameError, 'foo'): + await test() + self.assertFalse(entered) + + @_async_test + async def test_decorating_method(self): + + @asynccontextmanager + async def context(): + yield + + + class Test(object): + + @context() + async def method(self, a, b, c=None): + self.a = a + self.b = b + self.c = c + + # these tests are for argument passing when used as a decorator + test = Test() + await test.method(1, 2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + self.assertEqual(test.c, None) + + test = Test() + await test.method('a', 'b', 'c') + self.assertEqual(test.a, 'a') + self.assertEqual(test.b, 'b') + self.assertEqual(test.c, 'c') + + test = Test() + await test.method(a=1, b=2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + class AclosingTestCase(unittest.TestCase): |