summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/asyncio/base_events.py44
-rw-r--r--Lib/asyncio/events.py10
-rw-r--r--Lib/asyncio/queues.py7
-rw-r--r--Lib/selectors.py49
-rw-r--r--Lib/test/test_asyncio/test_base_events.py36
5 files changed, 110 insertions, 36 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index eb867cd..efbb9f4 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -197,6 +197,7 @@ class BaseEventLoop(events.AbstractEventLoop):
# exceed this duration in seconds, the slow callback/task is logged.
self.slow_callback_duration = 0.1
self._current_handle = None
+ self._task_factory = None
def __repr__(self):
return ('<%s running=%s closed=%s debug=%s>'
@@ -209,11 +210,32 @@ class BaseEventLoop(events.AbstractEventLoop):
Return a task object.
"""
self._check_closed()
- task = tasks.Task(coro, loop=self)
- if task._source_traceback:
- del task._source_traceback[-1]
+ if self._task_factory is None:
+ task = tasks.Task(coro, loop=self)
+ if task._source_traceback:
+ del task._source_traceback[-1]
+ else:
+ task = self._task_factory(self, coro)
return task
+ def set_task_factory(self, factory):
+ """Set a task factory that will be used by loop.create_task().
+
+ If factory is None the default task factory will be set.
+
+ If factory is a callable, it should have a signature matching
+ '(loop, coro)', where 'loop' will be a reference to the active
+ event loop, 'coro' will be a coroutine object. The callable
+ must return a Future.
+ """
+ if factory is not None and not callable(factory):
+ raise TypeError('task factory must be a callable or None')
+ self._task_factory = factory
+
+ def get_task_factory(self):
+ """Return a task factory, or None if the default one is in use."""
+ return self._task_factory
+
def _make_socket_transport(self, sock, protocol, waiter=None, *,
extra=None, server=None):
"""Create socket transport."""
@@ -465,25 +487,25 @@ class BaseEventLoop(events.AbstractEventLoop):
self._write_to_self()
return handle
- def run_in_executor(self, executor, callback, *args):
- if (coroutines.iscoroutine(callback)
- or coroutines.iscoroutinefunction(callback)):
+ def run_in_executor(self, executor, func, *args):
+ if (coroutines.iscoroutine(func)
+ or coroutines.iscoroutinefunction(func)):
raise TypeError("coroutines cannot be used with run_in_executor()")
self._check_closed()
- if isinstance(callback, events.Handle):
+ if isinstance(func, events.Handle):
assert not args
- assert not isinstance(callback, events.TimerHandle)
- if callback._cancelled:
+ assert not isinstance(func, events.TimerHandle)
+ if func._cancelled:
f = futures.Future(loop=self)
f.set_result(None)
return f
- callback, args = callback._callback, callback._args
+ func, args = func._callback, func._args
if executor is None:
executor = self._default_executor
if executor is None:
executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS)
self._default_executor = executor
- return futures.wrap_future(executor.submit(callback, *args), loop=self)
+ return futures.wrap_future(executor.submit(func, *args), loop=self)
def set_default_executor(self, executor):
self._default_executor = executor
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index 3b907c6..496075b 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -277,7 +277,7 @@ class AbstractEventLoop:
def call_soon_threadsafe(self, callback, *args):
raise NotImplementedError
- def run_in_executor(self, executor, callback, *args):
+ def run_in_executor(self, executor, func, *args):
raise NotImplementedError
def set_default_executor(self, executor):
@@ -438,6 +438,14 @@ class AbstractEventLoop:
def remove_signal_handler(self, sig):
raise NotImplementedError
+ # Task factory.
+
+ def set_task_factory(self, factory):
+ raise NotImplementedError
+
+ def get_task_factory(self):
+ raise NotImplementedError
+
# Error handlers.
def set_exception_handler(self, handler):
diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py
index 50543c8..ed11662 100644
--- a/Lib/asyncio/queues.py
+++ b/Lib/asyncio/queues.py
@@ -1,6 +1,7 @@
"""Queues"""
-__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
+__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty',
+ 'JoinableQueue']
import collections
import heapq
@@ -286,3 +287,7 @@ class LifoQueue(Queue):
def _get(self):
return self._queue.pop()
+
+
+JoinableQueue = Queue
+"""Deprecated alias for Queue."""
diff --git a/Lib/selectors.py b/Lib/selectors.py
index e17ea36..6d569c3 100644
--- a/Lib/selectors.py
+++ b/Lib/selectors.py
@@ -310,7 +310,10 @@ class SelectSelector(_BaseSelectorImpl):
def select(self, timeout=None):
timeout = None if timeout is None else max(timeout, 0)
ready = []
- r, w, _ = self._select(self._readers, self._writers, [], timeout)
+ try:
+ r, w, _ = self._select(self._readers, self._writers, [], timeout)
+ except InterruptedError:
+ return ready
r = set(r)
w = set(w)
for fd in r | w:
@@ -359,10 +362,11 @@ if hasattr(select, 'poll'):
# poll() has a resolution of 1 millisecond, round away from
# zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3)
-
- fd_event_list = self._poll.poll(timeout)
-
ready = []
+ try:
+ fd_event_list = self._poll.poll(timeout)
+ except InterruptedError:
+ return ready
for fd, event in fd_event_list:
events = 0
if event & ~select.POLLIN:
@@ -423,9 +427,11 @@ if hasattr(select, 'epoll'):
# FD is registered.
max_ev = max(len(self._fd_to_key), 1)
- fd_event_list = self._epoll.poll(timeout, max_ev)
-
ready = []
+ try:
+ fd_event_list = self._epoll.poll(timeout, max_ev)
+ except InterruptedError:
+ return ready
for fd, event in fd_event_list:
events = 0
if event & ~select.EPOLLIN:
@@ -439,10 +445,8 @@ if hasattr(select, 'epoll'):
return ready
def close(self):
- try:
- self._epoll.close()
- finally:
- super().close()
+ self._epoll.close()
+ super().close()
if hasattr(select, 'devpoll'):
@@ -481,10 +485,11 @@ if hasattr(select, 'devpoll'):
# devpoll() has a resolution of 1 millisecond, round away from
# zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3)
-
- fd_event_list = self._devpoll.poll(timeout)
-
ready = []
+ try:
+ fd_event_list = self._devpoll.poll(timeout)
+ except InterruptedError:
+ return ready
for fd, event in fd_event_list:
events = 0
if event & ~select.POLLIN:
@@ -498,10 +503,8 @@ if hasattr(select, 'devpoll'):
return ready
def close(self):
- try:
- self._devpoll.close()
- finally:
- super().close()
+ self._devpoll.close()
+ super().close()
if hasattr(select, 'kqueue'):
@@ -552,9 +555,11 @@ if hasattr(select, 'kqueue'):
def select(self, timeout=None):
timeout = None if timeout is None else max(timeout, 0)
max_ev = len(self._fd_to_key)
- kev_list = self._kqueue.control(None, max_ev, timeout)
-
ready = []
+ try:
+ kev_list = self._kqueue.control(None, max_ev, timeout)
+ except InterruptedError:
+ return ready
for kev in kev_list:
fd = kev.ident
flag = kev.filter
@@ -570,10 +575,8 @@ if hasattr(select, 'kqueue'):
return ready
def close(self):
- try:
- self._kqueue.close()
- finally:
- super().close()
+ self._kqueue.close()
+ super().close()
# Choose the best implementation, roughly:
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
index 4d36f23..aaa8e67 100644
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -623,6 +623,42 @@ class BaseEventLoopTests(test_utils.TestCase):
self.assertIs(type(_context['context']['exception']),
ZeroDivisionError)
+ def test_set_task_factory_invalid(self):
+ with self.assertRaisesRegex(
+ TypeError, 'task factory must be a callable or None'):
+
+ self.loop.set_task_factory(1)
+
+ self.assertIsNone(self.loop.get_task_factory())
+
+ def test_set_task_factory(self):
+ self.loop._process_events = mock.Mock()
+
+ class MyTask(asyncio.Task):
+ pass
+
+ @asyncio.coroutine
+ def coro():
+ pass
+
+ factory = lambda loop, coro: MyTask(coro, loop=loop)
+
+ self.assertIsNone(self.loop.get_task_factory())
+ self.loop.set_task_factory(factory)
+ self.assertIs(self.loop.get_task_factory(), factory)
+
+ task = self.loop.create_task(coro())
+ self.assertTrue(isinstance(task, MyTask))
+ self.loop.run_until_complete(task)
+
+ self.loop.set_task_factory(None)
+ self.assertIsNone(self.loop.get_task_factory())
+
+ task = self.loop.create_task(coro())
+ self.assertTrue(isinstance(task, asyncio.Task))
+ self.assertFalse(isinstance(task, MyTask))
+ self.loop.run_until_complete(task)
+
def test_env_var_debug(self):
code = '\n'.join((
'import asyncio',