diff options
Diffstat (limited to 'Lib/asyncio/base_events.py')
-rw-r--r-- | Lib/asyncio/base_events.py | 125 |
1 files changed, 89 insertions, 36 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 9b4b846..c58906f 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -16,6 +16,7 @@ to modify the meaning of the API call itself. import collections import collections.abc import concurrent.futures +import functools import heapq import itertools import os @@ -41,6 +42,7 @@ from . import exceptions from . import futures from . import protocols from . import sslproto +from . import staggered from . import tasks from . import transports from .log import logger @@ -159,6 +161,28 @@ def _ipaddr_info(host, port, family, type, proto): return None +def _interleave_addrinfos(addrinfos, first_address_family_count=1): + """Interleave list of addrinfo tuples by family.""" + # Group addresses by family + addrinfos_by_family = collections.OrderedDict() + for addr in addrinfos: + family = addr[0] + if family not in addrinfos_by_family: + addrinfos_by_family[family] = [] + addrinfos_by_family[family].append(addr) + addrinfos_lists = list(addrinfos_by_family.values()) + + reordered = [] + if first_address_family_count > 1: + reordered.extend(addrinfos_lists[0][:first_address_family_count - 1]) + del addrinfos_lists[0][:first_address_family_count - 1] + reordered.extend( + a for a in itertools.chain.from_iterable( + itertools.zip_longest(*addrinfos_lists) + ) if a is not None) + return reordered + + def _run_until_complete_cb(fut): if not fut.cancelled(): exc = fut.exception() @@ -871,12 +895,49 @@ class BaseEventLoop(events.AbstractEventLoop): "offset must be a non-negative integer (got {!r})".format( offset)) + async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None): + """Create, bind and connect one socket.""" + my_exceptions = [] + exceptions.append(my_exceptions) + family, type_, proto, _, address = addr_info + sock = None + try: + sock = socket.socket(family=family, type=type_, proto=proto) + sock.setblocking(False) + if local_addr_infos is not None: + for _, _, _, _, laddr in local_addr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + msg = ( + f'error while attempting to bind on ' + f'address {laddr!r}: ' + f'{exc.strerror.lower()}' + ) + exc = OSError(exc.errno, msg) + my_exceptions.append(exc) + else: # all bind attempts failed + raise my_exceptions.pop() + await self.sock_connect(sock, address) + return sock + except OSError as exc: + my_exceptions.append(exc) + if sock is not None: + sock.close() + raise + except: + if sock is not None: + sock.close() + raise + async def create_connection( self, protocol_factory, host=None, port=None, *, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + happy_eyeballs_delay=None, interleave=None): """Connect to a TCP server. Create a streaming transport connection to a given Internet host and @@ -911,6 +972,10 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if happy_eyeballs_delay is not None and interleave is None: + # If using happy eyeballs, default to interleave addresses by family + interleave = 1 + if host is not None or port is not None: if sock is not None: raise ValueError( @@ -929,43 +994,31 @@ class BaseEventLoop(events.AbstractEventLoop): flags=flags, loop=self) if not laddr_infos: raise OSError('getaddrinfo() returned empty list') + else: + laddr_infos = None + + if interleave: + infos = _interleave_addrinfos(infos, interleave) exceptions = [] - for family, type, proto, cname, address in infos: - try: - sock = socket.socket(family=family, type=type, proto=proto) - sock.setblocking(False) - if local_addr is not None: - for _, _, _, _, laddr in laddr_infos: - try: - sock.bind(laddr) - break - except OSError as exc: - msg = ( - f'error while attempting to bind on ' - f'address {laddr!r}: ' - f'{exc.strerror.lower()}' - ) - exc = OSError(exc.errno, msg) - exceptions.append(exc) - else: - sock.close() - sock = None - continue - if self._debug: - logger.debug("connect %r to %r", sock, address) - await self.sock_connect(sock, address) - except OSError as exc: - if sock is not None: - sock.close() - exceptions.append(exc) - except: - if sock is not None: - sock.close() - raise - else: - break - else: + if happy_eyeballs_delay is None: + # not using happy eyeballs + for addrinfo in infos: + try: + sock = await self._connect_sock( + exceptions, addrinfo, laddr_infos) + break + except OSError: + continue + else: # using happy eyeballs + sock, _, _ = await staggered.staggered_race( + (functools.partial(self._connect_sock, + exceptions, addrinfo, laddr_infos) + for addrinfo in infos), + happy_eyeballs_delay, loop=self) + + if sock is None: + exceptions = [exc for sub in exceptions for exc in sub] if len(exceptions) == 1: raise exceptions[0] else: |