summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorAndrew Svetlov <andrew.svetlov@gmail.com>2019-05-28 09:52:15 (GMT)
committerMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>2019-05-28 09:52:15 (GMT)
commitbafd4b5ac83b6cc0b7455290a04c4bfad34bdc90 (patch)
treebfb330fd3530eec1781d35b4b0c8339f93018951 /Lib
parent9ee2c264c37a71bd1c60f6032c50630b87e3c611 (diff)
downloadcpython-bafd4b5ac83b6cc0b7455290a04c4bfad34bdc90.zip
cpython-bafd4b5ac83b6cc0b7455290a04c4bfad34bdc90.tar.gz
cpython-bafd4b5ac83b6cc0b7455290a04c4bfad34bdc90.tar.bz2
bpo-29883: Asyncio proactor udp (GH-13440)
Follow-up for #1067 https://bugs.python.org/issue29883
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/proactor_events.py169
-rw-r--r--Lib/asyncio/windows_events.py46
-rw-r--r--Lib/test/test_asyncio/test_events.py9
-rw-r--r--Lib/test/test_asyncio/test_proactor_events.py278
4 files changed, 477 insertions, 25 deletions
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index 6a53b2e..9b8ae06 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -11,6 +11,7 @@ import os
import socket
import warnings
import signal
+import collections
from . import base_events
from . import constants
@@ -23,6 +24,24 @@ from . import trsock
from .log import logger
+def _set_socket_extra(transport, sock):
+ transport._extra['socket'] = trsock.TransportSocket(sock)
+
+ try:
+ transport._extra['sockname'] = sock.getsockname()
+ except socket.error:
+ if transport._loop.get_debug():
+ logger.warning(
+ "getsockname() failed on %r", sock, exc_info=True)
+
+ if 'peername' not in transport._extra:
+ try:
+ transport._extra['peername'] = sock.getpeername()
+ except socket.error:
+ # UDP sockets may not have a peer name
+ transport._extra['peername'] = None
+
+
class _ProactorBasePipeTransport(transports._FlowControlMixin,
transports.BaseTransport):
"""Base class for pipe and socket transports."""
@@ -430,6 +449,134 @@ class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport):
self.close()
+class _ProactorDatagramTransport(_ProactorBasePipeTransport):
+ max_size = 256 * 1024
+ def __init__(self, loop, sock, protocol, address=None,
+ waiter=None, extra=None):
+ self._address = address
+ self._empty_waiter = None
+ # We don't need to call _protocol.connection_made() since our base
+ # constructor does it for us.
+ super().__init__(loop, sock, protocol, waiter=waiter, extra=extra)
+
+ # The base constructor sets _buffer = None, so we set it here
+ self._buffer = collections.deque()
+ self._loop.call_soon(self._loop_reading)
+
+ def _set_extra(self, sock):
+ _set_socket_extra(self, sock)
+
+ def get_write_buffer_size(self):
+ return sum(len(data) for data, _ in self._buffer)
+
+ def abort(self):
+ self._force_close(None)
+
+ def sendto(self, data, addr=None):
+ if not isinstance(data, (bytes, bytearray, memoryview)):
+ raise TypeError('data argument must be bytes-like object (%r)',
+ type(data))
+
+ if not data:
+ return
+
+ if self._address is not None and addr not in (None, self._address):
+ raise ValueError(
+ f'Invalid address: must be None or {self._address}')
+
+ if self._conn_lost and self._address:
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ logger.warning('socket.sendto() raised exception.')
+ self._conn_lost += 1
+ return
+
+ # Ensure that what we buffer is immutable.
+ self._buffer.append((bytes(data), addr))
+
+ if self._write_fut is None:
+ # No current write operations are active, kick one off
+ self._loop_writing()
+ # else: A write operation is already kicked off
+
+ self._maybe_pause_protocol()
+
+ def _loop_writing(self, fut=None):
+ try:
+ if self._conn_lost:
+ return
+
+ assert fut is self._write_fut
+ self._write_fut = None
+ if fut:
+ # We are in a _loop_writing() done callback, get the result
+ fut.result()
+
+ if not self._buffer or (self._conn_lost and self._address):
+ # The connection has been closed
+ if self._closing:
+ self._loop.call_soon(self._call_connection_lost, None)
+ return
+
+ data, addr = self._buffer.popleft()
+ if self._address is not None:
+ self._write_fut = self._loop._proactor.send(self._sock,
+ data)
+ else:
+ self._write_fut = self._loop._proactor.sendto(self._sock,
+ data,
+ addr=addr)
+ except OSError as exc:
+ self._protocol.error_received(exc)
+ except Exception as exc:
+ self._fatal_error(exc, 'Fatal write error on datagram transport')
+ else:
+ self._write_fut.add_done_callback(self._loop_writing)
+ self._maybe_resume_protocol()
+
+ def _loop_reading(self, fut=None):
+ data = None
+ try:
+ if self._conn_lost:
+ return
+
+ assert self._read_fut is fut or (self._read_fut is None and
+ self._closing)
+
+ self._read_fut = None
+ if fut is not None:
+ res = fut.result()
+
+ if self._closing:
+ # since close() has been called we ignore any read data
+ data = None
+ return
+
+ if self._address is not None:
+ data, addr = res, self._address
+ else:
+ data, addr = res
+
+ if self._conn_lost:
+ return
+ if self._address is not None:
+ self._read_fut = self._loop._proactor.recv(self._sock,
+ self.max_size)
+ else:
+ self._read_fut = self._loop._proactor.recvfrom(self._sock,
+ self.max_size)
+ except OSError as exc:
+ self._protocol.error_received(exc)
+ except exceptions.CancelledError:
+ if not self._closing:
+ raise
+ else:
+ if self._read_fut is not None:
+ self._read_fut.add_done_callback(self._loop_reading)
+ finally:
+ if data:
+ self._protocol.datagram_received(data, addr)
+
+
class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport,
_ProactorBaseWritePipeTransport,
transports.Transport):
@@ -455,22 +602,7 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport,
base_events._set_nodelay(sock)
def _set_extra(self, sock):
- self._extra['socket'] = trsock.TransportSocket(sock)
-
- try:
- self._extra['sockname'] = sock.getsockname()
- except (socket.error, AttributeError):
- if self._loop.get_debug():
- logger.warning(
- "getsockname() failed on %r", sock, exc_info=True)
-
- if 'peername' not in self._extra:
- try:
- self._extra['peername'] = sock.getpeername()
- except (socket.error, AttributeError):
- if self._loop.get_debug():
- logger.warning("getpeername() failed on %r",
- sock, exc_info=True)
+ _set_socket_extra(self, sock)
def can_write_eof(self):
return True
@@ -515,6 +647,11 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
extra=extra, server=server)
return ssl_protocol._app_transport
+ def _make_datagram_transport(self, sock, protocol,
+ address=None, waiter=None, extra=None):
+ return _ProactorDatagramTransport(self, sock, protocol, address,
+ waiter, extra)
+
def _make_duplex_pipe_transport(self, sock, protocol, waiter=None,
extra=None):
return _ProactorDuplexPipeTransport(self,
diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py
index 61b40ba..ac51109 100644
--- a/Lib/asyncio/windows_events.py
+++ b/Lib/asyncio/windows_events.py
@@ -483,6 +483,44 @@ class IocpProactor:
return self._register(ov, conn, finish_recv)
+ def recvfrom(self, conn, nbytes, flags=0):
+ self._register_with_iocp(conn)
+ ov = _overlapped.Overlapped(NULL)
+ try:
+ ov.WSARecvFrom(conn.fileno(), nbytes, flags)
+ except BrokenPipeError:
+ return self._result((b'', None))
+
+ def finish_recv(trans, key, ov):
+ try:
+ return ov.getresult()
+ except OSError as exc:
+ if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+ _overlapped.ERROR_OPERATION_ABORTED):
+ raise ConnectionResetError(*exc.args)
+ else:
+ raise
+
+ return self._register(ov, conn, finish_recv)
+
+ def sendto(self, conn, buf, flags=0, addr=None):
+ self._register_with_iocp(conn)
+ ov = _overlapped.Overlapped(NULL)
+
+ ov.WSASendTo(conn.fileno(), buf, flags, addr)
+
+ def finish_send(trans, key, ov):
+ try:
+ return ov.getresult()
+ except OSError as exc:
+ if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+ _overlapped.ERROR_OPERATION_ABORTED):
+ raise ConnectionResetError(*exc.args)
+ else:
+ raise
+
+ return self._register(ov, conn, finish_send)
+
def send(self, conn, buf, flags=0):
self._register_with_iocp(conn)
ov = _overlapped.Overlapped(NULL)
@@ -532,6 +570,14 @@ class IocpProactor:
return future
def connect(self, conn, address):
+ if conn.type == socket.SOCK_DGRAM:
+ # WSAConnect will complete immediately for UDP sockets so we don't
+ # need to register any IOCP operation
+ _overlapped.WSAConnect(conn.fileno(), address)
+ fut = self._loop.create_future()
+ fut.set_result(None)
+ return fut
+
self._register_with_iocp(conn)
# The socket needs to be locally bound before we call ConnectEx().
try:
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index e89db99..045654e 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -1249,11 +1249,6 @@ class EventLoopTestsMixin:
server.transport.close()
def test_create_datagram_endpoint_sock(self):
- if (sys.platform == 'win32' and
- isinstance(self.loop, proactor_events.BaseProactorEventLoop)):
- raise unittest.SkipTest(
- 'UDP is not supported with proactor event loops')
-
sock = None
local_address = ('127.0.0.1', 0)
infos = self.loop.run_until_complete(
@@ -2004,10 +1999,6 @@ if sys.platform == 'win32':
def test_writer_callback_cancel(self):
raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
- def test_create_datagram_endpoint(self):
- raise unittest.SkipTest(
- "IocpEventLoop does not have create_datagram_endpoint()")
-
def test_remove_fds_after_closing(self):
raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
else:
diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py
index 5952ccc..2e9995d 100644
--- a/Lib/test/test_asyncio/test_proactor_events.py
+++ b/Lib/test/test_asyncio/test_proactor_events.py
@@ -4,6 +4,7 @@ import io
import socket
import unittest
import sys
+from collections import deque
from unittest import mock
import asyncio
@@ -12,6 +13,7 @@ from asyncio.proactor_events import BaseProactorEventLoop
from asyncio.proactor_events import _ProactorSocketTransport
from asyncio.proactor_events import _ProactorWritePipeTransport
from asyncio.proactor_events import _ProactorDuplexPipeTransport
+from asyncio.proactor_events import _ProactorDatagramTransport
from test import support
from test.test_asyncio import utils as test_utils
@@ -725,6 +727,208 @@ class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase):
self.assertFalse(tr.is_reading())
+class ProactorDatagramTransportTests(test_utils.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.loop = self.new_test_loop()
+ self.proactor = mock.Mock()
+ self.loop._proactor = self.proactor
+ self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
+ self.sock = mock.Mock(spec_set=socket.socket)
+ self.sock.fileno.return_value = 7
+
+ def datagram_transport(self, address=None):
+ self.sock.getpeername.side_effect = None if address else OSError
+ transport = _ProactorDatagramTransport(self.loop, self.sock,
+ self.protocol,
+ address=address)
+ self.addCleanup(close_transport, transport)
+ return transport
+
+ def test_sendto(self):
+ data = b'data'
+ transport = self.datagram_transport()
+ transport.sendto(data, ('0.0.0.0', 1234))
+ self.assertTrue(self.proactor.sendto.called)
+ self.proactor.sendto.assert_called_with(
+ self.sock, data, addr=('0.0.0.0', 1234))
+
+ def test_sendto_bytearray(self):
+ data = bytearray(b'data')
+ transport = self.datagram_transport()
+ transport.sendto(data, ('0.0.0.0', 1234))
+ self.assertTrue(self.proactor.sendto.called)
+ self.proactor.sendto.assert_called_with(
+ self.sock, b'data', addr=('0.0.0.0', 1234))
+
+ def test_sendto_memoryview(self):
+ data = memoryview(b'data')
+ transport = self.datagram_transport()
+ transport.sendto(data, ('0.0.0.0', 1234))
+ self.assertTrue(self.proactor.sendto.called)
+ self.proactor.sendto.assert_called_with(
+ self.sock, b'data', addr=('0.0.0.0', 1234))
+
+ def test_sendto_no_data(self):
+ transport = self.datagram_transport()
+ transport._buffer.append((b'data', ('0.0.0.0', 12345)))
+ transport.sendto(b'', ())
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+ def test_sendto_buffer(self):
+ transport = self.datagram_transport()
+ transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+ transport._write_fut = object()
+ transport.sendto(b'data2', ('0.0.0.0', 12345))
+ self.assertFalse(self.proactor.sendto.called)
+ self.assertEqual(
+ [(b'data1', ('0.0.0.0', 12345)),
+ (b'data2', ('0.0.0.0', 12345))],
+ list(transport._buffer))
+
+ def test_sendto_buffer_bytearray(self):
+ data2 = bytearray(b'data2')
+ transport = self.datagram_transport()
+ transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+ transport._write_fut = object()
+ transport.sendto(data2, ('0.0.0.0', 12345))
+ self.assertFalse(self.proactor.sendto.called)
+ self.assertEqual(
+ [(b'data1', ('0.0.0.0', 12345)),
+ (b'data2', ('0.0.0.0', 12345))],
+ list(transport._buffer))
+ self.assertIsInstance(transport._buffer[1][0], bytes)
+
+ def test_sendto_buffer_memoryview(self):
+ data2 = memoryview(b'data2')
+ transport = self.datagram_transport()
+ transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+ transport._write_fut = object()
+ transport.sendto(data2, ('0.0.0.0', 12345))
+ self.assertFalse(self.proactor.sendto.called)
+ self.assertEqual(
+ [(b'data1', ('0.0.0.0', 12345)),
+ (b'data2', ('0.0.0.0', 12345))],
+ list(transport._buffer))
+ self.assertIsInstance(transport._buffer[1][0], bytes)
+
+ @mock.patch('asyncio.proactor_events.logger')
+ def test_sendto_exception(self, m_log):
+ data = b'data'
+ err = self.proactor.sendto.side_effect = RuntimeError()
+
+ transport = self.datagram_transport()
+ transport._fatal_error = mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertTrue(transport._fatal_error.called)
+ transport._fatal_error.assert_called_with(
+ err,
+ 'Fatal write error on datagram transport')
+ transport._conn_lost = 1
+
+ transport._address = ('123',)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ m_log.warning.assert_called_with('socket.sendto() raised exception.')
+
+ def test_sendto_error_received(self):
+ data = b'data'
+
+ self.sock.sendto.side_effect = ConnectionRefusedError
+
+ transport = self.datagram_transport()
+ transport._fatal_error = mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertEqual(transport._conn_lost, 0)
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_sendto_error_received_connected(self):
+ data = b'data'
+
+ self.proactor.send.side_effect = ConnectionRefusedError
+
+ transport = self.datagram_transport(address=('0.0.0.0', 1))
+ transport._fatal_error = mock.Mock()
+ transport.sendto(data)
+
+ self.assertFalse(transport._fatal_error.called)
+ self.assertTrue(self.protocol.error_received.called)
+
+ def test_sendto_str(self):
+ transport = self.datagram_transport()
+ self.assertRaises(TypeError, transport.sendto, 'str', ())
+
+ def test_sendto_connected_addr(self):
+ transport = self.datagram_transport(address=('0.0.0.0', 1))
+ self.assertRaises(
+ ValueError, transport.sendto, b'str', ('0.0.0.0', 2))
+
+ def test_sendto_closing(self):
+ transport = self.datagram_transport(address=(1,))
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.sendto(b'data', (1,))
+ self.assertEqual(transport._conn_lost, 2)
+
+ def test__loop_writing_closing(self):
+ transport = self.datagram_transport()
+ transport._closing = True
+ transport._loop_writing()
+ self.assertIsNone(transport._write_fut)
+ test_utils.run_briefly(self.loop)
+ self.sock.close.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test__loop_writing_exception(self):
+ err = self.proactor.sendto.side_effect = RuntimeError()
+
+ transport = self.datagram_transport()
+ transport._fatal_error = mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._loop_writing()
+
+ transport._fatal_error.assert_called_with(
+ err,
+ 'Fatal write error on datagram transport')
+
+ def test__loop_writing_error_received(self):
+ self.proactor.sendto.side_effect = ConnectionRefusedError
+
+ transport = self.datagram_transport()
+ transport._fatal_error = mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._loop_writing()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test__loop_writing_error_received_connection(self):
+ self.proactor.send.side_effect = ConnectionRefusedError
+
+ transport = self.datagram_transport(address=('0.0.0.0', 1))
+ transport._fatal_error = mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._loop_writing()
+
+ self.assertFalse(transport._fatal_error.called)
+ self.assertTrue(self.protocol.error_received.called)
+
+ @mock.patch('asyncio.base_events.logger.error')
+ def test_fatal_error_connected(self, m_exc):
+ transport = self.datagram_transport(address=('0.0.0.0', 1))
+ err = ConnectionRefusedError()
+ transport._fatal_error(err)
+ self.assertFalse(self.protocol.error_received.called)
+ m_exc.assert_not_called()
+
+
class BaseProactorEventLoopTests(test_utils.TestCase):
def setUp(self):
@@ -864,6 +1068,80 @@ class BaseProactorEventLoopTests(test_utils.TestCase):
self.assertFalse(sock2.close.called)
self.assertFalse(future2.cancel.called)
+ def datagram_transport(self):
+ self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
+ return self.loop._make_datagram_transport(self.sock, self.protocol)
+
+ def test_make_datagram_transport(self):
+ tr = self.datagram_transport()
+ self.assertIsInstance(tr, _ProactorDatagramTransport)
+ close_transport(tr)
+
+ def test_datagram_loop_writing(self):
+ tr = self.datagram_transport()
+ tr._buffer.appendleft((b'data', ('127.0.0.1', 12068)))
+ tr._loop_writing()
+ self.loop._proactor.sendto.assert_called_with(self.sock, b'data', addr=('127.0.0.1', 12068))
+ self.loop._proactor.sendto.return_value.add_done_callback.\
+ assert_called_with(tr._loop_writing)
+
+ close_transport(tr)
+
+ def test_datagram_loop_reading(self):
+ tr = self.datagram_transport()
+ tr._loop_reading()
+ self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024)
+ self.assertFalse(self.protocol.datagram_received.called)
+ self.assertFalse(self.protocol.error_received.called)
+ close_transport(tr)
+
+ def test_datagram_loop_reading_data(self):
+ res = asyncio.Future(loop=self.loop)
+ res.set_result((b'data', ('127.0.0.1', 12068)))
+
+ tr = self.datagram_transport()
+ tr._read_fut = res
+ tr._loop_reading(res)
+ self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024)
+ self.protocol.datagram_received.assert_called_with(b'data', ('127.0.0.1', 12068))
+ close_transport(tr)
+
+ def test_datagram_loop_reading_no_data(self):
+ res = asyncio.Future(loop=self.loop)
+ res.set_result((b'', ('127.0.0.1', 12068)))
+
+ tr = self.datagram_transport()
+ self.assertRaises(AssertionError, tr._loop_reading, res)
+
+ tr.close = mock.Mock()
+ tr._read_fut = res
+ tr._loop_reading(res)
+ self.assertTrue(self.loop._proactor.recvfrom.called)
+ self.assertFalse(self.protocol.error_received.called)
+ self.assertFalse(tr.close.called)
+ close_transport(tr)
+
+ def test_datagram_loop_reading_aborted(self):
+ err = self.loop._proactor.recvfrom.side_effect = ConnectionAbortedError()
+
+ tr = self.datagram_transport()
+ tr._fatal_error = mock.Mock()
+ tr._protocol.error_received = mock.Mock()
+ tr._loop_reading()
+ tr._protocol.error_received.assert_called_with(err)
+ close_transport(tr)
+
+ def test_datagram_loop_writing_aborted(self):
+ err = self.loop._proactor.sendto.side_effect = ConnectionAbortedError()
+
+ tr = self.datagram_transport()
+ tr._fatal_error = mock.Mock()
+ tr._protocol.error_received = mock.Mock()
+ tr._buffer.appendleft((b'Hello', ('127.0.0.1', 12068)))
+ tr._loop_writing()
+ tr._protocol.error_received.assert_called_with(err)
+ close_transport(tr)
+
@unittest.skipIf(sys.platform != 'win32',
'Proactor is supported on Windows only')