summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/asyncio/proactor_events.py116
-rw-r--r--Lib/test/test_asyncio/test_proactor_events.py14
2 files changed, 106 insertions, 24 deletions
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index ce226b9..979bc25 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -24,12 +24,14 @@ class _ProactorBasePipeTransport(transports.BaseTransport):
self._sock = sock
self._protocol = protocol
self._server = server
- self._buffer = []
+ self._buffer = None # None or bytearray.
self._read_fut = None
self._write_fut = None
self._conn_lost = 0
self._closing = False # Set when close() called.
self._eof_written = False
+ self._protocol_paused = False
+ self.set_write_buffer_limits()
if self._server is not None:
self._server.attach(self)
self._loop.call_soon(self._protocol.connection_made, self)
@@ -63,7 +65,7 @@ class _ProactorBasePipeTransport(transports.BaseTransport):
if self._read_fut:
self._read_fut.cancel()
self._write_fut = self._read_fut = None
- self._buffer = []
+ self._buffer = None
self._loop.call_soon(self._call_connection_lost, exc)
def _call_connection_lost(self, exc):
@@ -82,6 +84,53 @@ class _ProactorBasePipeTransport(transports.BaseTransport):
server.detach(self)
self._server = None
+ # XXX The next four methods are nearly identical to corresponding
+ # ones in _SelectorTransport. Maybe refactor buffer management to
+ # share the implementations? (Also these are really only needed
+ # by _ProactorWritePipeTransport but since _buffer is defined on
+ # the base class I am putting it here for now.)
+
+ def _maybe_pause_protocol(self):
+ size = self.get_write_buffer_size()
+ if size <= self._high_water:
+ return
+ if not self._protocol_paused:
+ self._protocol_paused = True
+ try:
+ self._protocol.pause_writing()
+ except Exception:
+ logger.exception('pause_writing() failed')
+
+ def _maybe_resume_protocol(self):
+ if (self._protocol_paused and
+ self.get_write_buffer_size() <= self._low_water):
+ self._protocol_paused = False
+ try:
+ self._protocol.resume_writing()
+ except Exception:
+ logger.exception('resume_writing() failed')
+
+ def set_write_buffer_limits(self, high=None, low=None):
+ if high is None:
+ if low is None:
+ high = 64*1024
+ else:
+ high = 4*low
+ if low is None:
+ low = high // 4
+ if not high >= low >= 0:
+ raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
+ (high, low))
+ self._high_water = high
+ self._low_water = low
+
+ def get_write_buffer_size(self):
+ # NOTE: This doesn't take into account data already passed to
+ # send() even if send() hasn't finished yet.
+ if not self._buffer:
+ return 0
+ return len(self._buffer)
+
class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
transports.ReadTransport):
@@ -95,12 +144,15 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
self._loop.call_soon(self._loop_reading)
def pause_reading(self):
- assert not self._closing, 'Cannot pause_reading() when closing'
- assert not self._paused, 'Already paused'
+ if self._closing:
+ raise RuntimeError('Cannot pause_reading() when closing')
+ if self._paused:
+ raise RuntimeError('Already paused')
self._paused = True
def resume_reading(self):
- assert self._paused, 'Not paused'
+ if not self._paused:
+ raise RuntimeError('Not paused')
self._paused = False
if self._closing:
return
@@ -155,9 +207,11 @@ class _ProactorWritePipeTransport(_ProactorBasePipeTransport,
"""Transport for write pipes."""
def write(self, data):
- assert isinstance(data, bytes), repr(data)
+ if not isinstance(data, (bytes, bytearray, memoryview)):
+ raise TypeError('data argument must be byte-ish (%r)',
+ type(data))
if self._eof_written:
- raise IOError('write_eof() already called')
+ raise RuntimeError('write_eof() already called')
if not data:
return
@@ -167,26 +221,53 @@ class _ProactorWritePipeTransport(_ProactorBasePipeTransport,
logger.warning('socket.send() raised exception.')
self._conn_lost += 1
return
- self._buffer.append(data)
- if self._write_fut is None:
- self._loop_writing()
- def _loop_writing(self, f=None):
+ # Observable states:
+ # 1. IDLE: _write_fut and _buffer both None
+ # 2. WRITING: _write_fut set; _buffer None
+ # 3. BACKED UP: _write_fut set; _buffer a bytearray
+ # We always copy the data, so the caller can't modify it
+ # while we're still waiting for the I/O to happen.
+ if self._write_fut is None: # IDLE -> WRITING
+ assert self._buffer is None
+ # Pass a copy, except if it's already immutable.
+ self._loop_writing(data=bytes(data))
+ # XXX Should we pause the protocol at this point
+ # if len(data) > self._high_water? (That would
+ # require keeping track of the number of bytes passed
+ # to a send() that hasn't finished yet.)
+ elif not self._buffer: # WRITING -> BACKED UP
+ # Make a mutable copy which we can extend.
+ self._buffer = bytearray(data)
+ self._maybe_pause_protocol()
+ else: # BACKED UP
+ # Append to buffer (also copies).
+ self._buffer.extend(data)
+ self._maybe_pause_protocol()
+
+ def _loop_writing(self, f=None, data=None):
try:
assert f is self._write_fut
self._write_fut = None
if f:
f.result()
- data = b''.join(self._buffer)
- self._buffer = []
+ if data is None:
+ data = self._buffer
+ self._buffer = None
if not data:
if self._closing:
self._loop.call_soon(self._call_connection_lost, None)
if self._eof_written:
self._sock.shutdown(socket.SHUT_WR)
- return
- self._write_fut = self._loop._proactor.send(self._sock, data)
- self._write_fut.add_done_callback(self._loop_writing)
+ else:
+ self._write_fut = self._loop._proactor.send(self._sock, data)
+ self._write_fut.add_done_callback(self._loop_writing)
+ # Now that we've reduced the buffer size, tell the
+ # protocol to resume writing if it was paused. Note that
+ # we do this last since the callback is called immediately
+ # and it may add more data to the buffer (even causing the
+ # protocol to be paused again).
+ self._maybe_resume_protocol()
except ConnectionResetError as exc:
self._force_close(exc)
except OSError as exc:
@@ -330,7 +411,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self._csock.send(b'x')
def _start_serving(self, protocol_factory, sock, ssl=None, server=None):
- assert not ssl, 'IocpEventLoop is incompatible with SSL.'
+ if ssl:
+ raise ValueError('IocpEventLoop is incompatible with SSL.')
def loop(f=None):
try:
diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py
index 5a2a51c..9964f42 100644
--- a/Lib/test/test_asyncio/test_proactor_events.py
+++ b/Lib/test/test_asyncio/test_proactor_events.py
@@ -111,8 +111,8 @@ class ProactorSocketTransportTests(unittest.TestCase):
tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
tr._loop_writing = unittest.mock.Mock()
tr.write(b'data')
- self.assertEqual(tr._buffer, [b'data'])
- self.assertTrue(tr._loop_writing.called)
+ self.assertEqual(tr._buffer, None)
+ tr._loop_writing.assert_called_with(data=b'data')
def test_write_no_data(self):
tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
@@ -124,12 +124,12 @@ class ProactorSocketTransportTests(unittest.TestCase):
tr._write_fut = unittest.mock.Mock()
tr._loop_writing = unittest.mock.Mock()
tr.write(b'data')
- self.assertEqual(tr._buffer, [b'data'])
+ self.assertEqual(tr._buffer, b'data')
self.assertFalse(tr._loop_writing.called)
def test_loop_writing(self):
tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
- tr._buffer = [b'da', b'ta']
+ tr._buffer = bytearray(b'data')
tr._loop_writing()
self.loop._proactor.send.assert_called_with(self.sock, b'data')
self.loop._proactor.send.return_value.add_done_callback.\
@@ -150,7 +150,7 @@ class ProactorSocketTransportTests(unittest.TestCase):
tr.write(b'data')
tr.write(b'data')
tr.write(b'data')
- self.assertEqual(tr._buffer, [])
+ self.assertEqual(tr._buffer, None)
m_log.warning.assert_called_with('socket.send() raised exception.')
def test_loop_writing_stop(self):
@@ -226,7 +226,7 @@ class ProactorSocketTransportTests(unittest.TestCase):
write_fut.cancel.assert_called_with()
test_utils.run_briefly(self.loop)
self.protocol.connection_lost.assert_called_with(None)
- self.assertEqual([], tr._buffer)
+ self.assertEqual(None, tr._buffer)
self.assertEqual(tr._conn_lost, 1)
def test_force_close_idempotent(self):
@@ -243,7 +243,7 @@ class ProactorSocketTransportTests(unittest.TestCase):
test_utils.run_briefly(self.loop)
self.protocol.connection_lost.assert_called_with(None)
- self.assertEqual([], tr._buffer)
+ self.assertEqual(None, tr._buffer)
def test_call_connection_lost(self):
tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)