diff options
author | Yury Selivanov <yselivanov@sprymix.com> | 2015-05-29 13:06:05 (GMT) |
---|---|---|
committer | Yury Selivanov <yselivanov@sprymix.com> | 2015-05-29 13:06:05 (GMT) |
commit | c565cd5d1b53f8b117f613517b71cf60813a2d65 (patch) | |
tree | b52c7a7e71373712789cb004838fbdb0b9fabc91 | |
parent | 56fc61402533dc550244efe3e860242872f35bad (diff) | |
download | cpython-c565cd5d1b53f8b117f613517b71cf60813a2d65.zip cpython-c565cd5d1b53f8b117f613517b71cf60813a2d65.tar.gz cpython-c565cd5d1b53f8b117f613517b71cf60813a2d65.tar.bz2 |
Issue 24316: Fix types.coroutine() to accept objects from Cython
-rw-r--r-- | Lib/test/test_types.py | 32 | ||||
-rw-r--r-- | Lib/types.py | 66 |
2 files changed, 72 insertions, 26 deletions
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index ccaf414..956214d 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -1196,11 +1196,39 @@ class CoroutineTests(unittest.TestCase): pass def bar(): pass - samples = [Foo, Foo(), bar, None, int, 1] + samples = [None, 1, object()] for sample in samples: - with self.assertRaisesRegex(TypeError, 'expects a generator'): + with self.assertRaisesRegex(TypeError, + 'types.coroutine.*expects a callable'): types.coroutine(sample) + def test_wrong_func(self): + @types.coroutine + def foo(): + pass + @types.coroutine + def gen(): + def _gen(): yield + return _gen() + + for sample in (foo, gen): + with self.assertRaisesRegex(TypeError, + 'callable wrapped .* non-coroutine'): + sample() + + def test_duck_coro(self): + class CoroLike: + def send(self): pass + def throw(self): pass + def close(self): pass + def __await__(self): pass + + coro = CoroLike() + @types.coroutine + def foo(): + return coro + self.assertIs(coro, foo()) + def test_genfunc(self): def gen(): yield diff --git a/Lib/types.py b/Lib/types.py index 49e4d04..e9cc794 100644 --- a/Lib/types.py +++ b/Lib/types.py @@ -43,30 +43,6 @@ MemberDescriptorType = type(FunctionType.__globals__) del sys, _f, _g, _C, # Not for export -_CO_GENERATOR = 0x20 -_CO_ITERABLE_COROUTINE = 0x100 - -def coroutine(func): - """Convert regular generator function to a coroutine.""" - - # TODO: Implement this in C. - - if (not isinstance(func, (FunctionType, MethodType)) or - not isinstance(getattr(func, '__code__', None), CodeType) or - not (func.__code__.co_flags & _CO_GENERATOR)): - raise TypeError('coroutine() expects a generator function') - - co = func.__code__ - func.__code__ = CodeType( - co.co_argcount, co.co_kwonlyargcount, co.co_nlocals, co.co_stacksize, - co.co_flags | _CO_ITERABLE_COROUTINE, - co.co_code, - co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name, - co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars) - - return func - - # Provide a PEP 3115 compliant mechanism for class creation def new_class(name, bases=(), kwds=None, exec_body=None): """Create a class object dynamically using the appropriate metaclass.""" @@ -182,4 +158,46 @@ class DynamicClassAttribute: return result +import functools as _functools +import collections.abc as _collections_abc + +def coroutine(func): + """Convert regular generator function to a coroutine.""" + + # We don't want to import 'dis' or 'inspect' just for + # these constants. + _CO_GENERATOR = 0x20 + _CO_ITERABLE_COROUTINE = 0x100 + + if not callable(func): + raise TypeError('types.coroutine() expects a callable') + + if (isinstance(func, FunctionType) and + isinstance(getattr(func, '__code__', None), CodeType) and + (func.__code__.co_flags & _CO_GENERATOR)): + + # TODO: Implement this in C. + co = func.__code__ + func.__code__ = CodeType( + co.co_argcount, co.co_kwonlyargcount, co.co_nlocals, + co.co_stacksize, + co.co_flags | _CO_ITERABLE_COROUTINE, + co.co_code, + co.co_consts, co.co_names, co.co_varnames, co.co_filename, + co.co_name, co.co_firstlineno, co.co_lnotab, co.co_freevars, + co.co_cellvars) + return func + + @_functools.wraps(func) + def wrapped(*args, **kwargs): + coro = func(*args, **kwargs) + if not isinstance(coro, _collections_abc.Coroutine): + raise TypeError( + 'callable wrapped with types.coroutine() returned ' + 'non-coroutine: {!r}'.format(coro)) + return coro + + return wrapped + + __all__ = [n for n in globals() if n[:1] != '_'] |