diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/socketserver.py | 24 | ||||
-rw-r--r-- | Lib/test/test_socketserver.py | 79 | ||||
-rw-r--r-- | Lib/wsgiref/simple_server.py | 17 |
3 files changed, 107 insertions, 13 deletions
diff --git a/Lib/socketserver.py b/Lib/socketserver.py index c6d38c7..41a3766 100644 --- a/Lib/socketserver.py +++ b/Lib/socketserver.py @@ -132,6 +132,7 @@ try: import threading except ImportError: import dummy_threading as threading +from io import BufferedIOBase from time import monotonic as time __all__ = ["BaseServer", "TCPServer", "UDPServer", @@ -743,7 +744,10 @@ class StreamRequestHandler(BaseRequestHandler): self.connection.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) self.rfile = self.connection.makefile('rb', self.rbufsize) - self.wfile = self.connection.makefile('wb', self.wbufsize) + if self.wbufsize == 0: + self.wfile = _SocketWriter(self.connection) + else: + self.wfile = self.connection.makefile('wb', self.wbufsize) def finish(self): if not self.wfile.closed: @@ -756,6 +760,24 @@ class StreamRequestHandler(BaseRequestHandler): self.wfile.close() self.rfile.close() +class _SocketWriter(BufferedIOBase): + """Simple writable BufferedIOBase implementation for a socket + + Does not hold data in a buffer, avoiding any need to call flush().""" + + def __init__(self, sock): + self._sock = sock + + def writable(self): + return True + + def write(self, b): + self._sock.sendall(b) + with memoryview(b) as view: + return view.nbytes + + def fileno(self): + return self._sock.fileno() class DatagramRequestHandler(BaseRequestHandler): diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 9a90729..3f4dfa1 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -3,6 +3,7 @@ Test suite for socketserver. """ import contextlib +import io import os import select import signal @@ -376,6 +377,84 @@ if HAVE_FORKING: self.active_children.clear() +class SocketWriterTest(unittest.TestCase): + def test_basics(self): + class Handler(socketserver.StreamRequestHandler): + def handle(self): + self.server.wfile = self.wfile + self.server.wfile_fileno = self.wfile.fileno() + self.server.request_fileno = self.request.fileno() + + server = socketserver.TCPServer((HOST, 0), Handler) + self.addCleanup(server.server_close) + s = socket.socket( + server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) + with s: + s.connect(server.server_address) + server.handle_request() + self.assertIsInstance(server.wfile, io.BufferedIOBase) + self.assertEqual(server.wfile_fileno, server.request_fileno) + + @unittest.skipUnless(threading, 'Threading required for this test.') + def test_write(self): + # Test that wfile.write() sends data immediately, and that it does + # not truncate sends when interrupted by a Unix signal + pthread_kill = test.support.get_attribute(signal, 'pthread_kill') + + class Handler(socketserver.StreamRequestHandler): + def handle(self): + self.server.sent1 = self.wfile.write(b'write data\n') + # Should be sent immediately, without requiring flush() + self.server.received = self.rfile.readline() + big_chunk = bytes(test.support.SOCK_MAX_SIZE) + self.server.sent2 = self.wfile.write(big_chunk) + + server = socketserver.TCPServer((HOST, 0), Handler) + self.addCleanup(server.server_close) + interrupted = threading.Event() + + def signal_handler(signum, frame): + interrupted.set() + + original = signal.signal(signal.SIGUSR1, signal_handler) + self.addCleanup(signal.signal, signal.SIGUSR1, original) + response1 = None + received2 = None + main_thread = threading.get_ident() + + def run_client(): + s = socket.socket(server.address_family, socket.SOCK_STREAM, + socket.IPPROTO_TCP) + with s, s.makefile('rb') as reader: + s.connect(server.server_address) + nonlocal response1 + response1 = reader.readline() + s.sendall(b'client response\n') + + reader.read(100) + # The main thread should now be blocking in a send() syscall. + # But in theory, it could get interrupted by other signals, + # and then retried. So keep sending the signal in a loop, in + # case an earlier signal happens to be delivered at an + # inconvenient moment. + while True: + pthread_kill(main_thread, signal.SIGUSR1) + if interrupted.wait(timeout=float(1)): + break + nonlocal received2 + received2 = len(reader.read()) + + background = threading.Thread(target=run_client) + background.start() + server.handle_request() + background.join() + self.assertEqual(server.sent1, len(response1)) + self.assertEqual(response1, b'write data\n') + self.assertEqual(server.received, b'client response\n') + self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE) + self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100) + + class MiscTestCase(unittest.TestCase): def test_all(self): diff --git a/Lib/wsgiref/simple_server.py b/Lib/wsgiref/simple_server.py index da74d7b..f71563a 100644 --- a/Lib/wsgiref/simple_server.py +++ b/Lib/wsgiref/simple_server.py @@ -11,7 +11,6 @@ module. See also the BaseHTTPServer module docs for other API information. """ from http.server import BaseHTTPRequestHandler, HTTPServer -from io import BufferedWriter import sys import urllib.parse from wsgiref.handlers import SimpleHandler @@ -127,17 +126,11 @@ class WSGIRequestHandler(BaseHTTPRequestHandler): if not self.parse_request(): # An error code has been sent, just exit return - # Avoid passing the raw file object wfile, which can do partial - # writes (Issue 24291) - stdout = BufferedWriter(self.wfile) - try: - handler = ServerHandler( - self.rfile, stdout, self.get_stderr(), self.get_environ() - ) - handler.request_handler = self # backpointer for logging - handler.run(self.server.get_app()) - finally: - stdout.detach() + handler = ServerHandler( + self.rfile, self.wfile, self.get_stderr(), self.get_environ() + ) + handler.request_handler = self # backpointer for logging + handler.run(self.server.get_app()) |