diff options
author | Yury Selivanov <yselivanov@sprymix.com> | 2015-05-29 20:19:18 (GMT) |
---|---|---|
committer | Yury Selivanov <yselivanov@sprymix.com> | 2015-05-29 20:19:18 (GMT) |
commit | 13f7723d8163f29465817e5aec17e8394b2c28ef (patch) | |
tree | 5e5740a64dec8b384381812b4980f750aa835ef4 | |
parent | c565cd5d1b53f8b117f613517b71cf60813a2d65 (diff) | |
download | cpython-13f7723d8163f29465817e5aec17e8394b2c28ef.zip cpython-13f7723d8163f29465817e5aec17e8394b2c28ef.tar.gz cpython-13f7723d8163f29465817e5aec17e8394b2c28ef.tar.bz2 |
Issue 24316: Wrap gen objects returned from callables in types.coroutine
-rw-r--r-- | Lib/test/test_types.py | 45 | ||||
-rw-r--r-- | Lib/types.py | 40 |
2 files changed, 70 insertions, 15 deletions
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 956214d..17ec645 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -1206,28 +1206,51 @@ class CoroutineTests(unittest.TestCase): @types.coroutine def foo(): pass - @types.coroutine - def gen(): - def _gen(): yield - return _gen() - - for sample in (foo, gen): - with self.assertRaisesRegex(TypeError, - 'callable wrapped .* non-coroutine'): - sample() + with self.assertRaisesRegex(TypeError, + 'callable wrapped .* non-coroutine'): + foo() def test_duck_coro(self): class CoroLike: def send(self): pass def throw(self): pass def close(self): pass - def __await__(self): pass + def __await__(self): return self coro = CoroLike() @types.coroutine def foo(): return coro - self.assertIs(coro, foo()) + self.assertIs(foo().__await__(), coro) + + def test_duck_gen(self): + class GenLike: + def send(self): pass + def throw(self): pass + def close(self): pass + def __iter__(self): return self + def __next__(self): pass + + gen = GenLike() + @types.coroutine + def foo(): + return gen + self.assertIs(foo().__await__(), gen) + + with self.assertRaises(AttributeError): + foo().gi_code + + def test_gen(self): + def gen(): yield + gen = gen() + @types.coroutine + def foo(): return gen + self.assertIs(foo().__await__(), gen) + + for name in ('__name__', '__qualname__', 'gi_code', + 'gi_running', 'gi_frame'): + self.assertIs(getattr(foo(), name), + getattr(gen, name)) def test_genfunc(self): def gen(): diff --git a/Lib/types.py b/Lib/types.py index e9cc794..0a87c2f 100644 --- a/Lib/types.py +++ b/Lib/types.py @@ -166,32 +166,64 @@ def coroutine(func): # We don't want to import 'dis' or 'inspect' just for # these constants. - _CO_GENERATOR = 0x20 - _CO_ITERABLE_COROUTINE = 0x100 + CO_GENERATOR = 0x20 + CO_ITERABLE_COROUTINE = 0x100 if not callable(func): raise TypeError('types.coroutine() expects a callable') if (isinstance(func, FunctionType) and isinstance(getattr(func, '__code__', None), CodeType) and - (func.__code__.co_flags & _CO_GENERATOR)): + (func.__code__.co_flags & CO_GENERATOR)): # TODO: Implement this in C. co = func.__code__ func.__code__ = CodeType( co.co_argcount, co.co_kwonlyargcount, co.co_nlocals, co.co_stacksize, - co.co_flags | _CO_ITERABLE_COROUTINE, + co.co_flags | CO_ITERABLE_COROUTINE, co.co_code, co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name, co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars) return func + # The following code is primarily to support functions that + # return generator-like objects (for instance generators + # compiled with Cython). + + class GeneratorWrapper: + def __init__(self, gen): + self.__wrapped__ = gen + self.send = gen.send + self.throw = gen.throw + self.close = gen.close + self.__name__ = getattr(gen, '__name__', None) + self.__qualname__ = getattr(gen, '__qualname__', None) + @property + def gi_code(self): + return self.__wrapped__.gi_code + @property + def gi_frame(self): + return self.__wrapped__.gi_frame + @property + def gi_running(self): + return self.__wrapped__.gi_running + def __next__(self): + return next(self.__wrapped__) + def __iter__(self): + return self.__wrapped__ + __await__ = __iter__ + @_functools.wraps(func) def wrapped(*args, **kwargs): coro = func(*args, **kwargs) + if coro.__class__ is GeneratorType: + return GeneratorWrapper(coro) + # slow checks if not isinstance(coro, _collections_abc.Coroutine): + if isinstance(coro, _collections_abc.Generator): + return GeneratorWrapper(coro) raise TypeError( 'callable wrapped with types.coroutine() returned ' 'non-coroutine: {!r}'.format(coro)) |