diff options
-rw-r--r-- | Lib/socketserver.py | 73 | ||||
-rw-r--r-- | Lib/test/test_socketserver.py | 24 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2020-06-12-21-23-20.bpo-37193.wJximU.rst | 2 |
3 files changed, 13 insertions, 86 deletions
diff --git a/Lib/socketserver.py b/Lib/socketserver.py index 6859b69..57c1ae6 100644 --- a/Lib/socketserver.py +++ b/Lib/socketserver.py @@ -128,7 +128,6 @@ import selectors import os import sys import threading -import contextlib from io import BufferedIOBase from time import monotonic as time @@ -629,55 +628,6 @@ if hasattr(os, "fork"): self.collect_children(blocking=self.block_on_close) -class _Threads(list): - """ - Joinable list of all non-daemon threads. - """ - def __init__(self): - self._lock = threading.Lock() - - def append(self, thread): - if thread.daemon: - return - with self._lock: - super().append(thread) - - def remove(self, thread): - with self._lock: - # should not happen, but safe to ignore - with contextlib.suppress(ValueError): - super().remove(thread) - - def remove_current(self): - """Remove a current non-daemon thread.""" - thread = threading.current_thread() - if not thread.daemon: - self.remove(thread) - - def pop_all(self): - with self._lock: - self[:], result = [], self[:] - return result - - def join(self): - for thread in self.pop_all(): - thread.join() - - -class _NoThreads: - """ - Degenerate version of _Threads. - """ - def append(self, thread): - pass - - def join(self): - pass - - def remove_current(self): - pass - - class ThreadingMixIn: """Mix-in class to handle each request in a new thread.""" @@ -686,9 +636,9 @@ class ThreadingMixIn: daemon_threads = False # If true, server_close() waits until all non-daemonic threads terminate. block_on_close = True - # Threads object + # For non-daemonic threads, list of threading.Threading objects # used by server_close() to wait for all threads completion. - _threads = _NoThreads() + _threads = None def process_request_thread(self, request, client_address): """Same as in BaseServer but as a thread. @@ -701,24 +651,27 @@ class ThreadingMixIn: except Exception: self.handle_error(request, client_address) finally: - try: - self.shutdown_request(request) - finally: - self._threads.remove_current() + self.shutdown_request(request) def process_request(self, request, client_address): """Start a new thread to process the request.""" - if self.block_on_close: - vars(self).setdefault('_threads', _Threads()) t = threading.Thread(target = self.process_request_thread, args = (request, client_address)) t.daemon = self.daemon_threads - self._threads.append(t) + if not t.daemon and self.block_on_close: + if self._threads is None: + self._threads = [] + self._threads.append(t) t.start() def server_close(self): super().server_close() - self._threads.join() + if self.block_on_close: + threads = self._threads + self._threads = None + if threads: + for thread in threads: + thread.join() if hasattr(os, "fork"): diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 1944795f..7cdd115 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -277,13 +277,6 @@ class SocketServerTest(unittest.TestCase): t.join() s.server_close() - def test_close_immediately(self): - class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer): - pass - - server = MyServer((HOST, 0), lambda: None) - server.server_close() - def test_tcpserver_bind_leak(self): # Issue #22435: the server socket wouldn't be closed if bind()/listen() # failed. @@ -498,23 +491,6 @@ class MiscTestCase(unittest.TestCase): self.assertEqual(server.shutdown_called, 1) server.server_close() - def test_threads_reaped(self): - """ - In #37193, users reported a memory leak - due to the saving of every request thread. Ensure that the - threads are cleaned up after the requests complete. - """ - class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer): - pass - - server = MyServer((HOST, 0), socketserver.StreamRequestHandler) - for n in range(10): - with socket.create_connection(server.server_address): - server.handle_request() - [thread.join() for thread in server._threads] - self.assertEqual(len(server._threads), 0) - server.server_close() - if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS.d/next/Library/2020-06-12-21-23-20.bpo-37193.wJximU.rst b/Misc/NEWS.d/next/Library/2020-06-12-21-23-20.bpo-37193.wJximU.rst deleted file mode 100644 index fbf56d3..0000000 --- a/Misc/NEWS.d/next/Library/2020-06-12-21-23-20.bpo-37193.wJximU.rst +++ /dev/null @@ -1,2 +0,0 @@ -Fixed memory leak in ``socketserver.ThreadingMixIn`` introduced in Python -3.7. |