summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio/streams.py
diff options
context:
space:
mode:
authorAndrew Svetlov <andrew.svetlov@gmail.com>2018-09-12 18:43:04 (GMT)
committerGitHub <noreply@github.com>2018-09-12 18:43:04 (GMT)
commita5d1eb8d8b7add31b5f5d9bbb31cee1a491b2c08 (patch)
tree8ffce2f8bcedaea78a0f0eb9c7e1c25f0a32707a /Lib/asyncio/streams.py
parentaca819fb494d4801b3e5b5b507b17cab772c1b40 (diff)
downloadcpython-a5d1eb8d8b7add31b5f5d9bbb31cee1a491b2c08.zip
cpython-a5d1eb8d8b7add31b5f5d9bbb31cee1a491b2c08.tar.gz
cpython-a5d1eb8d8b7add31b5f5d9bbb31cee1a491b2c08.tar.bz2
bpo-34638: Store a weak reference to stream reader to break strong references loop (GH-9201)
Store a weak reference to stream readerfor breaking strong references It breaks the strong reference loop between reader and protocol and allows to detect and close the socket if the stream is deleted (garbage collected)
Diffstat (limited to 'Lib/asyncio/streams.py')
-rw-r--r--Lib/asyncio/streams.py91
1 files changed, 81 insertions, 10 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index 9dab49b..e7fb22e 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -3,6 +3,8 @@ __all__ = (
'open_connection', 'start_server')
import socket
+import sys
+import weakref
if hasattr(socket, 'AF_UNIX'):
__all__ += ('open_unix_connection', 'start_unix_server')
@@ -10,6 +12,7 @@ if hasattr(socket, 'AF_UNIX'):
from . import coroutines
from . import events
from . import exceptions
+from . import format_helpers
from . import protocols
from .log import logger
from .tasks import sleep
@@ -186,46 +189,106 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
call inappropriate methods of the protocol.)
"""
+ _source_traceback = None
+
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
super().__init__(loop=loop)
- self._stream_reader = stream_reader
+ if stream_reader is not None:
+ self._stream_reader_wr = weakref.ref(stream_reader,
+ self._on_reader_gc)
+ self._source_traceback = stream_reader._source_traceback
+ else:
+ self._stream_reader_wr = None
+ if client_connected_cb is not None:
+ # This is a stream created by the `create_server()` function.
+ # Keep a strong reference to the reader until a connection
+ # is established.
+ self._strong_reader = stream_reader
+ self._reject_connection = False
self._stream_writer = None
+ self._transport = None
self._client_connected_cb = client_connected_cb
self._over_ssl = False
self._closed = self._loop.create_future()
+ def _on_reader_gc(self, wr):
+ transport = self._transport
+ if transport is not None:
+ # connection_made was called
+ context = {
+ 'message': ('An open stream object is being garbage '
+ 'collected; call "stream.close()" explicitly.')
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ transport.abort()
+ else:
+ self._reject_connection = True
+ self._stream_reader_wr = None
+
+ def _untrack_reader(self):
+ self._stream_reader_wr = None
+
+ @property
+ def _stream_reader(self):
+ if self._stream_reader_wr is None:
+ return None
+ return self._stream_reader_wr()
+
def connection_made(self, transport):
- self._stream_reader.set_transport(transport)
+ if self._reject_connection:
+ context = {
+ 'message': ('An open stream was garbage collected prior to '
+ 'establishing network connection; '
+ 'call "stream.close()" explicitly.')
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ transport.abort()
+ return
+ self._transport = transport
+ reader = self._stream_reader
+ if reader is not None:
+ reader.set_transport(transport)
self._over_ssl = transport.get_extra_info('sslcontext') is not None
if self._client_connected_cb is not None:
self._stream_writer = StreamWriter(transport, self,
- self._stream_reader,
+ reader,
self._loop)
- res = self._client_connected_cb(self._stream_reader,
+ res = self._client_connected_cb(reader,
self._stream_writer)
if coroutines.iscoroutine(res):
self._loop.create_task(res)
+ self._strong_reader = None
def connection_lost(self, exc):
- if self._stream_reader is not None:
+ reader = self._stream_reader
+ if reader is not None:
if exc is None:
- self._stream_reader.feed_eof()
+ reader.feed_eof()
else:
- self._stream_reader.set_exception(exc)
+ reader.set_exception(exc)
if not self._closed.done():
if exc is None:
self._closed.set_result(None)
else:
self._closed.set_exception(exc)
super().connection_lost(exc)
- self._stream_reader = None
+ self._stream_reader_wr = None
self._stream_writer = None
+ self._transport = None
def data_received(self, data):
- self._stream_reader.feed_data(data)
+ reader = self._stream_reader
+ if reader is not None:
+ reader.feed_data(data)
def eof_received(self):
- self._stream_reader.feed_eof()
+ reader = self._stream_reader
+ if reader is not None:
+ reader.feed_eof()
if self._over_ssl:
# Prevent a warning in SSLProtocol.eof_received:
# "returning true from eof_received()
@@ -282,6 +345,9 @@ class StreamWriter:
return self._transport.can_write_eof()
def close(self):
+ # a reader can be garbage collected
+ # after connection closing
+ self._protocol._untrack_reader()
return self._transport.close()
def is_closing(self):
@@ -318,6 +384,8 @@ class StreamWriter:
class StreamReader:
+ _source_traceback = None
+
def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
# The line length limit is a security feature;
# it also doubles as half the buffer limit.
@@ -336,6 +404,9 @@ class StreamReader:
self._exception = None
self._transport = None
self._paused = False
+ if self._loop.get_debug():
+ self._source_traceback = format_helpers.extract_stack(
+ sys._getframe(1))
def __repr__(self):
info = ['StreamReader']