summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYury Selivanov <yselivanov@sprymix.com>2016-03-02 15:49:16 (GMT)
committerYury Selivanov <yselivanov@sprymix.com>2016-03-02 15:49:16 (GMT)
commitdce63234c55db7395ccc62d5e6e96c19696871e8 (patch)
tree3453ea06db7753a0aa633c4f2f5969f99e22bed9
parent0c6a34409ea9cf9a82171da09801f086d472aa89 (diff)
downloadcpython-dce63234c55db7395ccc62d5e6e96c19696871e8.zip
cpython-dce63234c55db7395ccc62d5e6e96c19696871e8.tar.gz
cpython-dce63234c55db7395ccc62d5e6e96c19696871e8.tar.bz2
asyncio: Fix @coroutine to recognize CoroWrapper (issue #25647)
Patch by Vladimir Rutsky.
-rw-r--r--Lib/asyncio/coroutines.py3
-rw-r--r--Lib/test/test_asyncio/test_tasks.py24
2 files changed, 26 insertions, 1 deletions
diff --git a/Lib/asyncio/coroutines.py b/Lib/asyncio/coroutines.py
index 27ab42a..71bc6fb 100644
--- a/Lib/asyncio/coroutines.py
+++ b/Lib/asyncio/coroutines.py
@@ -204,7 +204,8 @@ def coroutine(func):
@functools.wraps(func)
def coro(*args, **kw):
res = func(*args, **kw)
- if isinstance(res, futures.Future) or inspect.isgenerator(res):
+ if isinstance(res, futures.Future) or inspect.isgenerator(res) or \
+ isinstance(res, CoroWrapper):
res = yield from res
elif _AwaitableABC is not None:
# If 'func' returns an Awaitable (new in 3.5) we
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index c9d49f0..acceb9b 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -1794,6 +1794,30 @@ class TaskTests(test_utils.TestCase):
self.assertRegex(message, re.compile(regex, re.DOTALL))
+ def test_return_coroutine_from_coroutine(self):
+ """Return of @asyncio.coroutine()-wrapped function generator object
+ from @asyncio.coroutine()-wrapped function should have same effect as
+ returning generator object or Future."""
+ def check():
+ @asyncio.coroutine
+ def outer_coro():
+ @asyncio.coroutine
+ def inner_coro():
+ return 1
+
+ return inner_coro()
+
+ result = self.loop.run_until_complete(outer_coro())
+ self.assertEqual(result, 1)
+
+ # Test with debug flag cleared.
+ with set_coroutine_debug(False):
+ check()
+
+ # Test with debug flag set.
+ with set_coroutine_debug(True):
+ check()
+
def test_task_source_traceback(self):
self.loop.set_debug(True)