summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio/base_events.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/asyncio/base_events.py')
-rw-r--r--Lib/asyncio/base_events.py295
1 files changed, 193 insertions, 102 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 4505732..b3e318e 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -13,13 +13,10 @@ conscious design decision, leaving the door open for keyword arguments
to modify the meaning of the API call itself.
"""
-
import collections
import concurrent.futures
-import functools
import heapq
import inspect
-import ipaddress
import itertools
import logging
import os
@@ -30,6 +27,7 @@ import time
import traceback
import sys
import warnings
+import weakref
from . import compat
from . import coroutines
@@ -43,9 +41,6 @@ from .log import logger
__all__ = ['BaseEventLoop']
-# Argument for default thread pool executor creation.
-_MAX_WORKERS = 5
-
# Minimum number of _scheduled timer handles before cleanup of
# cancelled handles is performed.
_MIN_SCHEDULED_TIMER_HANDLES = 100
@@ -54,6 +49,12 @@ _MIN_SCHEDULED_TIMER_HANDLES = 100
# before cleanup of cancelled handles is performed.
_MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5
+# Exceptions which must not call the exception handler in fatal error
+# methods (_fatal_error())
+_FATAL_ERROR_IGNORE = (BrokenPipeError,
+ ConnectionResetError, ConnectionAbortedError)
+
+
def _format_handle(handle):
cb = handle._callback
if inspect.ismethod(cb) and isinstance(cb.__self__, tasks.Task):
@@ -72,6 +73,17 @@ def _format_pipe(fd):
return repr(fd)
+def _set_reuseport(sock):
+ if not hasattr(socket, 'SO_REUSEPORT'):
+ raise ValueError('reuse_port not supported by socket module')
+ else:
+ try:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ except OSError:
+ raise ValueError('reuse_port not supported by socket module, '
+ 'SO_REUSEPORT defined but not implemented.')
+
+
# Linux's sock.type is a bitmask that can include extra info about socket.
_SOCKET_TYPE_MASK = 0
if hasattr(socket, 'SOCK_NONBLOCK'):
@@ -80,12 +92,14 @@ if hasattr(socket, 'SOCK_CLOEXEC'):
_SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC
-@functools.lru_cache(maxsize=1024)
def _ipaddr_info(host, port, family, type, proto):
- # Try to skip getaddrinfo if "host" is already an IP. Since getaddrinfo
- # blocks on an exclusive lock on some platforms, users might handle name
- # resolution in their own code and pass in resolved IPs.
- if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or host is None:
+ # Try to skip getaddrinfo if "host" is already an IP. Users might have
+ # handled name resolution in their own code and pass in resolved IPs.
+ if not hasattr(socket, 'inet_pton'):
+ return
+
+ if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or \
+ host is None:
return None
type &= ~_SOCKET_TYPE_MASK
@@ -96,59 +110,55 @@ def _ipaddr_info(host, port, family, type, proto):
else:
return None
- if hasattr(socket, 'inet_pton'):
- if family == socket.AF_UNSPEC:
- afs = [socket.AF_INET, socket.AF_INET6]
- else:
- afs = [family]
-
- for af in afs:
- # Linux's inet_pton doesn't accept an IPv6 zone index after host,
- # like '::1%lo0', so strip it. If we happen to make an invalid
- # address look valid, we fail later in sock.connect or sock.bind.
- try:
- if af == socket.AF_INET6:
- socket.inet_pton(af, host.partition('%')[0])
- else:
- socket.inet_pton(af, host)
- return af, type, proto, '', (host, port)
- except OSError:
- pass
-
- # "host" is not an IP address.
- return None
-
- # No inet_pton. (On Windows it's only available since Python 3.4.)
- # Even though getaddrinfo with AI_NUMERICHOST would be non-blocking, it
- # still requires a lock on some platforms, and waiting for that lock could
- # block the event loop. Use ipaddress instead, it's just text parsing.
- try:
- addr = ipaddress.IPv4Address(host)
- except ValueError:
+ if port is None:
+ port = 0
+ elif isinstance(port, bytes) and port == b'':
+ port = 0
+ elif isinstance(port, str) and port == '':
+ port = 0
+ else:
+ # If port's a service name like "http", don't skip getaddrinfo.
try:
- addr = ipaddress.IPv6Address(host.partition('%')[0])
- except ValueError:
+ port = int(port)
+ except (TypeError, ValueError):
return None
- af = socket.AF_INET if addr.version == 4 else socket.AF_INET6
- if family not in (socket.AF_UNSPEC, af):
- # "host" is wrong IP version for "family".
- return None
+ if family == socket.AF_UNSPEC:
+ afs = [socket.AF_INET, socket.AF_INET6]
+ else:
+ afs = [family]
- return af, type, proto, '', (host, port)
+ if isinstance(host, bytes):
+ host = host.decode('idna')
+ if '%' in host:
+ # Linux's inet_pton doesn't accept an IPv6 zone index after host,
+ # like '::1%lo0'.
+ return None
+ for af in afs:
+ try:
+ socket.inet_pton(af, host)
+ # The host has already been resolved.
+ return af, type, proto, '', (host, port)
+ except OSError:
+ pass
-def _check_resolved_address(sock, address):
- # Ensure that the address is already resolved to avoid the trap of hanging
- # the entire event loop when the address requires doing a DNS lookup.
+ # "host" is not an IP address.
+ return None
- if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
- return
+def _ensure_resolved(address, *, family=0, type=socket.SOCK_STREAM, proto=0,
+ flags=0, loop):
host, port = address[:2]
- if _ipaddr_info(host, port, sock.family, sock.type, sock.proto) is None:
- raise ValueError("address must be resolved (IP address),"
- " got host %r" % host)
+ info = _ipaddr_info(host, port, family, type, proto)
+ if info is not None:
+ # "host" is already a resolved IP.
+ fut = loop.create_future()
+ fut.set_result([info])
+ return fut
+ else:
+ return loop.getaddrinfo(host, port, family=family, type=type,
+ proto=proto, flags=flags)
def _run_until_complete_cb(fut):
@@ -203,7 +213,7 @@ class Server(events.AbstractServer):
def wait_closed(self):
if self.sockets is None or self._waiters is None:
return
- waiter = futures.Future(loop=self._loop)
+ waiter = self._loop.create_future()
self._waiters.append(waiter)
yield from waiter
@@ -232,11 +242,26 @@ class BaseEventLoop(events.AbstractEventLoop):
self._task_factory = None
self._coroutine_wrapper_set = False
+ if hasattr(sys, 'get_asyncgen_hooks'):
+ # Python >= 3.6
+ # A weak set of all asynchronous generators that are
+ # being iterated by the loop.
+ self._asyncgens = weakref.WeakSet()
+ else:
+ self._asyncgens = None
+
+ # Set to True when `loop.shutdown_asyncgens` is called.
+ self._asyncgens_shutdown_called = False
+
def __repr__(self):
return ('<%s running=%s closed=%s debug=%s>'
% (self.__class__.__name__, self.is_running(),
self.is_closed(), self.get_debug()))
+ def create_future(self):
+ """Create a Future object attached to the loop."""
+ return futures.Future(loop=self)
+
def create_task(self, coro):
"""Schedule a coroutine object.
@@ -319,6 +344,48 @@ class BaseEventLoop(events.AbstractEventLoop):
if self._closed:
raise RuntimeError('Event loop is closed')
+ def _asyncgen_finalizer_hook(self, agen):
+ self._asyncgens.discard(agen)
+ if not self.is_closed():
+ self.create_task(agen.aclose())
+
+ def _asyncgen_firstiter_hook(self, agen):
+ if self._asyncgens_shutdown_called:
+ warnings.warn(
+ "asynchronous generator {!r} was scheduled after "
+ "loop.shutdown_asyncgens() call".format(agen),
+ ResourceWarning, source=self)
+
+ self._asyncgens.add(agen)
+
+ @coroutine
+ def shutdown_asyncgens(self):
+ """Shutdown all active asynchronous generators."""
+ self._asyncgens_shutdown_called = True
+
+ if self._asyncgens is None or not len(self._asyncgens):
+ # If Python version is <3.6 or we don't have any asynchronous
+ # generators alive.
+ return
+
+ closing_agens = list(self._asyncgens)
+ self._asyncgens.clear()
+
+ shutdown_coro = tasks.gather(
+ *[ag.aclose() for ag in closing_agens],
+ return_exceptions=True,
+ loop=self)
+
+ results = yield from shutdown_coro
+ for result, agen in zip(results, closing_agens):
+ if isinstance(result, Exception):
+ self.call_exception_handler({
+ 'message': 'an error occurred during closing of '
+ 'asynchronous generator {!r}'.format(agen),
+ 'exception': result,
+ 'asyncgen': agen
+ })
+
def run_forever(self):
"""Run until stop() is called."""
self._check_closed()
@@ -326,6 +393,10 @@ class BaseEventLoop(events.AbstractEventLoop):
raise RuntimeError('Event loop is running.')
self._set_coroutine_wrapper(self._debug)
self._thread_id = threading.get_ident()
+ if self._asyncgens is not None:
+ old_agen_hooks = sys.get_asyncgen_hooks()
+ sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
+ finalizer=self._asyncgen_finalizer_hook)
try:
while True:
self._run_once()
@@ -335,6 +406,8 @@ class BaseEventLoop(events.AbstractEventLoop):
self._stopping = False
self._thread_id = None
self._set_coroutine_wrapper(False)
+ if self._asyncgens is not None:
+ sys.set_asyncgen_hooks(*old_agen_hooks)
def run_until_complete(self, future):
"""Run until the Future is done.
@@ -349,7 +422,7 @@ class BaseEventLoop(events.AbstractEventLoop):
"""
self._check_closed()
- new_task = not isinstance(future, futures.Future)
+ new_task = not futures.isfuture(future)
future = tasks.ensure_future(future, loop=self)
if new_task:
# An exception is raised if the future didn't complete, so there
@@ -529,15 +602,18 @@ class BaseEventLoop(events.AbstractEventLoop):
if isinstance(func, events.Handle):
assert not args
assert not isinstance(func, events.TimerHandle)
+ warnings.warn(
+ "Passing Handle to loop.run_in_executor() is deprecated",
+ DeprecationWarning)
if func._cancelled:
- f = futures.Future(loop=self)
+ f = self.create_future()
f.set_result(None)
return f
func, args = func._callback, func._args
if executor is None:
executor = self._default_executor
if executor is None:
- executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS)
+ executor = concurrent.futures.ThreadPoolExecutor()
self._default_executor = executor
return futures.wrap_future(executor.submit(func, *args), loop=self)
@@ -571,12 +647,7 @@ class BaseEventLoop(events.AbstractEventLoop):
def getaddrinfo(self, host, port, *,
family=0, type=0, proto=0, flags=0):
- info = _ipaddr_info(host, port, family, type, proto)
- if info is not None:
- fut = futures.Future(loop=self)
- fut.set_result([info])
- return fut
- elif self._debug:
+ if self._debug:
return self.run_in_executor(None, self._getaddrinfo_debug,
host, port, family, type, proto, flags)
else:
@@ -625,14 +696,14 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'host/port and sock can not be specified at the same time')
- f1 = self.getaddrinfo(
- host, port, family=family,
- type=socket.SOCK_STREAM, proto=proto, flags=flags)
+ f1 = _ensure_resolved((host, port), family=family,
+ type=socket.SOCK_STREAM, proto=proto,
+ flags=flags, loop=self)
fs = [f1]
if local_addr is not None:
- f2 = self.getaddrinfo(
- *local_addr, family=family,
- type=socket.SOCK_STREAM, proto=proto, flags=flags)
+ f2 = _ensure_resolved(local_addr, family=family,
+ type=socket.SOCK_STREAM, proto=proto,
+ flags=flags, loop=self)
fs.append(f2)
else:
f2 = None
@@ -698,8 +769,6 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'host and port was not specified and no sock specified')
- sock.setblocking(False)
-
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
if self._debug:
@@ -712,14 +781,17 @@ class BaseEventLoop(events.AbstractEventLoop):
@coroutine
def _create_connection_transport(self, sock, protocol_factory, ssl,
- server_hostname):
+ server_hostname, server_side=False):
+
+ sock.setblocking(False)
+
protocol = protocol_factory()
- waiter = futures.Future(loop=self)
+ waiter = self.create_future()
if ssl:
sslcontext = None if isinstance(ssl, bool) else ssl
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
- server_side=False, server_hostname=server_hostname)
+ server_side=server_side, server_hostname=server_hostname)
else:
transport = self._make_socket_transport(sock, protocol, waiter)
@@ -767,9 +839,9 @@ class BaseEventLoop(events.AbstractEventLoop):
assert isinstance(addr, tuple) and len(addr) == 2, (
'2-tuple is expected')
- infos = yield from self.getaddrinfo(
- *addr, family=family, type=socket.SOCK_DGRAM,
- proto=proto, flags=flags)
+ infos = yield from _ensure_resolved(
+ addr, family=family, type=socket.SOCK_DGRAM,
+ proto=proto, flags=flags, loop=self)
if not infos:
raise OSError('getaddrinfo() returned empty list')
@@ -804,12 +876,7 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if reuse_port:
- if not hasattr(socket, 'SO_REUSEPORT'):
- raise ValueError(
- 'reuse_port not supported by socket module')
- else:
- sock.setsockopt(
- socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ _set_reuseport(sock)
if allow_broadcast:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
@@ -834,7 +901,7 @@ class BaseEventLoop(events.AbstractEventLoop):
raise exceptions[0]
protocol = protocol_factory()
- waiter = futures.Future(loop=self)
+ waiter = self.create_future()
transport = self._make_datagram_transport(
sock, protocol, r_addr, waiter)
if self._debug:
@@ -857,9 +924,9 @@ class BaseEventLoop(events.AbstractEventLoop):
@coroutine
def _create_server_getaddrinfo(self, host, port, family, flags):
- infos = yield from self.getaddrinfo(host, port, family=family,
+ infos = yield from _ensure_resolved((host, port), family=family,
type=socket.SOCK_STREAM,
- flags=flags)
+ flags=flags, loop=self)
if not infos:
raise OSError('getaddrinfo({!r}) returned empty list'.format(host))
return infos
@@ -880,7 +947,10 @@ class BaseEventLoop(events.AbstractEventLoop):
to host and port.
The host parameter can also be a sequence of strings and in that case
- the TCP server is bound to all hosts of the sequence.
+ the TCP server is bound to all hosts of the sequence. If a host
+ appears multiple times (possibly indirectly e.g. when hostnames
+ resolve to the same IP address), the server is only bound once to that
+ host.
Return a Server object which can be used to stop the service.
@@ -909,7 +979,7 @@ class BaseEventLoop(events.AbstractEventLoop):
flags=flags)
for host in hosts]
infos = yield from tasks.gather(*fs, loop=self)
- infos = itertools.chain.from_iterable(infos)
+ infos = set(itertools.chain.from_iterable(infos))
completed = False
try:
@@ -929,12 +999,7 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
if reuse_port:
- if not hasattr(socket, 'SO_REUSEPORT'):
- raise ValueError(
- 'reuse_port not supported by socket module')
- else:
- sock.setsockopt(
- socket.SOL_SOCKET, socket.SO_REUSEPORT, True)
+ _set_reuseport(sock)
# Disable IPv4/IPv6 dual stack support (enabled by
# default on Linux) which makes a single socket
# listen on both address families.
@@ -962,15 +1027,34 @@ class BaseEventLoop(events.AbstractEventLoop):
for sock in sockets:
sock.listen(backlog)
sock.setblocking(False)
- self._start_serving(protocol_factory, sock, ssl, server)
+ self._start_serving(protocol_factory, sock, ssl, server, backlog)
if self._debug:
logger.info("%r is serving", server)
return server
@coroutine
+ def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None):
+ """Handle an accepted connection.
+
+ This is used by servers that accept connections outside of
+ asyncio but that use asyncio to handle connections.
+
+ This method is a coroutine. When completed, the coroutine
+ returns a (transport, protocol) pair.
+ """
+ transport, protocol = yield from self._create_connection_transport(
+ sock, protocol_factory, ssl, '', server_side=True)
+ if self._debug:
+ # Get the socket from the transport because SSL transport closes
+ # the old socket and creates a new SSL socket
+ sock = transport.get_extra_info('socket')
+ logger.debug("%r handled: (%r, %r)", sock, transport, protocol)
+ return transport, protocol
+
+ @coroutine
def connect_read_pipe(self, protocol_factory, pipe):
protocol = protocol_factory()
- waiter = futures.Future(loop=self)
+ waiter = self.create_future()
transport = self._make_read_pipe_transport(pipe, protocol, waiter)
try:
@@ -987,7 +1071,7 @@ class BaseEventLoop(events.AbstractEventLoop):
@coroutine
def connect_write_pipe(self, protocol_factory, pipe):
protocol = protocol_factory()
- waiter = futures.Future(loop=self)
+ waiter = self.create_future()
transport = self._make_write_pipe_transport(pipe, protocol, waiter)
try:
@@ -1036,7 +1120,7 @@ class BaseEventLoop(events.AbstractEventLoop):
transport = yield from self._make_subprocess_transport(
protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs)
if self._debug:
- logger.info('%s: %r' % (debug_log, transport))
+ logger.info('%s: %r', debug_log, transport)
return transport, protocol
@coroutine
@@ -1066,9 +1150,14 @@ class BaseEventLoop(events.AbstractEventLoop):
protocol, popen_args, False, stdin, stdout, stderr,
bufsize, **kwargs)
if self._debug:
- logger.info('%s: %r' % (debug_log, transport))
+ logger.info('%s: %r', debug_log, transport)
return transport, protocol
+ def get_exception_handler(self):
+ """Return an exception handler, or None if the default one is in use.
+ """
+ return self._exception_handler
+
def set_exception_handler(self, handler):
"""Set handler as the new event loop exception handler.
@@ -1141,7 +1230,9 @@ class BaseEventLoop(events.AbstractEventLoop):
- 'handle' (optional): Handle instance;
- 'protocol' (optional): Protocol instance;
- 'transport' (optional): Transport instance;
- - 'socket' (optional): Socket instance.
+ - 'socket' (optional): Socket instance;
+ - 'asyncgen' (optional): Asynchronous generator that caused
+ the exception.
New keys maybe introduced in the future.