summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorGuido van Rossum <guido@python.org>2015-10-05 16:15:28 (GMT)
committerGuido van Rossum <guido@python.org>2015-10-05 16:15:28 (GMT)
commitb9bf913ab32d27d221fb765fd90d64d07e926000 (patch)
tree375ff9c2e3ae0287a6b4ae8b00630f04015ba674 /Lib
parentd17e9785de023fc425142e0d348f368677b90011 (diff)
downloadcpython-b9bf913ab32d27d221fb765fd90d64d07e926000.zip
cpython-b9bf913ab32d27d221fb765fd90d64d07e926000.tar.gz
cpython-b9bf913ab32d27d221fb765fd90d64d07e926000.tar.bz2
Issue #23972: updates to asyncio datagram API. By Chris Laws.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/base_events.py174
-rw-r--r--Lib/asyncio/events.py40
-rw-r--r--Lib/test/test_asyncio/test_base_events.py140
-rw-r--r--Lib/test/test_asyncio/test_events.py52
4 files changed, 336 insertions, 70 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index a50e005..af9c881 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -700,75 +700,109 @@ class BaseEventLoop(events.AbstractEventLoop):
@coroutine
def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None, *,
- family=0, proto=0, flags=0):
+ family=0, proto=0, flags=0,
+ reuse_address=None, reuse_port=None,
+ allow_broadcast=None, sock=None):
"""Create datagram connection."""
- if not (local_addr or remote_addr):
- if family == 0:
- raise ValueError('unexpected address family')
- addr_pairs_info = (((family, proto), (None, None)),)
- else:
- # join address by (family, protocol)
- addr_infos = collections.OrderedDict()
- for idx, addr in ((0, local_addr), (1, remote_addr)):
- if addr is not None:
- assert isinstance(addr, tuple) and len(addr) == 2, (
- '2-tuple is expected')
-
- infos = yield from self.getaddrinfo(
- *addr, family=family, type=socket.SOCK_DGRAM,
- proto=proto, flags=flags)
- if not infos:
- raise OSError('getaddrinfo() returned empty list')
-
- for fam, _, pro, _, address in infos:
- key = (fam, pro)
- if key not in addr_infos:
- addr_infos[key] = [None, None]
- addr_infos[key][idx] = address
-
- # each addr has to have info for each (family, proto) pair
- addr_pairs_info = [
- (key, addr_pair) for key, addr_pair in addr_infos.items()
- if not ((local_addr and addr_pair[0] is None) or
- (remote_addr and addr_pair[1] is None))]
-
- if not addr_pairs_info:
- raise ValueError('can not get address information')
-
- exceptions = []
-
- for ((family, proto),
- (local_address, remote_address)) in addr_pairs_info:
- sock = None
+ if sock is not None:
+ if (local_addr or remote_addr or
+ family or proto or flags or
+ reuse_address or reuse_port or allow_broadcast):
+ # show the problematic kwargs in exception msg
+ opts = dict(local_addr=local_addr, remote_addr=remote_addr,
+ family=family, proto=proto, flags=flags,
+ reuse_address=reuse_address, reuse_port=reuse_port,
+ allow_broadcast=allow_broadcast)
+ problems = ', '.join(
+ '{}={}'.format(k, v) for k, v in opts.items() if v)
+ raise ValueError(
+ 'socket modifier keyword arguments can not be used '
+ 'when sock is specified. ({})'.format(problems))
+ sock.setblocking(False)
r_addr = None
- try:
- sock = socket.socket(
- family=family, type=socket.SOCK_DGRAM, proto=proto)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.setblocking(False)
-
- if local_addr:
- sock.bind(local_address)
- if remote_addr:
- yield from self.sock_connect(sock, remote_address)
- r_addr = remote_address
- except OSError as exc:
- if sock is not None:
- sock.close()
- exceptions.append(exc)
- except:
- if sock is not None:
- sock.close()
- raise
- else:
- break
else:
- raise exceptions[0]
+ if not (local_addr or remote_addr):
+ if family == 0:
+ raise ValueError('unexpected address family')
+ addr_pairs_info = (((family, proto), (None, None)),)
+ else:
+ # join address by (family, protocol)
+ addr_infos = collections.OrderedDict()
+ for idx, addr in ((0, local_addr), (1, remote_addr)):
+ if addr is not None:
+ assert isinstance(addr, tuple) and len(addr) == 2, (
+ '2-tuple is expected')
+
+ infos = yield from self.getaddrinfo(
+ *addr, family=family, type=socket.SOCK_DGRAM,
+ proto=proto, flags=flags)
+ if not infos:
+ raise OSError('getaddrinfo() returned empty list')
+
+ for fam, _, pro, _, address in infos:
+ key = (fam, pro)
+ if key not in addr_infos:
+ addr_infos[key] = [None, None]
+ addr_infos[key][idx] = address
+
+ # each addr has to have info for each (family, proto) pair
+ addr_pairs_info = [
+ (key, addr_pair) for key, addr_pair in addr_infos.items()
+ if not ((local_addr and addr_pair[0] is None) or
+ (remote_addr and addr_pair[1] is None))]
+
+ if not addr_pairs_info:
+ raise ValueError('can not get address information')
+
+ exceptions = []
+
+ if reuse_address is None:
+ reuse_address = os.name == 'posix' and sys.platform != 'cygwin'
+
+ for ((family, proto),
+ (local_address, remote_address)) in addr_pairs_info:
+ sock = None
+ r_addr = None
+ try:
+ sock = socket.socket(
+ family=family, type=socket.SOCK_DGRAM, proto=proto)
+ if reuse_address:
+ sock.setsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ if reuse_port:
+ if not hasattr(socket, 'SO_REUSEPORT'):
+ raise ValueError(
+ 'reuse_port not supported by socket module')
+ else:
+ sock.setsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ if allow_broadcast:
+ sock.setsockopt(
+ socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
+ sock.setblocking(False)
+
+ if local_addr:
+ sock.bind(local_address)
+ if remote_addr:
+ yield from self.sock_connect(sock, remote_address)
+ r_addr = remote_address
+ except OSError as exc:
+ if sock is not None:
+ sock.close()
+ exceptions.append(exc)
+ except:
+ if sock is not None:
+ sock.close()
+ raise
+ else:
+ break
+ else:
+ raise exceptions[0]
protocol = protocol_factory()
waiter = futures.Future(loop=self)
- transport = self._make_datagram_transport(sock, protocol, r_addr,
- waiter)
+ transport = self._make_datagram_transport(
+ sock, protocol, r_addr, waiter)
if self._debug:
if local_addr:
logger.info("Datagram endpoint local_addr=%r remote_addr=%r "
@@ -804,7 +838,8 @@ class BaseEventLoop(events.AbstractEventLoop):
sock=None,
backlog=100,
ssl=None,
- reuse_address=None):
+ reuse_address=None,
+ reuse_port=None):
"""Create a TCP server.
The host parameter can be a string, in that case the TCP server is bound
@@ -857,8 +892,15 @@ class BaseEventLoop(events.AbstractEventLoop):
continue
sockets.append(sock)
if reuse_address:
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
- True)
+ sock.setsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
+ if reuse_port:
+ if not hasattr(socket, 'SO_REUSEPORT'):
+ raise ValueError(
+ 'reuse_port not supported by socket module')
+ else:
+ sock.setsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEPORT, True)
# Disable IPv4/IPv6 dual stack support (enabled by
# default on Linux) which makes a single socket
# listen on both address families.
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index 1e42ddd..176a846 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -297,7 +297,8 @@ class AbstractEventLoop:
def create_server(self, protocol_factory, host=None, port=None, *,
family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE,
- sock=None, backlog=100, ssl=None, reuse_address=None):
+ sock=None, backlog=100, ssl=None, reuse_address=None,
+ reuse_port=None):
"""A coroutine which creates a TCP server bound to host and port.
The return value is a Server object which can be used to stop
@@ -327,6 +328,11 @@ class AbstractEventLoop:
TIME_WAIT state, without waiting for its natural timeout to
expire. If not specified will automatically be set to True on
UNIX.
+
+ reuse_port tells the kernel to allow this endpoint to be bound to
+ the same port as other existing endpoints are bound to, so long as
+ they all set this flag when being created. This option is not
+ supported on Windows.
"""
raise NotImplementedError
@@ -358,7 +364,37 @@ class AbstractEventLoop:
def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None, *,
- family=0, proto=0, flags=0):
+ family=0, proto=0, flags=0,
+ reuse_address=None, reuse_port=None,
+ allow_broadcast=None, sock=None):
+ """A coroutine which creates a datagram endpoint.
+
+ This method will try to establish the endpoint in the background.
+ When successful, the coroutine returns a (transport, protocol) pair.
+
+ protocol_factory must be a callable returning a protocol instance.
+
+ socket family AF_INET or socket.AF_INET6 depending on host (or
+ family if specified), socket type SOCK_DGRAM.
+
+ reuse_address tells the kernel to reuse a local socket in
+ TIME_WAIT state, without waiting for its natural timeout to
+ expire. If not specified it will automatically be set to True on
+ UNIX.
+
+ reuse_port tells the kernel to allow this endpoint to be bound to
+ the same port as other existing endpoints are bound to, so long as
+ they all set this flag when being created. This option is not
+ supported on Windows and some UNIX's. If the
+ :py:data:`~socket.SO_REUSEPORT` constant is not defined then this
+ capability is unsupported.
+
+ allow_broadcast tells the kernel to allow this endpoint to send
+ messages to the broadcast address.
+
+ sock can optionally be specified in order to use a preexisting
+ socket object.
+ """
raise NotImplementedError
# Pipes and subprocesses.
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
index b1f1e56..1568440 100644
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -3,6 +3,7 @@
import errno
import logging
import math
+import os
import socket
import sys
import threading
@@ -790,11 +791,11 @@ class MyProto(asyncio.Protocol):
class MyDatagramProto(asyncio.DatagramProtocol):
done = None
- def __init__(self, create_future=False):
+ def __init__(self, create_future=False, loop=None):
self.state = 'INITIAL'
self.nbytes = 0
if create_future:
- self.done = asyncio.Future()
+ self.done = asyncio.Future(loop=loop)
def connection_made(self, transport):
self.transport = transport
@@ -1100,6 +1101,19 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises(OSError, self.loop.run_until_complete, f)
@mock.patch('asyncio.base_events.socket')
+ def test_create_server_nosoreuseport(self, m_socket):
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_socket.SOCK_STREAM = socket.SOCK_STREAM
+ m_socket.SOL_SOCKET = socket.SOL_SOCKET
+ del m_socket.SO_REUSEPORT
+ m_socket.socket.return_value = mock.Mock()
+
+ f = self.loop.create_server(
+ MyProto, '0.0.0.0', 0, reuse_port=True)
+
+ self.assertRaises(ValueError, self.loop.run_until_complete, f)
+
+ @mock.patch('asyncio.base_events.socket')
def test_create_server_cant_bind(self, m_socket):
class Err(OSError):
@@ -1199,6 +1213,128 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises(Err, self.loop.run_until_complete, fut)
self.assertTrue(m_sock.close.called)
+ def test_create_datagram_endpoint_sock(self):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ fut = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(create_future=True, loop=self.loop),
+ sock=sock)
+ transport, protocol = self.loop.run_until_complete(fut)
+ transport.close()
+ self.loop.run_until_complete(protocol.done)
+ self.assertEqual('CLOSED', protocol.state)
+
+ def test_create_datagram_endpoint_sock_sockopts(self):
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('127.0.0.1', 0), sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, family=1, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, proto=1, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, flags=1, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, reuse_address=True, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, reuse_port=True, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto, allow_broadcast=True, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ def test_create_datagram_endpoint_sockopts(self):
+ # Socket options should not be applied unless asked for.
+ # SO_REUSEADDR defaults to on for UNIX.
+ # SO_REUSEPORT is not available on all platforms.
+
+ coro = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(create_future=True, loop=self.loop),
+ local_addr=('127.0.0.1', 0))
+ transport, protocol = self.loop.run_until_complete(coro)
+ sock = transport.get_extra_info('socket')
+
+ reuse_address_default_on = (
+ os.name == 'posix' and sys.platform != 'cygwin')
+ reuseport_supported = hasattr(socket, 'SO_REUSEPORT')
+
+ if reuse_address_default_on:
+ self.assertTrue(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR))
+ else:
+ self.assertFalse(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR))
+ if reuseport_supported:
+ self.assertFalse(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEPORT))
+ self.assertFalse(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_BROADCAST))
+
+ transport.close()
+ self.loop.run_until_complete(protocol.done)
+ self.assertEqual('CLOSED', protocol.state)
+
+ coro = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(create_future=True, loop=self.loop),
+ local_addr=('127.0.0.1', 0),
+ reuse_address=True,
+ reuse_port=reuseport_supported,
+ allow_broadcast=True)
+ transport, protocol = self.loop.run_until_complete(coro)
+ sock = transport.get_extra_info('socket')
+
+ self.assertTrue(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR))
+ if reuseport_supported:
+ self.assertTrue(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEPORT))
+ else:
+ self.assertFalse(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEPORT))
+ self.assertTrue(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_BROADCAST))
+
+ transport.close()
+ self.loop.run_until_complete(protocol.done)
+ self.assertEqual('CLOSED', protocol.state)
+
+ @mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_nosoreuseport(self, m_socket):
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_socket.SOCK_DGRAM = socket.SOCK_DGRAM
+ m_socket.SOL_SOCKET = socket.SOL_SOCKET
+ del m_socket.SO_REUSEPORT
+ m_socket.socket.return_value = mock.Mock()
+
+ coro = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(loop=self.loop),
+ local_addr=('127.0.0.1', 0),
+ reuse_address=False,
+ reuse_port=True)
+
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
def test_accept_connection_retry(self):
sock = mock.Mock()
sock.accept.side_effect = BlockingIOError()
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index 9801d22..141fde7 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -814,6 +814,32 @@ class EventLoopTestsMixin:
# close server
server.close()
+ @unittest.skipUnless(hasattr(socket, 'SO_REUSEPORT'), 'No SO_REUSEPORT')
+ def test_create_server_reuse_port(self):
+ proto = MyProto(self.loop)
+ f = self.loop.create_server(
+ lambda: proto, '0.0.0.0', 0)
+ server = self.loop.run_until_complete(f)
+ self.assertEqual(len(server.sockets), 1)
+ sock = server.sockets[0]
+ self.assertFalse(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEPORT))
+ server.close()
+
+ test_utils.run_briefly(self.loop)
+
+ proto = MyProto(self.loop)
+ f = self.loop.create_server(
+ lambda: proto, '0.0.0.0', 0, reuse_port=True)
+ server = self.loop.run_until_complete(f)
+ self.assertEqual(len(server.sockets), 1)
+ sock = server.sockets[0]
+ self.assertTrue(
+ sock.getsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEPORT))
+ server.close()
+
def _make_unix_server(self, factory, **kwargs):
path = test_utils.gen_unix_socket_path()
self.addCleanup(lambda: os.path.exists(path) and os.unlink(path))
@@ -1264,6 +1290,32 @@ class EventLoopTestsMixin:
self.assertEqual('CLOSED', client.state)
server.transport.close()
+ def test_create_datagram_endpoint_sock(self):
+ sock = None
+ local_address = ('127.0.0.1', 0)
+ infos = self.loop.run_until_complete(
+ self.loop.getaddrinfo(
+ *local_address, type=socket.SOCK_DGRAM))
+ for family, type, proto, cname, address in infos:
+ try:
+ sock = socket.socket(family=family, type=type, proto=proto)
+ sock.setblocking(False)
+ sock.bind(address)
+ except:
+ pass
+ else:
+ break
+ else:
+ assert False, 'Can not create socket.'
+
+ f = self.loop.create_connection(
+ lambda: MyDatagramProto(loop=self.loop), sock=sock)
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertIsInstance(tr, asyncio.Transport)
+ self.assertIsInstance(pr, MyDatagramProto)
+ tr.close()
+ self.loop.run_until_complete(pr.done)
+
def test_internal_fds(self):
loop = self.create_event_loop()
if not isinstance(loop, selector_events.BaseSelectorEventLoop):