summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorYury Selivanov <yselivanov@sprymix.com>2015-06-22 16:19:30 (GMT)
committerYury Selivanov <yselivanov@sprymix.com>2015-06-22 16:19:30 (GMT)
commit5376ba9630e45ad177150ae68c9712640330a2fc (patch)
tree68eacabe0721f40098654fe5f2e0b0e3391d95c1 /Lib
parentcd881b850c95cdb410620f3acc6ebf37e5467192 (diff)
downloadcpython-5376ba9630e45ad177150ae68c9712640330a2fc.zip
cpython-5376ba9630e45ad177150ae68c9712640330a2fc.tar.gz
cpython-5376ba9630e45ad177150ae68c9712640330a2fc.tar.bz2
Issue #24400: Introduce a distinct type for 'async def' coroutines.
Summary of changes: 1. Coroutines now have a distinct, separate from generators type at the C level: PyGen_Type, and a new typedef PyCoroObject. PyCoroObject shares the initial segment of struct layout with PyGenObject, making it possible to reuse existing generators machinery. The new type is exposed as 'types.CoroutineType'. As a consequence of having a new type, CO_GENERATOR flag is no longer applied to coroutines. 2. Having a separate type for coroutines made it possible to add an __await__ method to the type. Although it is not used by the interpreter (see details on that below), it makes coroutines naturally (without using __instancecheck__) conform to collections.abc.Coroutine and collections.abc.Awaitable ABCs. [The __instancecheck__ is still used for generator-based coroutines, as we don't want to add __await__ for generators.] 3. Add new opcode: GET_YIELD_FROM_ITER. The opcode is needed to allow passing native coroutines to the YIELD_FROM opcode. Before this change, 'yield from o' expression was compiled to: (o) GET_ITER LOAD_CONST YIELD_FROM Now, we use GET_YIELD_FROM_ITER instead of GET_ITER. The reason for adding a new opcode is that GET_ITER is used in some contexts (such as 'for .. in' loops) where passing a coroutine object is invalid. 4. Add two new introspection functions to the inspec module: getcoroutinestate(c) and getcoroutinelocals(c). 5. inspect.iscoroutine(o) is updated to test if 'o' is a native coroutine object. Before this commit it used abc.Coroutine, and it was requested to update inspect.isgenerator(o) to use abc.Generator; it was decided, however, that inspect functions should really be tailored for checking for native types. 6. sys.set_coroutine_wrapper(w) API is updated to work with only native coroutines. Since types.coroutine decorator supports any type of callables now, it would be confusing that it does not work for all types of coroutines. 7. Exceptions logic in generators C implementation was updated to raise clearer messages for coroutines: Before: TypeError("generator raised StopIteration") After: TypeError("coroutine raised StopIteration")
Diffstat (limited to 'Lib')
-rw-r--r--Lib/_collections_abc.py27
-rw-r--r--Lib/asyncio/coroutines.py48
-rw-r--r--Lib/importlib/_bootstrap_external.py3
-rw-r--r--Lib/inspect.py50
-rw-r--r--Lib/opcode.py1
-rw-r--r--Lib/test/test_coroutines.py160
-rw-r--r--Lib/test/test_inspect.py71
-rw-r--r--Lib/test/test_types.py54
-rw-r--r--Lib/types.py85
9 files changed, 394 insertions, 105 deletions
diff --git a/Lib/_collections_abc.py b/Lib/_collections_abc.py
index a02b219..ba6a9b8 100644
--- a/Lib/_collections_abc.py
+++ b/Lib/_collections_abc.py
@@ -52,6 +52,12 @@ dict_items = type({}.items())
## misc ##
mappingproxy = type(type.__dict__)
generator = type((lambda: (yield))())
+## coroutine ##
+async def _coro(): pass
+_coro = _coro()
+coroutine = type(_coro)
+_coro.close() # Prevent ResourceWarning
+del _coro
### ONE-TRICK PONIES ###
@@ -78,17 +84,15 @@ class Hashable(metaclass=ABCMeta):
class _AwaitableMeta(ABCMeta):
def __instancecheck__(cls, instance):
- # 0x80 = CO_COROUTINE
- # 0x100 = CO_ITERABLE_COROUTINE
- # We don't want to import 'inspect' module, as
- # a dependency for 'collections.abc'.
- CO_COROUTINES = 0x80 | 0x100
-
- if (isinstance(instance, generator) and
- instance.gi_code.co_flags & CO_COROUTINES):
-
+ # This hook is needed because we can't add
+ # '__await__' method to generator objects, and
+ # we can't register GeneratorType on Awaitable.
+ # NB: 0x100 = CO_ITERABLE_COROUTINE
+ # (We don't want to import 'inspect' module, as
+ # a dependency for 'collections.abc')
+ if (instance.__class__ is generator and
+ instance.gi_code.co_flags & 0x100):
return True
-
return super().__instancecheck__(instance)
@@ -159,6 +163,9 @@ class Coroutine(Awaitable):
return NotImplemented
+Coroutine.register(coroutine)
+
+
class AsyncIterable(metaclass=ABCMeta):
__slots__ = ()
diff --git a/Lib/asyncio/coroutines.py b/Lib/asyncio/coroutines.py
index edb6806..4fc46a5 100644
--- a/Lib/asyncio/coroutines.py
+++ b/Lib/asyncio/coroutines.py
@@ -34,30 +34,20 @@ _DEBUG = (not sys.flags.ignore_environment
try:
- types.coroutine
+ _types_coroutine = types.coroutine
except AttributeError:
- native_coroutine_support = False
-else:
- native_coroutine_support = True
+ _types_coroutine = None
try:
- _iscoroutinefunction = inspect.iscoroutinefunction
+ _inspect_iscoroutinefunction = inspect.iscoroutinefunction
except AttributeError:
- _iscoroutinefunction = lambda func: False
+ _inspect_iscoroutinefunction = lambda func: False
try:
- inspect.CO_COROUTINE
-except AttributeError:
- _is_native_coro_code = lambda code: False
-else:
- _is_native_coro_code = lambda code: (code.co_flags &
- inspect.CO_COROUTINE)
-
-try:
- from collections.abc import Coroutine as CoroutineABC, \
- Awaitable as AwaitableABC
+ from collections.abc import Coroutine as _CoroutineABC, \
+ Awaitable as _AwaitableABC
except ImportError:
- CoroutineABC = AwaitableABC = None
+ _CoroutineABC = _AwaitableABC = None
# Check for CPython issue #21209
@@ -89,10 +79,7 @@ def debug_wrapper(gen):
# We only wrap here coroutines defined via 'async def' syntax.
# Generator-based coroutines are wrapped in @coroutine
# decorator.
- if _is_native_coro_code(gen.gi_code):
- return CoroWrapper(gen, None)
- else:
- return gen
+ return CoroWrapper(gen, None)
class CoroWrapper:
@@ -177,8 +164,7 @@ def coroutine(func):
If the coroutine is not yielded from before it is destroyed,
an error message is logged.
"""
- is_coroutine = _iscoroutinefunction(func)
- if is_coroutine and _is_native_coro_code(func.__code__):
+ if _inspect_iscoroutinefunction(func):
# In Python 3.5 that's all we need to do for coroutines
# defiend with "async def".
# Wrapping in CoroWrapper will happen via
@@ -193,7 +179,7 @@ def coroutine(func):
res = func(*args, **kw)
if isinstance(res, futures.Future) or inspect.isgenerator(res):
res = yield from res
- elif AwaitableABC is not None:
+ elif _AwaitableABC is not None:
# If 'func' returns an Awaitable (new in 3.5) we
# want to run it.
try:
@@ -201,15 +187,15 @@ def coroutine(func):
except AttributeError:
pass
else:
- if isinstance(res, AwaitableABC):
+ if isinstance(res, _AwaitableABC):
res = yield from await_meth()
return res
if not _DEBUG:
- if native_coroutine_support:
- wrapper = types.coroutine(coro)
- else:
+ if _types_coroutine is None:
wrapper = coro
+ else:
+ wrapper = _types_coroutine(coro)
else:
@functools.wraps(func)
def wrapper(*args, **kwds):
@@ -231,12 +217,12 @@ def coroutine(func):
def iscoroutinefunction(func):
"""Return True if func is a decorated coroutine function."""
return (getattr(func, '_is_coroutine', False) or
- _iscoroutinefunction(func))
+ _inspect_iscoroutinefunction(func))
_COROUTINE_TYPES = (types.GeneratorType, CoroWrapper)
-if CoroutineABC is not None:
- _COROUTINE_TYPES += (CoroutineABC,)
+if _CoroutineABC is not None:
+ _COROUTINE_TYPES += (_CoroutineABC,)
def iscoroutine(obj):
diff --git a/Lib/importlib/_bootstrap_external.py b/Lib/importlib/_bootstrap_external.py
index b3c10b9..3508ce9 100644
--- a/Lib/importlib/_bootstrap_external.py
+++ b/Lib/importlib/_bootstrap_external.py
@@ -222,12 +222,13 @@ _code_type = type(_write_atomic.__code__)
# Python 3.5a0 3320 (matrix multiplication operator)
# Python 3.5b1 3330 (PEP 448: Additional Unpacking Generalizations)
# Python 3.5b2 3340 (fix dictionary display evaluation order #11205)
+# Python 3.5b2 3350 (add GET_YIELD_FROM_ITER opcode #24400)
#
# MAGIC must change whenever the bytecode emitted by the compiler may no
# longer be understood by older implementations of the eval loop (usually
# due to the addition of new opcodes).
-MAGIC_NUMBER = (3340).to_bytes(2, 'little') + b'\r\n'
+MAGIC_NUMBER = (3350).to_bytes(2, 'little') + b'\r\n'
_RAW_MAGIC_NUMBER = int.from_bytes(MAGIC_NUMBER, 'little') # For import.c
_PYCACHE = '__pycache__'
diff --git a/Lib/inspect.py b/Lib/inspect.py
index 25ddd26..6285a6c 100644
--- a/Lib/inspect.py
+++ b/Lib/inspect.py
@@ -175,8 +175,7 @@ def isgeneratorfunction(object):
See help(isfunction) for attributes listing."""
return bool((isfunction(object) or ismethod(object)) and
- object.__code__.co_flags & CO_GENERATOR and
- not object.__code__.co_flags & CO_COROUTINE)
+ object.__code__.co_flags & CO_GENERATOR)
def iscoroutinefunction(object):
"""Return true if the object is a coroutine function.
@@ -185,8 +184,7 @@ def iscoroutinefunction(object):
or generators decorated with "types.coroutine".
"""
return bool((isfunction(object) or ismethod(object)) and
- object.__code__.co_flags & (CO_ITERABLE_COROUTINE |
- CO_COROUTINE))
+ object.__code__.co_flags & CO_COROUTINE)
def isawaitable(object):
"""Return true if the object can be used in "await" expression."""
@@ -207,12 +205,11 @@ def isgenerator(object):
send resumes the generator and "sends" a value that becomes
the result of the current yield-expression
throw used to raise an exception inside the generator"""
- return (isinstance(object, types.GeneratorType) and
- not object.gi_code.co_flags & CO_COROUTINE)
+ return isinstance(object, types.GeneratorType)
def iscoroutine(object):
"""Return true if the object is a coroutine."""
- return isinstance(object, collections.abc.Coroutine)
+ return isinstance(object, types.CoroutineType)
def istraceback(object):
"""Return true if the object is a traceback.
@@ -1598,6 +1595,45 @@ def getgeneratorlocals(generator):
else:
return {}
+
+# ------------------------------------------------ coroutine introspection
+
+CORO_CREATED = 'CORO_CREATED'
+CORO_RUNNING = 'CORO_RUNNING'
+CORO_SUSPENDED = 'CORO_SUSPENDED'
+CORO_CLOSED = 'CORO_CLOSED'
+
+def getcoroutinestate(coroutine):
+ """Get current state of a coroutine object.
+
+ Possible states are:
+ CORO_CREATED: Waiting to start execution.
+ CORO_RUNNING: Currently being executed by the interpreter.
+ CORO_SUSPENDED: Currently suspended at an await expression.
+ CORO_CLOSED: Execution has completed.
+ """
+ if coroutine.cr_running:
+ return CORO_RUNNING
+ if coroutine.cr_frame is None:
+ return CORO_CLOSED
+ if coroutine.cr_frame.f_lasti == -1:
+ return CORO_CREATED
+ return CORO_SUSPENDED
+
+
+def getcoroutinelocals(coroutine):
+ """
+ Get the mapping of coroutine local variables to their current values.
+
+ A dict is returned, with the keys the local variable names and values the
+ bound values."""
+ frame = getattr(coroutine, "cr_frame", None)
+ if frame is not None:
+ return frame.f_locals
+ else:
+ return {}
+
+
###############################################################################
### Function Signature Object (PEP 362)
###############################################################################
diff --git a/Lib/opcode.py b/Lib/opcode.py
index c7b3443..4c826a7 100644
--- a/Lib/opcode.py
+++ b/Lib/opcode.py
@@ -103,6 +103,7 @@ def_op('BINARY_XOR', 65)
def_op('BINARY_OR', 66)
def_op('INPLACE_POWER', 67)
def_op('GET_ITER', 68)
+def_op('GET_YIELD_FROM_ITER', 69)
def_op('PRINT_EXPR', 70)
def_op('LOAD_BUILD_CLASS', 71)
diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py
index e3a3304..8d2b1a3 100644
--- a/Lib/test/test_coroutines.py
+++ b/Lib/test/test_coroutines.py
@@ -24,7 +24,7 @@ class AsyncYield:
def run_async(coro):
- assert coro.__class__ is types.GeneratorType
+ assert coro.__class__ in {types.GeneratorType, types.CoroutineType}
buffer = []
result = None
@@ -37,6 +37,25 @@ def run_async(coro):
return buffer, result
+def run_async__await__(coro):
+ assert coro.__class__ is types.CoroutineType
+ aw = coro.__await__()
+ buffer = []
+ result = None
+ i = 0
+ while True:
+ try:
+ if i % 2:
+ buffer.append(next(aw))
+ else:
+ buffer.append(aw.send(None))
+ i += 1
+ except StopIteration as ex:
+ result = ex.args[0] if ex.args else None
+ break
+ return buffer, result
+
+
@contextlib.contextmanager
def silence_coro_gc():
with warnings.catch_warnings():
@@ -121,22 +140,24 @@ class CoroutineTest(unittest.TestCase):
return 10
f = foo()
- self.assertIsInstance(f, types.GeneratorType)
- self.assertTrue(bool(foo.__code__.co_flags & 0x80))
- self.assertTrue(bool(foo.__code__.co_flags & 0x20))
- self.assertTrue(bool(f.gi_code.co_flags & 0x80))
- self.assertTrue(bool(f.gi_code.co_flags & 0x20))
+ self.assertIsInstance(f, types.CoroutineType)
+ self.assertTrue(bool(foo.__code__.co_flags & inspect.CO_COROUTINE))
+ self.assertFalse(bool(foo.__code__.co_flags & inspect.CO_GENERATOR))
+ self.assertTrue(bool(f.cr_code.co_flags & inspect.CO_COROUTINE))
+ self.assertFalse(bool(f.cr_code.co_flags & inspect.CO_GENERATOR))
self.assertEqual(run_async(f), ([], 10))
+ self.assertEqual(run_async__await__(foo()), ([], 10))
+
def bar(): pass
- self.assertFalse(bool(bar.__code__.co_flags & 0x80))
+ self.assertFalse(bool(bar.__code__.co_flags & inspect.CO_COROUTINE))
def test_func_2(self):
async def foo():
raise StopIteration
with self.assertRaisesRegex(
- RuntimeError, "generator raised StopIteration"):
+ RuntimeError, "coroutine raised StopIteration"):
run_async(foo())
@@ -152,7 +173,7 @@ class CoroutineTest(unittest.TestCase):
raise StopIteration
check = lambda: self.assertRaisesRegex(
- TypeError, "coroutine-objects do not support iteration")
+ TypeError, "'coroutine' object is not iterable")
with check():
list(foo())
@@ -166,9 +187,6 @@ class CoroutineTest(unittest.TestCase):
with check():
iter(foo())
- with check():
- next(foo())
-
with silence_coro_gc(), check():
for i in foo():
pass
@@ -185,7 +203,7 @@ class CoroutineTest(unittest.TestCase):
await bar()
check = lambda: self.assertRaisesRegex(
- TypeError, "coroutine-objects do not support iteration")
+ TypeError, "'coroutine' object is not iterable")
with check():
for el in foo(): pass
@@ -221,7 +239,7 @@ class CoroutineTest(unittest.TestCase):
with silence_coro_gc(), self.assertRaisesRegex(
TypeError,
- "cannot 'yield from' a coroutine object from a generator"):
+ "cannot 'yield from' a coroutine object in a non-coroutine generator"):
list(foo())
@@ -244,6 +262,98 @@ class CoroutineTest(unittest.TestCase):
foo()
support.gc_collect()
+ def test_func_10(self):
+ N = 0
+
+ @types.coroutine
+ def gen():
+ nonlocal N
+ try:
+ a = yield
+ yield (a ** 2)
+ except ZeroDivisionError:
+ N += 100
+ raise
+ finally:
+ N += 1
+
+ async def foo():
+ await gen()
+
+ coro = foo()
+ aw = coro.__await__()
+ self.assertIs(aw, iter(aw))
+ next(aw)
+ self.assertEqual(aw.send(10), 100)
+
+ self.assertEqual(N, 0)
+ aw.close()
+ self.assertEqual(N, 1)
+
+ coro = foo()
+ aw = coro.__await__()
+ next(aw)
+ with self.assertRaises(ZeroDivisionError):
+ aw.throw(ZeroDivisionError, None, None)
+ self.assertEqual(N, 102)
+
+ def test_func_11(self):
+ async def func(): pass
+ coro = func()
+ # Test that PyCoro_Type and _PyCoroWrapper_Type types were properly
+ # initialized
+ self.assertIn('__await__', dir(coro))
+ self.assertIn('__iter__', dir(coro.__await__()))
+ self.assertIn('coroutine_wrapper', repr(coro.__await__()))
+ coro.close() # avoid RuntimeWarning
+
+ def test_func_12(self):
+ async def g():
+ i = me.send(None)
+ await foo
+ me = g()
+ with self.assertRaisesRegex(ValueError,
+ "coroutine already executing"):
+ me.send(None)
+
+ def test_func_13(self):
+ async def g():
+ pass
+ with self.assertRaisesRegex(
+ TypeError,
+ "can't send non-None value to a just-started coroutine"):
+
+ g().send('spam')
+
+ def test_func_14(self):
+ @types.coroutine
+ def gen():
+ yield
+ async def coro():
+ try:
+ await gen()
+ except GeneratorExit:
+ await gen()
+ c = coro()
+ c.send(None)
+ with self.assertRaisesRegex(RuntimeError,
+ "coroutine ignored GeneratorExit"):
+ c.close()
+
+ def test_corotype_1(self):
+ ct = types.CoroutineType
+ self.assertIn('into coroutine', ct.send.__doc__)
+ self.assertIn('inside coroutine', ct.close.__doc__)
+ self.assertIn('in coroutine', ct.throw.__doc__)
+ self.assertIn('of the coroutine', ct.__dict__['__name__'].__doc__)
+ self.assertIn('of the coroutine', ct.__dict__['__qualname__'].__doc__)
+ self.assertEqual(ct.__name__, 'coroutine')
+
+ async def f(): pass
+ c = f()
+ self.assertIn('coroutine object', repr(c))
+ c.close()
+
def test_await_1(self):
async def foo():
@@ -262,6 +372,7 @@ class CoroutineTest(unittest.TestCase):
await AsyncYieldFrom([1, 2, 3])
self.assertEqual(run_async(foo()), ([1, 2, 3], None))
+ self.assertEqual(run_async__await__(foo()), ([1, 2, 3], None))
def test_await_4(self):
async def bar():
@@ -1015,6 +1126,27 @@ class SysSetCoroWrapperTest(unittest.TestCase):
finally:
sys.set_coroutine_wrapper(None)
+ def test_set_wrapper_4(self):
+ @types.coroutine
+ def foo():
+ return 'spam'
+
+ wrapped = None
+ def wrap(gen):
+ nonlocal wrapped
+ wrapped = gen
+ return gen
+
+ sys.set_coroutine_wrapper(wrap)
+ try:
+ foo()
+ self.assertIs(
+ wrapped, None,
+ "generator-based coroutine was wrapped via "
+ "sys.set_coroutine_wrapper")
+ finally:
+ sys.set_coroutine_wrapper(None)
+
class CAPITest(unittest.TestCase):
diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py
index 4695da8..39fa484 100644
--- a/Lib/test/test_inspect.py
+++ b/Lib/test/test_inspect.py
@@ -141,9 +141,9 @@ class TestPredicates(IsTestBase):
gen_coro = gen_coroutine_function_example(1)
coro = coroutine_function_example(1)
- self.assertTrue(
+ self.assertFalse(
inspect.iscoroutinefunction(gen_coroutine_function_example))
- self.assertTrue(inspect.iscoroutine(gen_coro))
+ self.assertFalse(inspect.iscoroutine(gen_coro))
self.assertTrue(
inspect.isgeneratorfunction(gen_coroutine_function_example))
@@ -1737,6 +1737,70 @@ class TestGetGeneratorState(unittest.TestCase):
self.assertRaises(TypeError, inspect.getgeneratorlocals, (2,3))
+class TestGetCoroutineState(unittest.TestCase):
+
+ def setUp(self):
+ @types.coroutine
+ def number_coroutine():
+ for number in range(5):
+ yield number
+ async def coroutine():
+ await number_coroutine()
+ self.coroutine = coroutine()
+
+ def tearDown(self):
+ self.coroutine.close()
+
+ def _coroutinestate(self):
+ return inspect.getcoroutinestate(self.coroutine)
+
+ def test_created(self):
+ self.assertEqual(self._coroutinestate(), inspect.CORO_CREATED)
+
+ def test_suspended(self):
+ self.coroutine.send(None)
+ self.assertEqual(self._coroutinestate(), inspect.CORO_SUSPENDED)
+
+ def test_closed_after_exhaustion(self):
+ while True:
+ try:
+ self.coroutine.send(None)
+ except StopIteration:
+ break
+
+ self.assertEqual(self._coroutinestate(), inspect.CORO_CLOSED)
+
+ def test_closed_after_immediate_exception(self):
+ with self.assertRaises(RuntimeError):
+ self.coroutine.throw(RuntimeError)
+ self.assertEqual(self._coroutinestate(), inspect.CORO_CLOSED)
+
+ def test_easy_debugging(self):
+ # repr() and str() of a coroutine state should contain the state name
+ names = 'CORO_CREATED CORO_RUNNING CORO_SUSPENDED CORO_CLOSED'.split()
+ for name in names:
+ state = getattr(inspect, name)
+ self.assertIn(name, repr(state))
+ self.assertIn(name, str(state))
+
+ def test_getcoroutinelocals(self):
+ @types.coroutine
+ def gencoro():
+ yield
+
+ gencoro = gencoro()
+ async def func(a=None):
+ b = 'spam'
+ await gencoro
+
+ coro = func()
+ self.assertEqual(inspect.getcoroutinelocals(coro),
+ {'a': None, 'gencoro': gencoro})
+ coro.send(None)
+ self.assertEqual(inspect.getcoroutinelocals(coro),
+ {'a': None, 'gencoro': gencoro, 'b': 'spam'})
+
+
class MySignature(inspect.Signature):
# Top-level to make it picklable;
# used in test_signature_object_pickle
@@ -3494,7 +3558,8 @@ def test_main():
TestNoEOL, TestSignatureObject, TestSignatureBind, TestParameterObject,
TestBoundArguments, TestSignaturePrivateHelpers,
TestSignatureDefinitions,
- TestGetClosureVars, TestUnwrap, TestMain, TestReload
+ TestGetClosureVars, TestUnwrap, TestMain, TestReload,
+ TestGetCoroutineState
)
if __name__ == "__main__":
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
index 17ec645..5b971d1 100644
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -1205,10 +1205,28 @@ class CoroutineTests(unittest.TestCase):
def test_wrong_func(self):
@types.coroutine
def foo():
- pass
- with self.assertRaisesRegex(TypeError,
- 'callable wrapped .* non-coroutine'):
- foo()
+ return 'spam'
+ self.assertEqual(foo(), 'spam')
+
+ def test_async_def(self):
+ # Test that types.coroutine passes 'async def' coroutines
+ # without modification
+
+ async def foo(): pass
+ foo_code = foo.__code__
+ foo_flags = foo.__code__.co_flags
+ decorated_foo = types.coroutine(foo)
+ self.assertIs(foo, decorated_foo)
+ self.assertEqual(foo.__code__.co_flags, foo_flags)
+ self.assertIs(decorated_foo.__code__, foo_code)
+
+ foo_coro = foo()
+ @types.coroutine
+ def bar(): return foo_coro
+ coro = bar()
+ self.assertIs(foo_coro, coro)
+ self.assertEqual(coro.cr_code.co_flags, foo_flags)
+ coro.close()
def test_duck_coro(self):
class CoroLike:
@@ -1221,6 +1239,23 @@ class CoroutineTests(unittest.TestCase):
@types.coroutine
def foo():
return coro
+ self.assertIs(foo(), coro)
+ self.assertIs(foo().__await__(), coro)
+
+ def test_duck_corogen(self):
+ class CoroGenLike:
+ def send(self): pass
+ def throw(self): pass
+ def close(self): pass
+ def __await__(self): return self
+ def __iter__(self): return self
+ def __next__(self): pass
+
+ coro = CoroGenLike()
+ @types.coroutine
+ def foo():
+ return coro
+ self.assertIs(foo(), coro)
self.assertIs(foo().__await__(), coro)
def test_duck_gen(self):
@@ -1236,7 +1271,7 @@ class CoroutineTests(unittest.TestCase):
def foo():
return gen
self.assertIs(foo().__await__(), gen)
-
+ self.assertTrue(isinstance(foo(), collections.abc.Coroutine))
with self.assertRaises(AttributeError):
foo().gi_code
@@ -1251,6 +1286,7 @@ class CoroutineTests(unittest.TestCase):
'gi_running', 'gi_frame'):
self.assertIs(getattr(foo(), name),
getattr(gen, name))
+ self.assertIs(foo().cr_code, gen.gi_code)
def test_genfunc(self):
def gen():
@@ -1259,7 +1295,13 @@ class CoroutineTests(unittest.TestCase):
self.assertFalse(isinstance(gen(), collections.abc.Coroutine))
self.assertFalse(isinstance(gen(), collections.abc.Awaitable))
- self.assertIs(types.coroutine(gen), gen)
+ gen_code = gen.__code__
+ decorated_gen = types.coroutine(gen)
+ self.assertIs(decorated_gen, gen)
+ self.assertIsNot(decorated_gen.__code__, gen_code)
+
+ decorated_gen2 = types.coroutine(decorated_gen)
+ self.assertIs(decorated_gen2.__code__, decorated_gen.__code__)
self.assertTrue(gen.__code__.co_flags & inspect.CO_ITERABLE_COROUTINE)
self.assertFalse(gen.__code__.co_flags & inspect.CO_COROUTINE)
diff --git a/Lib/types.py b/Lib/types.py
index 0a87c2f..dc1b040 100644
--- a/Lib/types.py
+++ b/Lib/types.py
@@ -19,6 +19,11 @@ def _g():
yield 1
GeneratorType = type(_g())
+async def _c(): pass
+_c = _c()
+CoroutineType = type(_c)
+_c.close() # Prevent ResourceWarning
+
class _C:
def _m(self): pass
MethodType = type(_C()._m)
@@ -40,7 +45,7 @@ except TypeError:
GetSetDescriptorType = type(FunctionType.__code__)
MemberDescriptorType = type(FunctionType.__globals__)
-del sys, _f, _g, _C, # Not for export
+del sys, _f, _g, _C, _c, # Not for export
# Provide a PEP 3115 compliant mechanism for class creation
@@ -164,29 +169,33 @@ import collections.abc as _collections_abc
def coroutine(func):
"""Convert regular generator function to a coroutine."""
- # We don't want to import 'dis' or 'inspect' just for
- # these constants.
- CO_GENERATOR = 0x20
- CO_ITERABLE_COROUTINE = 0x100
-
if not callable(func):
raise TypeError('types.coroutine() expects a callable')
- if (isinstance(func, FunctionType) and
- isinstance(getattr(func, '__code__', None), CodeType) and
- (func.__code__.co_flags & CO_GENERATOR)):
-
- # TODO: Implement this in C.
- co = func.__code__
- func.__code__ = CodeType(
- co.co_argcount, co.co_kwonlyargcount, co.co_nlocals,
- co.co_stacksize,
- co.co_flags | CO_ITERABLE_COROUTINE,
- co.co_code,
- co.co_consts, co.co_names, co.co_varnames, co.co_filename,
- co.co_name, co.co_firstlineno, co.co_lnotab, co.co_freevars,
- co.co_cellvars)
- return func
+ if (func.__class__ is FunctionType and
+ getattr(func, '__code__', None).__class__ is CodeType):
+
+ co_flags = func.__code__.co_flags
+
+ # Check if 'func' is a coroutine function.
+ # (0x180 == CO_COROUTINE | CO_ITERABLE_COROUTINE)
+ if co_flags & 0x180:
+ return func
+
+ # Check if 'func' is a generator function.
+ # (0x20 == CO_GENERATOR)
+ if co_flags & 0x20:
+ # TODO: Implement this in C.
+ co = func.__code__
+ func.__code__ = CodeType(
+ co.co_argcount, co.co_kwonlyargcount, co.co_nlocals,
+ co.co_stacksize,
+ co.co_flags | 0x100, # 0x100 == CO_ITERABLE_COROUTINE
+ co.co_code,
+ co.co_consts, co.co_names, co.co_varnames, co.co_filename,
+ co.co_name, co.co_firstlineno, co.co_lnotab, co.co_freevars,
+ co.co_cellvars)
+ return func
# The following code is primarily to support functions that
# return generator-like objects (for instance generators
@@ -195,11 +204,14 @@ def coroutine(func):
class GeneratorWrapper:
def __init__(self, gen):
self.__wrapped__ = gen
- self.send = gen.send
- self.throw = gen.throw
- self.close = gen.close
self.__name__ = getattr(gen, '__name__', None)
self.__qualname__ = getattr(gen, '__qualname__', None)
+ def send(self, val):
+ return self.__wrapped__.send(val)
+ def throw(self, *args):
+ return self.__wrapped__.throw(*args)
+ def close(self):
+ return self.__wrapped__.close()
@property
def gi_code(self):
return self.__wrapped__.gi_code
@@ -209,24 +221,31 @@ def coroutine(func):
@property
def gi_running(self):
return self.__wrapped__.gi_running
+ cr_code = gi_code
+ cr_frame = gi_frame
+ cr_running = gi_running
def __next__(self):
return next(self.__wrapped__)
def __iter__(self):
return self.__wrapped__
- __await__ = __iter__
+ def __await__(self):
+ return self.__wrapped__
@_functools.wraps(func)
def wrapped(*args, **kwargs):
coro = func(*args, **kwargs)
- if coro.__class__ is GeneratorType:
+ if coro.__class__ is CoroutineType:
+ # 'coro' is a native coroutine object.
+ return coro
+ if (coro.__class__ is GeneratorType or
+ (isinstance(coro, _collections_abc.Generator) and
+ not isinstance(coro, _collections_abc.Coroutine))):
+ # 'coro' is either a pure Python generator iterator, or it
+ # implements collections.abc.Generator (and does not implement
+ # collections.abc.Coroutine).
return GeneratorWrapper(coro)
- # slow checks
- if not isinstance(coro, _collections_abc.Coroutine):
- if isinstance(coro, _collections_abc.Generator):
- return GeneratorWrapper(coro)
- raise TypeError(
- 'callable wrapped with types.coroutine() returned '
- 'non-coroutine: {!r}'.format(coro))
+ # 'coro' is either an instance of collections.abc.Coroutine or
+ # some other object -- pass it through.
return coro
return wrapped