diff options
Diffstat (limited to 'Lib/asyncio/base_events.py')
| -rw-r--r-- | Lib/asyncio/base_events.py | 295 |
1 files changed, 193 insertions, 102 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 4505732..b3e318e 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -13,13 +13,10 @@ conscious design decision, leaving the door open for keyword arguments to modify the meaning of the API call itself. """ - import collections import concurrent.futures -import functools import heapq import inspect -import ipaddress import itertools import logging import os @@ -30,6 +27,7 @@ import time import traceback import sys import warnings +import weakref from . import compat from . import coroutines @@ -43,9 +41,6 @@ from .log import logger __all__ = ['BaseEventLoop'] -# Argument for default thread pool executor creation. -_MAX_WORKERS = 5 - # Minimum number of _scheduled timer handles before cleanup of # cancelled handles is performed. _MIN_SCHEDULED_TIMER_HANDLES = 100 @@ -54,6 +49,12 @@ _MIN_SCHEDULED_TIMER_HANDLES = 100 # before cleanup of cancelled handles is performed. _MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5 +# Exceptions which must not call the exception handler in fatal error +# methods (_fatal_error()) +_FATAL_ERROR_IGNORE = (BrokenPipeError, + ConnectionResetError, ConnectionAbortedError) + + def _format_handle(handle): cb = handle._callback if inspect.ismethod(cb) and isinstance(cb.__self__, tasks.Task): @@ -72,6 +73,17 @@ def _format_pipe(fd): return repr(fd) +def _set_reuseport(sock): + if not hasattr(socket, 'SO_REUSEPORT'): + raise ValueError('reuse_port not supported by socket module') + else: + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except OSError: + raise ValueError('reuse_port not supported by socket module, ' + 'SO_REUSEPORT defined but not implemented.') + + # Linux's sock.type is a bitmask that can include extra info about socket. _SOCKET_TYPE_MASK = 0 if hasattr(socket, 'SOCK_NONBLOCK'): @@ -80,12 +92,14 @@ if hasattr(socket, 'SOCK_CLOEXEC'): _SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC -@functools.lru_cache(maxsize=1024) def _ipaddr_info(host, port, family, type, proto): - # Try to skip getaddrinfo if "host" is already an IP. Since getaddrinfo - # blocks on an exclusive lock on some platforms, users might handle name - # resolution in their own code and pass in resolved IPs. - if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or host is None: + # Try to skip getaddrinfo if "host" is already an IP. Users might have + # handled name resolution in their own code and pass in resolved IPs. + if not hasattr(socket, 'inet_pton'): + return + + if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or \ + host is None: return None type &= ~_SOCKET_TYPE_MASK @@ -96,59 +110,55 @@ def _ipaddr_info(host, port, family, type, proto): else: return None - if hasattr(socket, 'inet_pton'): - if family == socket.AF_UNSPEC: - afs = [socket.AF_INET, socket.AF_INET6] - else: - afs = [family] - - for af in afs: - # Linux's inet_pton doesn't accept an IPv6 zone index after host, - # like '::1%lo0', so strip it. If we happen to make an invalid - # address look valid, we fail later in sock.connect or sock.bind. - try: - if af == socket.AF_INET6: - socket.inet_pton(af, host.partition('%')[0]) - else: - socket.inet_pton(af, host) - return af, type, proto, '', (host, port) - except OSError: - pass - - # "host" is not an IP address. - return None - - # No inet_pton. (On Windows it's only available since Python 3.4.) - # Even though getaddrinfo with AI_NUMERICHOST would be non-blocking, it - # still requires a lock on some platforms, and waiting for that lock could - # block the event loop. Use ipaddress instead, it's just text parsing. - try: - addr = ipaddress.IPv4Address(host) - except ValueError: + if port is None: + port = 0 + elif isinstance(port, bytes) and port == b'': + port = 0 + elif isinstance(port, str) and port == '': + port = 0 + else: + # If port's a service name like "http", don't skip getaddrinfo. try: - addr = ipaddress.IPv6Address(host.partition('%')[0]) - except ValueError: + port = int(port) + except (TypeError, ValueError): return None - af = socket.AF_INET if addr.version == 4 else socket.AF_INET6 - if family not in (socket.AF_UNSPEC, af): - # "host" is wrong IP version for "family". - return None + if family == socket.AF_UNSPEC: + afs = [socket.AF_INET, socket.AF_INET6] + else: + afs = [family] - return af, type, proto, '', (host, port) + if isinstance(host, bytes): + host = host.decode('idna') + if '%' in host: + # Linux's inet_pton doesn't accept an IPv6 zone index after host, + # like '::1%lo0'. + return None + for af in afs: + try: + socket.inet_pton(af, host) + # The host has already been resolved. + return af, type, proto, '', (host, port) + except OSError: + pass -def _check_resolved_address(sock, address): - # Ensure that the address is already resolved to avoid the trap of hanging - # the entire event loop when the address requires doing a DNS lookup. + # "host" is not an IP address. + return None - if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX: - return +def _ensure_resolved(address, *, family=0, type=socket.SOCK_STREAM, proto=0, + flags=0, loop): host, port = address[:2] - if _ipaddr_info(host, port, sock.family, sock.type, sock.proto) is None: - raise ValueError("address must be resolved (IP address)," - " got host %r" % host) + info = _ipaddr_info(host, port, family, type, proto) + if info is not None: + # "host" is already a resolved IP. + fut = loop.create_future() + fut.set_result([info]) + return fut + else: + return loop.getaddrinfo(host, port, family=family, type=type, + proto=proto, flags=flags) def _run_until_complete_cb(fut): @@ -203,7 +213,7 @@ class Server(events.AbstractServer): def wait_closed(self): if self.sockets is None or self._waiters is None: return - waiter = futures.Future(loop=self._loop) + waiter = self._loop.create_future() self._waiters.append(waiter) yield from waiter @@ -232,11 +242,26 @@ class BaseEventLoop(events.AbstractEventLoop): self._task_factory = None self._coroutine_wrapper_set = False + if hasattr(sys, 'get_asyncgen_hooks'): + # Python >= 3.6 + # A weak set of all asynchronous generators that are + # being iterated by the loop. + self._asyncgens = weakref.WeakSet() + else: + self._asyncgens = None + + # Set to True when `loop.shutdown_asyncgens` is called. + self._asyncgens_shutdown_called = False + def __repr__(self): return ('<%s running=%s closed=%s debug=%s>' % (self.__class__.__name__, self.is_running(), self.is_closed(), self.get_debug())) + def create_future(self): + """Create a Future object attached to the loop.""" + return futures.Future(loop=self) + def create_task(self, coro): """Schedule a coroutine object. @@ -319,6 +344,48 @@ class BaseEventLoop(events.AbstractEventLoop): if self._closed: raise RuntimeError('Event loop is closed') + def _asyncgen_finalizer_hook(self, agen): + self._asyncgens.discard(agen) + if not self.is_closed(): + self.create_task(agen.aclose()) + + def _asyncgen_firstiter_hook(self, agen): + if self._asyncgens_shutdown_called: + warnings.warn( + "asynchronous generator {!r} was scheduled after " + "loop.shutdown_asyncgens() call".format(agen), + ResourceWarning, source=self) + + self._asyncgens.add(agen) + + @coroutine + def shutdown_asyncgens(self): + """Shutdown all active asynchronous generators.""" + self._asyncgens_shutdown_called = True + + if self._asyncgens is None or not len(self._asyncgens): + # If Python version is <3.6 or we don't have any asynchronous + # generators alive. + return + + closing_agens = list(self._asyncgens) + self._asyncgens.clear() + + shutdown_coro = tasks.gather( + *[ag.aclose() for ag in closing_agens], + return_exceptions=True, + loop=self) + + results = yield from shutdown_coro + for result, agen in zip(results, closing_agens): + if isinstance(result, Exception): + self.call_exception_handler({ + 'message': 'an error occurred during closing of ' + 'asynchronous generator {!r}'.format(agen), + 'exception': result, + 'asyncgen': agen + }) + def run_forever(self): """Run until stop() is called.""" self._check_closed() @@ -326,6 +393,10 @@ class BaseEventLoop(events.AbstractEventLoop): raise RuntimeError('Event loop is running.') self._set_coroutine_wrapper(self._debug) self._thread_id = threading.get_ident() + if self._asyncgens is not None: + old_agen_hooks = sys.get_asyncgen_hooks() + sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook, + finalizer=self._asyncgen_finalizer_hook) try: while True: self._run_once() @@ -335,6 +406,8 @@ class BaseEventLoop(events.AbstractEventLoop): self._stopping = False self._thread_id = None self._set_coroutine_wrapper(False) + if self._asyncgens is not None: + sys.set_asyncgen_hooks(*old_agen_hooks) def run_until_complete(self, future): """Run until the Future is done. @@ -349,7 +422,7 @@ class BaseEventLoop(events.AbstractEventLoop): """ self._check_closed() - new_task = not isinstance(future, futures.Future) + new_task = not futures.isfuture(future) future = tasks.ensure_future(future, loop=self) if new_task: # An exception is raised if the future didn't complete, so there @@ -529,15 +602,18 @@ class BaseEventLoop(events.AbstractEventLoop): if isinstance(func, events.Handle): assert not args assert not isinstance(func, events.TimerHandle) + warnings.warn( + "Passing Handle to loop.run_in_executor() is deprecated", + DeprecationWarning) if func._cancelled: - f = futures.Future(loop=self) + f = self.create_future() f.set_result(None) return f func, args = func._callback, func._args if executor is None: executor = self._default_executor if executor is None: - executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + executor = concurrent.futures.ThreadPoolExecutor() self._default_executor = executor return futures.wrap_future(executor.submit(func, *args), loop=self) @@ -571,12 +647,7 @@ class BaseEventLoop(events.AbstractEventLoop): def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): - info = _ipaddr_info(host, port, family, type, proto) - if info is not None: - fut = futures.Future(loop=self) - fut.set_result([info]) - return fut - elif self._debug: + if self._debug: return self.run_in_executor(None, self._getaddrinfo_debug, host, port, family, type, proto, flags) else: @@ -625,14 +696,14 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'host/port and sock can not be specified at the same time') - f1 = self.getaddrinfo( - host, port, family=family, - type=socket.SOCK_STREAM, proto=proto, flags=flags) + f1 = _ensure_resolved((host, port), family=family, + type=socket.SOCK_STREAM, proto=proto, + flags=flags, loop=self) fs = [f1] if local_addr is not None: - f2 = self.getaddrinfo( - *local_addr, family=family, - type=socket.SOCK_STREAM, proto=proto, flags=flags) + f2 = _ensure_resolved(local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, + flags=flags, loop=self) fs.append(f2) else: f2 = None @@ -698,8 +769,6 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'host and port was not specified and no sock specified') - sock.setblocking(False) - transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, server_hostname) if self._debug: @@ -712,14 +781,17 @@ class BaseEventLoop(events.AbstractEventLoop): @coroutine def _create_connection_transport(self, sock, protocol_factory, ssl, - server_hostname): + server_hostname, server_side=False): + + sock.setblocking(False) + protocol = protocol_factory() - waiter = futures.Future(loop=self) + waiter = self.create_future() if ssl: sslcontext = None if isinstance(ssl, bool) else ssl transport = self._make_ssl_transport( sock, protocol, sslcontext, waiter, - server_side=False, server_hostname=server_hostname) + server_side=server_side, server_hostname=server_hostname) else: transport = self._make_socket_transport(sock, protocol, waiter) @@ -767,9 +839,9 @@ class BaseEventLoop(events.AbstractEventLoop): assert isinstance(addr, tuple) and len(addr) == 2, ( '2-tuple is expected') - infos = yield from self.getaddrinfo( - *addr, family=family, type=socket.SOCK_DGRAM, - proto=proto, flags=flags) + infos = yield from _ensure_resolved( + addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags, loop=self) if not infos: raise OSError('getaddrinfo() returned empty list') @@ -804,12 +876,7 @@ class BaseEventLoop(events.AbstractEventLoop): sock.setsockopt( socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if reuse_port: - if not hasattr(socket, 'SO_REUSEPORT'): - raise ValueError( - 'reuse_port not supported by socket module') - else: - sock.setsockopt( - socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + _set_reuseport(sock) if allow_broadcast: sock.setsockopt( socket.SOL_SOCKET, socket.SO_BROADCAST, 1) @@ -834,7 +901,7 @@ class BaseEventLoop(events.AbstractEventLoop): raise exceptions[0] protocol = protocol_factory() - waiter = futures.Future(loop=self) + waiter = self.create_future() transport = self._make_datagram_transport( sock, protocol, r_addr, waiter) if self._debug: @@ -857,9 +924,9 @@ class BaseEventLoop(events.AbstractEventLoop): @coroutine def _create_server_getaddrinfo(self, host, port, family, flags): - infos = yield from self.getaddrinfo(host, port, family=family, + infos = yield from _ensure_resolved((host, port), family=family, type=socket.SOCK_STREAM, - flags=flags) + flags=flags, loop=self) if not infos: raise OSError('getaddrinfo({!r}) returned empty list'.format(host)) return infos @@ -880,7 +947,10 @@ class BaseEventLoop(events.AbstractEventLoop): to host and port. The host parameter can also be a sequence of strings and in that case - the TCP server is bound to all hosts of the sequence. + the TCP server is bound to all hosts of the sequence. If a host + appears multiple times (possibly indirectly e.g. when hostnames + resolve to the same IP address), the server is only bound once to that + host. Return a Server object which can be used to stop the service. @@ -909,7 +979,7 @@ class BaseEventLoop(events.AbstractEventLoop): flags=flags) for host in hosts] infos = yield from tasks.gather(*fs, loop=self) - infos = itertools.chain.from_iterable(infos) + infos = set(itertools.chain.from_iterable(infos)) completed = False try: @@ -929,12 +999,7 @@ class BaseEventLoop(events.AbstractEventLoop): sock.setsockopt( socket.SOL_SOCKET, socket.SO_REUSEADDR, True) if reuse_port: - if not hasattr(socket, 'SO_REUSEPORT'): - raise ValueError( - 'reuse_port not supported by socket module') - else: - sock.setsockopt( - socket.SOL_SOCKET, socket.SO_REUSEPORT, True) + _set_reuseport(sock) # Disable IPv4/IPv6 dual stack support (enabled by # default on Linux) which makes a single socket # listen on both address families. @@ -962,15 +1027,34 @@ class BaseEventLoop(events.AbstractEventLoop): for sock in sockets: sock.listen(backlog) sock.setblocking(False) - self._start_serving(protocol_factory, sock, ssl, server) + self._start_serving(protocol_factory, sock, ssl, server, backlog) if self._debug: logger.info("%r is serving", server) return server @coroutine + def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None): + """Handle an accepted connection. + + This is used by servers that accept connections outside of + asyncio but that use asyncio to handle connections. + + This method is a coroutine. When completed, the coroutine + returns a (transport, protocol) pair. + """ + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, '', server_side=True) + if self._debug: + # Get the socket from the transport because SSL transport closes + # the old socket and creates a new SSL socket + sock = transport.get_extra_info('socket') + logger.debug("%r handled: (%r, %r)", sock, transport, protocol) + return transport, protocol + + @coroutine def connect_read_pipe(self, protocol_factory, pipe): protocol = protocol_factory() - waiter = futures.Future(loop=self) + waiter = self.create_future() transport = self._make_read_pipe_transport(pipe, protocol, waiter) try: @@ -987,7 +1071,7 @@ class BaseEventLoop(events.AbstractEventLoop): @coroutine def connect_write_pipe(self, protocol_factory, pipe): protocol = protocol_factory() - waiter = futures.Future(loop=self) + waiter = self.create_future() transport = self._make_write_pipe_transport(pipe, protocol, waiter) try: @@ -1036,7 +1120,7 @@ class BaseEventLoop(events.AbstractEventLoop): transport = yield from self._make_subprocess_transport( protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) if self._debug: - logger.info('%s: %r' % (debug_log, transport)) + logger.info('%s: %r', debug_log, transport) return transport, protocol @coroutine @@ -1066,9 +1150,14 @@ class BaseEventLoop(events.AbstractEventLoop): protocol, popen_args, False, stdin, stdout, stderr, bufsize, **kwargs) if self._debug: - logger.info('%s: %r' % (debug_log, transport)) + logger.info('%s: %r', debug_log, transport) return transport, protocol + def get_exception_handler(self): + """Return an exception handler, or None if the default one is in use. + """ + return self._exception_handler + def set_exception_handler(self, handler): """Set handler as the new event loop exception handler. @@ -1141,7 +1230,9 @@ class BaseEventLoop(events.AbstractEventLoop): - 'handle' (optional): Handle instance; - 'protocol' (optional): Protocol instance; - 'transport' (optional): Transport instance; - - 'socket' (optional): Socket instance. + - 'socket' (optional): Socket instance; + - 'asyncgen' (optional): Asynchronous generator that caused + the exception. New keys maybe introduced in the future. |
