From b507cbaac5921023c17068b616efdbbecbd89920 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 00:35:56 +0100 Subject: asyncio: Fix SSLProtocol.eof_received() Wake-up the waiter if it is not done yet. --- Lib/asyncio/sslproto.py | 4 ++++ Lib/test/test_asyncio/test_sslproto.py | 40 ++++++++++++++++++++++++---------- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index f2b856c..26937c8 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -489,6 +489,10 @@ class SSLProtocol(protocols.Protocol): try: if self._loop.get_debug(): logger.debug("%r received EOF", self) + + if self._waiter is not None and not self._waiter.done(): + self._waiter.set_exception(ConnectionResetError()) + if not self._in_handshake: keep_open = self._app_protocol.eof_received() if keep_open: diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index b1a61c4..148e30d 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -12,21 +12,36 @@ from asyncio import sslproto from asyncio import test_utils +@unittest.skipIf(ssl is None, 'No ssl module') class SslProtoHandshakeTests(test_utils.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() self.set_event_loop(self.loop) - @unittest.skipIf(ssl is None, 'No ssl module') + def ssl_protocol(self, waiter=None): + sslcontext = test_utils.dummy_ssl_context() + app_proto = asyncio.Protocol() + return sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter) + + def connection_made(self, ssl_proto, do_handshake=None): + transport = mock.Mock() + sslpipe = mock.Mock() + sslpipe.shutdown.return_value = b'' + if do_handshake: + sslpipe.do_handshake.side_effect = do_handshake + else: + def mock_handshake(callback): + return [] + sslpipe.do_handshake.side_effect = mock_handshake + with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): + ssl_proto.connection_made(transport) + def test_cancel_handshake(self): # Python issue #23197: cancelling an handshake must not raise an # exception or log an error, even if the handshake failed - sslcontext = test_utils.dummy_ssl_context() - app_proto = asyncio.Protocol() waiter = asyncio.Future(loop=self.loop) - ssl_proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, - waiter) + ssl_proto = self.ssl_protocol(waiter) handshake_fut = asyncio.Future(loop=self.loop) def do_handshake(callback): @@ -36,12 +51,7 @@ class SslProtoHandshakeTests(test_utils.TestCase): return [] waiter.cancel() - transport = mock.Mock() - sslpipe = mock.Mock() - sslpipe.shutdown.return_value = b'' - sslpipe.do_handshake.side_effect = do_handshake - with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): - ssl_proto.connection_made(transport) + self.connection_made(ssl_proto, do_handshake) with test_utils.disable_logger(): self.loop.run_until_complete(handshake_fut) @@ -49,6 +59,14 @@ class SslProtoHandshakeTests(test_utils.TestCase): # Close the transport ssl_proto._app_transport.close() + def test_eof_received_waiter(self): + waiter = asyncio.Future(loop=self.loop) + ssl_proto = self.ssl_protocol(waiter) + self.connection_made(ssl_proto) + ssl_proto.eof_received() + test_utils.run_briefly(self.loop) + self.assertIsInstance(waiter.exception(), ConnectionResetError) + if __name__ == '__main__': unittest.main() -- cgit v0.12 From f07801bb17f8089dc8b8a4d2beafba7c497af900 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 00:36:35 +0100 Subject: asyncio: SSL transports now clear their reference to the waiter * Rephrase also the comment explaining why the waiter is not awaken immediatly. * SSLProtocol.eof_received() doesn't instanciate ConnectionResetError exception directly, it will be done by Future.set_exception(). The exception is not used if the waiter was cancelled or if there is no waiter. --- Lib/asyncio/proactor_events.py | 2 +- Lib/asyncio/selector_events.py | 27 ++++++++++++++++----------- Lib/asyncio/sslproto.py | 20 +++++++++++++------- Lib/asyncio/unix_events.py | 4 ++-- 4 files changed, 32 insertions(+), 21 deletions(-) diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index ed17062..0f533a5 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -38,7 +38,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, self._server._attach() self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 24f8461..42d88f5 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -581,7 +581,7 @@ class _SelectorSocketTransport(_SelectorTransport): self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def pause_reading(self): @@ -732,6 +732,16 @@ class _SelectorSslTransport(_SelectorTransport): start_time = None self._on_handshake(start_time) + def _wakeup_waiter(self, exc=None): + if self._waiter is None: + return + if not self._waiter.cancelled(): + if exc is not None: + self._waiter.set_exception(exc) + else: + self._waiter.set_result(None) + self._waiter = None + def _on_handshake(self, start_time): try: self._sock.do_handshake() @@ -750,8 +760,7 @@ class _SelectorSslTransport(_SelectorTransport): self._loop.remove_reader(self._sock_fd) self._loop.remove_writer(self._sock_fd) self._sock.close() - if self._waiter is not None and not self._waiter.cancelled(): - self._waiter.set_exception(exc) + self._wakeup_waiter(exc) if isinstance(exc, Exception): return else: @@ -774,9 +783,7 @@ class _SelectorSslTransport(_SelectorTransport): "on matching the hostname", self, exc_info=True) self._sock.close() - if (self._waiter is not None - and not self._waiter.cancelled()): - self._waiter.set_exception(exc) + self._wakeup_waiter(exc) return # Add extra info that becomes available after handshake. @@ -789,10 +796,8 @@ class _SelectorSslTransport(_SelectorTransport): self._write_wants_read = False self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) - if self._waiter is not None: - # wait until protocol.connection_made() has been called - self._loop.call_soon(self._waiter._set_result_unless_cancelled, - None) + # only wake up the waiter when connection_made() has been called + self._loop.call_soon(self._wakeup_waiter) if self._loop.get_debug(): dt = self._loop.time() - start_time @@ -924,7 +929,7 @@ class _SelectorDatagramTransport(_SelectorTransport): self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def get_write_buffer_size(self): diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 26937c8..fc809b9 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -418,6 +418,16 @@ class SSLProtocol(protocols.Protocol): self._in_shutdown = False self._transport = None + def _wakeup_waiter(self, exc=None): + if self._waiter is None: + return + if not self._waiter.cancelled(): + if exc is not None: + self._waiter.set_exception(exc) + else: + self._waiter.set_result(None) + self._waiter = None + def connection_made(self, transport): """Called when the low-level connection is made. @@ -490,8 +500,7 @@ class SSLProtocol(protocols.Protocol): if self._loop.get_debug(): logger.debug("%r received EOF", self) - if self._waiter is not None and not self._waiter.done(): - self._waiter.set_exception(ConnectionResetError()) + self._wakeup_waiter(ConnectionResetError) if not self._in_handshake: keep_open = self._app_protocol.eof_received() @@ -556,8 +565,7 @@ class SSLProtocol(protocols.Protocol): self, exc_info=True) self._transport.close() if isinstance(exc, Exception): - if self._waiter is not None and not self._waiter.cancelled(): - self._waiter.set_exception(exc) + self._wakeup_waiter(exc) return else: raise @@ -572,9 +580,7 @@ class SSLProtocol(protocols.Protocol): compression=sslobj.compression(), ) self._app_protocol.connection_made(self._app_transport) - if self._waiter is not None: - # wait until protocol.connection_made() has been called - self._waiter._set_result_unless_cancelled(None) + self._wakeup_waiter() self._session_established = True # In case transport.write() was already called. Don't call # immediatly _process_write_backlog(), but schedule it: diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index 97f9add..67973f1 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -301,7 +301,7 @@ class _UnixReadPipeTransport(transports.ReadTransport): self._loop.add_reader(self._fileno, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): @@ -409,7 +409,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): -- cgit v0.12 From fa73779b0a54211e99bd1e76511f30352c7055e9 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 00:36:51 +0100 Subject: asyncio: Fix _SelectorSocketTransport constructor Only start reading when connection_made() has been called: protocol.data_received() must not be called before protocol.connection_made(). --- Lib/asyncio/selector_events.py | 4 +++- Lib/test/test_asyncio/test_selector_events.py | 16 +++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 42d88f5..f499629 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -578,8 +578,10 @@ class _SelectorSocketTransport(_SelectorTransport): self._eof = False self._paused = False - self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) + # only start reading when connection_made() has been called + self._loop.call_soon(self._loop.add_reader, + self._sock_fd, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index ad86ada..5152616 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -59,6 +59,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): def test_make_socket_transport(self): m = mock.Mock() self.loop.add_reader = mock.Mock() + self.loop.add_reader._is_coroutine = False transport = self.loop._make_socket_transport(m, asyncio.Protocol()) self.assertIsInstance(transport, _SelectorSocketTransport) close_transport(transport) @@ -67,6 +68,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): def test_make_ssl_transport(self): m = mock.Mock() self.loop.add_reader = mock.Mock() + self.loop.add_reader._is_coroutine = False self.loop.add_writer = mock.Mock() self.loop.remove_reader = mock.Mock() self.loop.remove_writer = mock.Mock() @@ -770,20 +772,24 @@ class SelectorSocketTransportTests(test_utils.TestCase): return transport def test_ctor(self): - tr = self.socket_transport() + waiter = asyncio.Future(loop=self.loop) + tr = self.socket_transport(waiter=waiter) + self.loop.run_until_complete(waiter) + self.loop.assert_reader(7, tr._read_ready) test_utils.run_briefly(self.loop) self.protocol.connection_made.assert_called_with(tr) def test_ctor_with_waiter(self): - fut = asyncio.Future(loop=self.loop) + waiter = asyncio.Future(loop=self.loop) + self.socket_transport(waiter=waiter) + self.loop.run_until_complete(waiter) - self.socket_transport(waiter=fut) - test_utils.run_briefly(self.loop) - self.assertIsNone(fut.result()) + self.assertIsNone(waiter.result()) def test_pause_resume_reading(self): tr = self.socket_transport() + test_utils.run_briefly(self.loop) self.assertFalse(tr._paused) self.loop.assert_reader(7, tr._read_ready) tr.pause_reading() -- cgit v0.12