summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYury Selivanov <yselivanov@sprymix.com>2015-05-29 13:06:05 (GMT)
committerYury Selivanov <yselivanov@sprymix.com>2015-05-29 13:06:05 (GMT)
commitc565cd5d1b53f8b117f613517b71cf60813a2d65 (patch)
treeb52c7a7e71373712789cb004838fbdb0b9fabc91
parent56fc61402533dc550244efe3e860242872f35bad (diff)
downloadcpython-c565cd5d1b53f8b117f613517b71cf60813a2d65.zip
cpython-c565cd5d1b53f8b117f613517b71cf60813a2d65.tar.gz
cpython-c565cd5d1b53f8b117f613517b71cf60813a2d65.tar.bz2
Issue 24316: Fix types.coroutine() to accept objects from Cython
-rw-r--r--Lib/test/test_types.py32
-rw-r--r--Lib/types.py66
2 files changed, 72 insertions, 26 deletions
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
index ccaf414..956214d 100644
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -1196,11 +1196,39 @@ class CoroutineTests(unittest.TestCase):
pass
def bar(): pass
- samples = [Foo, Foo(), bar, None, int, 1]
+ samples = [None, 1, object()]
for sample in samples:
- with self.assertRaisesRegex(TypeError, 'expects a generator'):
+ with self.assertRaisesRegex(TypeError,
+ 'types.coroutine.*expects a callable'):
types.coroutine(sample)
+ def test_wrong_func(self):
+ @types.coroutine
+ def foo():
+ pass
+ @types.coroutine
+ def gen():
+ def _gen(): yield
+ return _gen()
+
+ for sample in (foo, gen):
+ with self.assertRaisesRegex(TypeError,
+ 'callable wrapped .* non-coroutine'):
+ sample()
+
+ def test_duck_coro(self):
+ class CoroLike:
+ def send(self): pass
+ def throw(self): pass
+ def close(self): pass
+ def __await__(self): pass
+
+ coro = CoroLike()
+ @types.coroutine
+ def foo():
+ return coro
+ self.assertIs(coro, foo())
+
def test_genfunc(self):
def gen():
yield
diff --git a/Lib/types.py b/Lib/types.py
index 49e4d04..e9cc794 100644
--- a/Lib/types.py
+++ b/Lib/types.py
@@ -43,30 +43,6 @@ MemberDescriptorType = type(FunctionType.__globals__)
del sys, _f, _g, _C, # Not for export
-_CO_GENERATOR = 0x20
-_CO_ITERABLE_COROUTINE = 0x100
-
-def coroutine(func):
- """Convert regular generator function to a coroutine."""
-
- # TODO: Implement this in C.
-
- if (not isinstance(func, (FunctionType, MethodType)) or
- not isinstance(getattr(func, '__code__', None), CodeType) or
- not (func.__code__.co_flags & _CO_GENERATOR)):
- raise TypeError('coroutine() expects a generator function')
-
- co = func.__code__
- func.__code__ = CodeType(
- co.co_argcount, co.co_kwonlyargcount, co.co_nlocals, co.co_stacksize,
- co.co_flags | _CO_ITERABLE_COROUTINE,
- co.co_code,
- co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
- co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
-
- return func
-
-
# Provide a PEP 3115 compliant mechanism for class creation
def new_class(name, bases=(), kwds=None, exec_body=None):
"""Create a class object dynamically using the appropriate metaclass."""
@@ -182,4 +158,46 @@ class DynamicClassAttribute:
return result
+import functools as _functools
+import collections.abc as _collections_abc
+
+def coroutine(func):
+ """Convert regular generator function to a coroutine."""
+
+ # We don't want to import 'dis' or 'inspect' just for
+ # these constants.
+ _CO_GENERATOR = 0x20
+ _CO_ITERABLE_COROUTINE = 0x100
+
+ if not callable(func):
+ raise TypeError('types.coroutine() expects a callable')
+
+ if (isinstance(func, FunctionType) and
+ isinstance(getattr(func, '__code__', None), CodeType) and
+ (func.__code__.co_flags & _CO_GENERATOR)):
+
+ # TODO: Implement this in C.
+ co = func.__code__
+ func.__code__ = CodeType(
+ co.co_argcount, co.co_kwonlyargcount, co.co_nlocals,
+ co.co_stacksize,
+ co.co_flags | _CO_ITERABLE_COROUTINE,
+ co.co_code,
+ co.co_consts, co.co_names, co.co_varnames, co.co_filename,
+ co.co_name, co.co_firstlineno, co.co_lnotab, co.co_freevars,
+ co.co_cellvars)
+ return func
+
+ @_functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ coro = func(*args, **kwargs)
+ if not isinstance(coro, _collections_abc.Coroutine):
+ raise TypeError(
+ 'callable wrapped with types.coroutine() returned '
+ 'non-coroutine: {!r}'.format(coro))
+ return coro
+
+ return wrapped
+
+
__all__ = [n for n in globals() if n[:1] != '_']