summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio/base_events.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/asyncio/base_events.py')
-rw-r--r--Lib/asyncio/base_events.py125
1 files changed, 89 insertions, 36 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 9b4b846..c58906f 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -16,6 +16,7 @@ to modify the meaning of the API call itself.
import collections
import collections.abc
import concurrent.futures
+import functools
import heapq
import itertools
import os
@@ -41,6 +42,7 @@ from . import exceptions
from . import futures
from . import protocols
from . import sslproto
+from . import staggered
from . import tasks
from . import transports
from .log import logger
@@ -159,6 +161,28 @@ def _ipaddr_info(host, port, family, type, proto):
return None
+def _interleave_addrinfos(addrinfos, first_address_family_count=1):
+ """Interleave list of addrinfo tuples by family."""
+ # Group addresses by family
+ addrinfos_by_family = collections.OrderedDict()
+ for addr in addrinfos:
+ family = addr[0]
+ if family not in addrinfos_by_family:
+ addrinfos_by_family[family] = []
+ addrinfos_by_family[family].append(addr)
+ addrinfos_lists = list(addrinfos_by_family.values())
+
+ reordered = []
+ if first_address_family_count > 1:
+ reordered.extend(addrinfos_lists[0][:first_address_family_count - 1])
+ del addrinfos_lists[0][:first_address_family_count - 1]
+ reordered.extend(
+ a for a in itertools.chain.from_iterable(
+ itertools.zip_longest(*addrinfos_lists)
+ ) if a is not None)
+ return reordered
+
+
def _run_until_complete_cb(fut):
if not fut.cancelled():
exc = fut.exception()
@@ -871,12 +895,49 @@ class BaseEventLoop(events.AbstractEventLoop):
"offset must be a non-negative integer (got {!r})".format(
offset))
+ async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None):
+ """Create, bind and connect one socket."""
+ my_exceptions = []
+ exceptions.append(my_exceptions)
+ family, type_, proto, _, address = addr_info
+ sock = None
+ try:
+ sock = socket.socket(family=family, type=type_, proto=proto)
+ sock.setblocking(False)
+ if local_addr_infos is not None:
+ for _, _, _, _, laddr in local_addr_infos:
+ try:
+ sock.bind(laddr)
+ break
+ except OSError as exc:
+ msg = (
+ f'error while attempting to bind on '
+ f'address {laddr!r}: '
+ f'{exc.strerror.lower()}'
+ )
+ exc = OSError(exc.errno, msg)
+ my_exceptions.append(exc)
+ else: # all bind attempts failed
+ raise my_exceptions.pop()
+ await self.sock_connect(sock, address)
+ return sock
+ except OSError as exc:
+ my_exceptions.append(exc)
+ if sock is not None:
+ sock.close()
+ raise
+ except:
+ if sock is not None:
+ sock.close()
+ raise
+
async def create_connection(
self, protocol_factory, host=None, port=None,
*, ssl=None, family=0,
proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ happy_eyeballs_delay=None, interleave=None):
"""Connect to a TCP server.
Create a streaming transport connection to a given Internet host and
@@ -911,6 +972,10 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
+ if happy_eyeballs_delay is not None and interleave is None:
+ # If using happy eyeballs, default to interleave addresses by family
+ interleave = 1
+
if host is not None or port is not None:
if sock is not None:
raise ValueError(
@@ -929,43 +994,31 @@ class BaseEventLoop(events.AbstractEventLoop):
flags=flags, loop=self)
if not laddr_infos:
raise OSError('getaddrinfo() returned empty list')
+ else:
+ laddr_infos = None
+
+ if interleave:
+ infos = _interleave_addrinfos(infos, interleave)
exceptions = []
- for family, type, proto, cname, address in infos:
- try:
- sock = socket.socket(family=family, type=type, proto=proto)
- sock.setblocking(False)
- if local_addr is not None:
- for _, _, _, _, laddr in laddr_infos:
- try:
- sock.bind(laddr)
- break
- except OSError as exc:
- msg = (
- f'error while attempting to bind on '
- f'address {laddr!r}: '
- f'{exc.strerror.lower()}'
- )
- exc = OSError(exc.errno, msg)
- exceptions.append(exc)
- else:
- sock.close()
- sock = None
- continue
- if self._debug:
- logger.debug("connect %r to %r", sock, address)
- await self.sock_connect(sock, address)
- except OSError as exc:
- if sock is not None:
- sock.close()
- exceptions.append(exc)
- except:
- if sock is not None:
- sock.close()
- raise
- else:
- break
- else:
+ if happy_eyeballs_delay is None:
+ # not using happy eyeballs
+ for addrinfo in infos:
+ try:
+ sock = await self._connect_sock(
+ exceptions, addrinfo, laddr_infos)
+ break
+ except OSError:
+ continue
+ else: # using happy eyeballs
+ sock, _, _ = await staggered.staggered_race(
+ (functools.partial(self._connect_sock,
+ exceptions, addrinfo, laddr_infos)
+ for addrinfo in infos),
+ happy_eyeballs_delay, loop=self)
+
+ if sock is None:
+ exceptions = [exc for sub in exceptions for exc in sub]
if len(exceptions) == 1:
raise exceptions[0]
else: