diff options
-rw-r--r-- | Lib/asyncio/proactor_events.py | 116 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_proactor_events.py | 14 |
2 files changed, 106 insertions, 24 deletions
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index ce226b9..979bc25 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -24,12 +24,14 @@ class _ProactorBasePipeTransport(transports.BaseTransport): self._sock = sock self._protocol = protocol self._server = server - self._buffer = [] + self._buffer = None # None or bytearray. self._read_fut = None self._write_fut = None self._conn_lost = 0 self._closing = False # Set when close() called. self._eof_written = False + self._protocol_paused = False + self.set_write_buffer_limits() if self._server is not None: self._server.attach(self) self._loop.call_soon(self._protocol.connection_made, self) @@ -63,7 +65,7 @@ class _ProactorBasePipeTransport(transports.BaseTransport): if self._read_fut: self._read_fut.cancel() self._write_fut = self._read_fut = None - self._buffer = [] + self._buffer = None self._loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): @@ -82,6 +84,53 @@ class _ProactorBasePipeTransport(transports.BaseTransport): server.detach(self) self._server = None + # XXX The next four methods are nearly identical to corresponding + # ones in _SelectorTransport. Maybe refactor buffer management to + # share the implementations? (Also these are really only needed + # by _ProactorWritePipeTransport but since _buffer is defined on + # the base class I am putting it here for now.) + + def _maybe_pause_protocol(self): + size = self.get_write_buffer_size() + if size <= self._high_water: + return + if not self._protocol_paused: + self._protocol_paused = True + try: + self._protocol.pause_writing() + except Exception: + logger.exception('pause_writing() failed') + + def _maybe_resume_protocol(self): + if (self._protocol_paused and + self.get_write_buffer_size() <= self._low_water): + self._protocol_paused = False + try: + self._protocol.resume_writing() + except Exception: + logger.exception('resume_writing() failed') + + def set_write_buffer_limits(self, high=None, low=None): + if high is None: + if low is None: + high = 64*1024 + else: + high = 4*low + if low is None: + low = high // 4 + if not high >= low >= 0: + raise ValueError('high (%r) must be >= low (%r) must be >= 0' % + (high, low)) + self._high_water = high + self._low_water = low + + def get_write_buffer_size(self): + # NOTE: This doesn't take into account data already passed to + # send() even if send() hasn't finished yet. + if not self._buffer: + return 0 + return len(self._buffer) + class _ProactorReadPipeTransport(_ProactorBasePipeTransport, transports.ReadTransport): @@ -95,12 +144,15 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, self._loop.call_soon(self._loop_reading) def pause_reading(self): - assert not self._closing, 'Cannot pause_reading() when closing' - assert not self._paused, 'Already paused' + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') self._paused = True def resume_reading(self): - assert self._paused, 'Not paused' + if not self._paused: + raise RuntimeError('Not paused') self._paused = False if self._closing: return @@ -155,9 +207,11 @@ class _ProactorWritePipeTransport(_ProactorBasePipeTransport, """Transport for write pipes.""" def write(self, data): - assert isinstance(data, bytes), repr(data) + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) if self._eof_written: - raise IOError('write_eof() already called') + raise RuntimeError('write_eof() already called') if not data: return @@ -167,26 +221,53 @@ class _ProactorWritePipeTransport(_ProactorBasePipeTransport, logger.warning('socket.send() raised exception.') self._conn_lost += 1 return - self._buffer.append(data) - if self._write_fut is None: - self._loop_writing() - def _loop_writing(self, f=None): + # Observable states: + # 1. IDLE: _write_fut and _buffer both None + # 2. WRITING: _write_fut set; _buffer None + # 3. BACKED UP: _write_fut set; _buffer a bytearray + # We always copy the data, so the caller can't modify it + # while we're still waiting for the I/O to happen. + if self._write_fut is None: # IDLE -> WRITING + assert self._buffer is None + # Pass a copy, except if it's already immutable. + self._loop_writing(data=bytes(data)) + # XXX Should we pause the protocol at this point + # if len(data) > self._high_water? (That would + # require keeping track of the number of bytes passed + # to a send() that hasn't finished yet.) + elif not self._buffer: # WRITING -> BACKED UP + # Make a mutable copy which we can extend. + self._buffer = bytearray(data) + self._maybe_pause_protocol() + else: # BACKED UP + # Append to buffer (also copies). + self._buffer.extend(data) + self._maybe_pause_protocol() + + def _loop_writing(self, f=None, data=None): try: assert f is self._write_fut self._write_fut = None if f: f.result() - data = b''.join(self._buffer) - self._buffer = [] + if data is None: + data = self._buffer + self._buffer = None if not data: if self._closing: self._loop.call_soon(self._call_connection_lost, None) if self._eof_written: self._sock.shutdown(socket.SHUT_WR) - return - self._write_fut = self._loop._proactor.send(self._sock, data) - self._write_fut.add_done_callback(self._loop_writing) + else: + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) + # Now that we've reduced the buffer size, tell the + # protocol to resume writing if it was paused. Note that + # we do this last since the callback is called immediately + # and it may add more data to the buffer (even causing the + # protocol to be paused again). + self._maybe_resume_protocol() except ConnectionResetError as exc: self._force_close(exc) except OSError as exc: @@ -330,7 +411,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): self._csock.send(b'x') def _start_serving(self, protocol_factory, sock, ssl=None, server=None): - assert not ssl, 'IocpEventLoop is incompatible with SSL.' + if ssl: + raise ValueError('IocpEventLoop is incompatible with SSL.') def loop(f=None): try: diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py index 5a2a51c..9964f42 100644 --- a/Lib/test/test_asyncio/test_proactor_events.py +++ b/Lib/test/test_asyncio/test_proactor_events.py @@ -111,8 +111,8 @@ class ProactorSocketTransportTests(unittest.TestCase): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._loop_writing = unittest.mock.Mock() tr.write(b'data') - self.assertEqual(tr._buffer, [b'data']) - self.assertTrue(tr._loop_writing.called) + self.assertEqual(tr._buffer, None) + tr._loop_writing.assert_called_with(data=b'data') def test_write_no_data(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -124,12 +124,12 @@ class ProactorSocketTransportTests(unittest.TestCase): tr._write_fut = unittest.mock.Mock() tr._loop_writing = unittest.mock.Mock() tr.write(b'data') - self.assertEqual(tr._buffer, [b'data']) + self.assertEqual(tr._buffer, b'data') self.assertFalse(tr._loop_writing.called) def test_loop_writing(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - tr._buffer = [b'da', b'ta'] + tr._buffer = bytearray(b'data') tr._loop_writing() self.loop._proactor.send.assert_called_with(self.sock, b'data') self.loop._proactor.send.return_value.add_done_callback.\ @@ -150,7 +150,7 @@ class ProactorSocketTransportTests(unittest.TestCase): tr.write(b'data') tr.write(b'data') tr.write(b'data') - self.assertEqual(tr._buffer, []) + self.assertEqual(tr._buffer, None) m_log.warning.assert_called_with('socket.send() raised exception.') def test_loop_writing_stop(self): @@ -226,7 +226,7 @@ class ProactorSocketTransportTests(unittest.TestCase): write_fut.cancel.assert_called_with() test_utils.run_briefly(self.loop) self.protocol.connection_lost.assert_called_with(None) - self.assertEqual([], tr._buffer) + self.assertEqual(None, tr._buffer) self.assertEqual(tr._conn_lost, 1) def test_force_close_idempotent(self): @@ -243,7 +243,7 @@ class ProactorSocketTransportTests(unittest.TestCase): test_utils.run_briefly(self.loop) self.protocol.connection_lost.assert_called_with(None) - self.assertEqual([], tr._buffer) + self.assertEqual(None, tr._buffer) def test_call_connection_lost(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) |