From 90ecfe65e681f6e7f901a101ad1e549e339ea10d Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 11 May 2015 13:48:16 -0400 Subject: asyncio: Sync with github repo --- Lib/asyncio/base_events.py | 44 ++++++++++++++++++++------- Lib/asyncio/events.py | 10 ++++++- Lib/asyncio/queues.py | 7 ++++- Lib/selectors.py | 49 ++++++++++++++++--------------- Lib/test/test_asyncio/test_base_events.py | 36 +++++++++++++++++++++++ 5 files changed, 110 insertions(+), 36 deletions(-) diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index eb867cd..efbb9f4 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -197,6 +197,7 @@ class BaseEventLoop(events.AbstractEventLoop): # exceed this duration in seconds, the slow callback/task is logged. self.slow_callback_duration = 0.1 self._current_handle = None + self._task_factory = None def __repr__(self): return ('<%s running=%s closed=%s debug=%s>' @@ -209,11 +210,32 @@ class BaseEventLoop(events.AbstractEventLoop): Return a task object. """ self._check_closed() - task = tasks.Task(coro, loop=self) - if task._source_traceback: - del task._source_traceback[-1] + if self._task_factory is None: + task = tasks.Task(coro, loop=self) + if task._source_traceback: + del task._source_traceback[-1] + else: + task = self._task_factory(self, coro) return task + def set_task_factory(self, factory): + """Set a task factory that will be used by loop.create_task(). + + If factory is None the default task factory will be set. + + If factory is a callable, it should have a signature matching + '(loop, coro)', where 'loop' will be a reference to the active + event loop, 'coro' will be a coroutine object. The callable + must return a Future. + """ + if factory is not None and not callable(factory): + raise TypeError('task factory must be a callable or None') + self._task_factory = factory + + def get_task_factory(self): + """Return a task factory, or None if the default one is in use.""" + return self._task_factory + def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): """Create socket transport.""" @@ -465,25 +487,25 @@ class BaseEventLoop(events.AbstractEventLoop): self._write_to_self() return handle - def run_in_executor(self, executor, callback, *args): - if (coroutines.iscoroutine(callback) - or coroutines.iscoroutinefunction(callback)): + def run_in_executor(self, executor, func, *args): + if (coroutines.iscoroutine(func) + or coroutines.iscoroutinefunction(func)): raise TypeError("coroutines cannot be used with run_in_executor()") self._check_closed() - if isinstance(callback, events.Handle): + if isinstance(func, events.Handle): assert not args - assert not isinstance(callback, events.TimerHandle) - if callback._cancelled: + assert not isinstance(func, events.TimerHandle) + if func._cancelled: f = futures.Future(loop=self) f.set_result(None) return f - callback, args = callback._callback, callback._args + func, args = func._callback, func._args if executor is None: executor = self._default_executor if executor is None: executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) self._default_executor = executor - return futures.wrap_future(executor.submit(callback, *args), loop=self) + return futures.wrap_future(executor.submit(func, *args), loop=self) def set_default_executor(self, executor): self._default_executor = executor diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index 3b907c6..496075b 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -277,7 +277,7 @@ class AbstractEventLoop: def call_soon_threadsafe(self, callback, *args): raise NotImplementedError - def run_in_executor(self, executor, callback, *args): + def run_in_executor(self, executor, func, *args): raise NotImplementedError def set_default_executor(self, executor): @@ -438,6 +438,14 @@ class AbstractEventLoop: def remove_signal_handler(self, sig): raise NotImplementedError + # Task factory. + + def set_task_factory(self, factory): + raise NotImplementedError + + def get_task_factory(self): + raise NotImplementedError + # Error handlers. def set_exception_handler(self, handler): diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py index 50543c8..ed11662 100644 --- a/Lib/asyncio/queues.py +++ b/Lib/asyncio/queues.py @@ -1,6 +1,7 @@ """Queues""" -__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty', + 'JoinableQueue'] import collections import heapq @@ -286,3 +287,7 @@ class LifoQueue(Queue): def _get(self): return self._queue.pop() + + +JoinableQueue = Queue +"""Deprecated alias for Queue.""" diff --git a/Lib/selectors.py b/Lib/selectors.py index e17ea36..6d569c3 100644 --- a/Lib/selectors.py +++ b/Lib/selectors.py @@ -310,7 +310,10 @@ class SelectSelector(_BaseSelectorImpl): def select(self, timeout=None): timeout = None if timeout is None else max(timeout, 0) ready = [] - r, w, _ = self._select(self._readers, self._writers, [], timeout) + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + return ready r = set(r) w = set(w) for fd in r | w: @@ -359,10 +362,11 @@ if hasattr(select, 'poll'): # poll() has a resolution of 1 millisecond, round away from # zero to wait *at least* timeout seconds. timeout = math.ceil(timeout * 1e3) - - fd_event_list = self._poll.poll(timeout) - ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + return ready for fd, event in fd_event_list: events = 0 if event & ~select.POLLIN: @@ -423,9 +427,11 @@ if hasattr(select, 'epoll'): # FD is registered. max_ev = max(len(self._fd_to_key), 1) - fd_event_list = self._epoll.poll(timeout, max_ev) - ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + return ready for fd, event in fd_event_list: events = 0 if event & ~select.EPOLLIN: @@ -439,10 +445,8 @@ if hasattr(select, 'epoll'): return ready def close(self): - try: - self._epoll.close() - finally: - super().close() + self._epoll.close() + super().close() if hasattr(select, 'devpoll'): @@ -481,10 +485,11 @@ if hasattr(select, 'devpoll'): # devpoll() has a resolution of 1 millisecond, round away from # zero to wait *at least* timeout seconds. timeout = math.ceil(timeout * 1e3) - - fd_event_list = self._devpoll.poll(timeout) - ready = [] + try: + fd_event_list = self._devpoll.poll(timeout) + except InterruptedError: + return ready for fd, event in fd_event_list: events = 0 if event & ~select.POLLIN: @@ -498,10 +503,8 @@ if hasattr(select, 'devpoll'): return ready def close(self): - try: - self._devpoll.close() - finally: - super().close() + self._devpoll.close() + super().close() if hasattr(select, 'kqueue'): @@ -552,9 +555,11 @@ if hasattr(select, 'kqueue'): def select(self, timeout=None): timeout = None if timeout is None else max(timeout, 0) max_ev = len(self._fd_to_key) - kev_list = self._kqueue.control(None, max_ev, timeout) - ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + return ready for kev in kev_list: fd = kev.ident flag = kev.filter @@ -570,10 +575,8 @@ if hasattr(select, 'kqueue'): return ready def close(self): - try: - self._kqueue.close() - finally: - super().close() + self._kqueue.close() + super().close() # Choose the best implementation, roughly: diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 4d36f23..aaa8e67 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -623,6 +623,42 @@ class BaseEventLoopTests(test_utils.TestCase): self.assertIs(type(_context['context']['exception']), ZeroDivisionError) + def test_set_task_factory_invalid(self): + with self.assertRaisesRegex( + TypeError, 'task factory must be a callable or None'): + + self.loop.set_task_factory(1) + + self.assertIsNone(self.loop.get_task_factory()) + + def test_set_task_factory(self): + self.loop._process_events = mock.Mock() + + class MyTask(asyncio.Task): + pass + + @asyncio.coroutine + def coro(): + pass + + factory = lambda loop, coro: MyTask(coro, loop=loop) + + self.assertIsNone(self.loop.get_task_factory()) + self.loop.set_task_factory(factory) + self.assertIs(self.loop.get_task_factory(), factory) + + task = self.loop.create_task(coro()) + self.assertTrue(isinstance(task, MyTask)) + self.loop.run_until_complete(task) + + self.loop.set_task_factory(None) + self.assertIsNone(self.loop.get_task_factory()) + + task = self.loop.create_task(coro()) + self.assertTrue(isinstance(task, asyncio.Task)) + self.assertFalse(isinstance(task, MyTask)) + self.loop.run_until_complete(task) + def test_env_var_debug(self): code = '\n'.join(( 'import asyncio', -- cgit v0.12