diff options
Diffstat (limited to 'Lib/test/test_socketserver.py')
-rw-r--r-- | Lib/test/test_socketserver.py | 183 |
1 files changed, 179 insertions, 4 deletions
diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 0d0f86f..3f4dfa1 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -3,12 +3,11 @@ Test suite for socketserver. """ import contextlib +import io import os import select import signal import socket -import select -import errno import tempfile import unittest import socketserver @@ -46,7 +45,7 @@ def receive(sock, n, timeout=20): else: raise RuntimeError("timed out on %r" % (sock,)) -if HAVE_UNIX_SOCKETS: +if HAVE_UNIX_SOCKETS and HAVE_FORKING: class ForkingUnixStreamServer(socketserver.ForkingMixIn, socketserver.UnixStreamServer): pass @@ -58,6 +57,7 @@ if HAVE_UNIX_SOCKETS: @contextlib.contextmanager def simple_subprocess(testcase): + """Tests that a custom child process is not waited on (Issue 1540386)""" pid = os.fork() if pid == 0: # Don't raise an exception; it would be caught by the test harness. @@ -103,7 +103,6 @@ class SocketServerTest(unittest.TestCase): class MyServer(svrcls): def handle_error(self, request, client_address): self.close_request(request) - self.server_close() raise class MyHandler(hdlrbase): @@ -279,6 +278,182 @@ class SocketServerTest(unittest.TestCase): socketserver.TCPServer((HOST, -1), socketserver.StreamRequestHandler) + def test_context_manager(self): + with socketserver.TCPServer((HOST, 0), + socketserver.StreamRequestHandler) as server: + pass + self.assertEqual(-1, server.socket.fileno()) + + +class ErrorHandlerTest(unittest.TestCase): + """Test that the servers pass normal exceptions from the handler to + handle_error(), and that exiting exceptions like SystemExit and + KeyboardInterrupt are not passed.""" + + def tearDown(self): + test.support.unlink(test.support.TESTFN) + + def test_sync_handled(self): + BaseErrorTestServer(ValueError) + self.check_result(handled=True) + + def test_sync_not_handled(self): + with self.assertRaises(SystemExit): + BaseErrorTestServer(SystemExit) + self.check_result(handled=False) + + @unittest.skipUnless(threading, 'Threading required for this test.') + def test_threading_handled(self): + ThreadingErrorTestServer(ValueError) + self.check_result(handled=True) + + @unittest.skipUnless(threading, 'Threading required for this test.') + def test_threading_not_handled(self): + ThreadingErrorTestServer(SystemExit) + self.check_result(handled=False) + + @requires_forking + def test_forking_handled(self): + ForkingErrorTestServer(ValueError) + self.check_result(handled=True) + + @requires_forking + def test_forking_not_handled(self): + ForkingErrorTestServer(SystemExit) + self.check_result(handled=False) + + def check_result(self, handled): + with open(test.support.TESTFN) as log: + expected = 'Handler called\n' + 'Error handled\n' * handled + self.assertEqual(log.read(), expected) + + +class BaseErrorTestServer(socketserver.TCPServer): + def __init__(self, exception): + self.exception = exception + super().__init__((HOST, 0), BadHandler) + with socket.create_connection(self.server_address): + pass + try: + self.handle_request() + finally: + self.server_close() + self.wait_done() + + def handle_error(self, request, client_address): + with open(test.support.TESTFN, 'a') as log: + log.write('Error handled\n') + + def wait_done(self): + pass + + +class BadHandler(socketserver.BaseRequestHandler): + def handle(self): + with open(test.support.TESTFN, 'a') as log: + log.write('Handler called\n') + raise self.server.exception('Test error') + + +class ThreadingErrorTestServer(socketserver.ThreadingMixIn, + BaseErrorTestServer): + def __init__(self, *pos, **kw): + self.done = threading.Event() + super().__init__(*pos, **kw) + + def shutdown_request(self, *pos, **kw): + super().shutdown_request(*pos, **kw) + self.done.set() + + def wait_done(self): + self.done.wait() + + +if HAVE_FORKING: + class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): + def wait_done(self): + [child] = self.active_children + os.waitpid(child, 0) + self.active_children.clear() + + +class SocketWriterTest(unittest.TestCase): + def test_basics(self): + class Handler(socketserver.StreamRequestHandler): + def handle(self): + self.server.wfile = self.wfile + self.server.wfile_fileno = self.wfile.fileno() + self.server.request_fileno = self.request.fileno() + + server = socketserver.TCPServer((HOST, 0), Handler) + self.addCleanup(server.server_close) + s = socket.socket( + server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) + with s: + s.connect(server.server_address) + server.handle_request() + self.assertIsInstance(server.wfile, io.BufferedIOBase) + self.assertEqual(server.wfile_fileno, server.request_fileno) + + @unittest.skipUnless(threading, 'Threading required for this test.') + def test_write(self): + # Test that wfile.write() sends data immediately, and that it does + # not truncate sends when interrupted by a Unix signal + pthread_kill = test.support.get_attribute(signal, 'pthread_kill') + + class Handler(socketserver.StreamRequestHandler): + def handle(self): + self.server.sent1 = self.wfile.write(b'write data\n') + # Should be sent immediately, without requiring flush() + self.server.received = self.rfile.readline() + big_chunk = bytes(test.support.SOCK_MAX_SIZE) + self.server.sent2 = self.wfile.write(big_chunk) + + server = socketserver.TCPServer((HOST, 0), Handler) + self.addCleanup(server.server_close) + interrupted = threading.Event() + + def signal_handler(signum, frame): + interrupted.set() + + original = signal.signal(signal.SIGUSR1, signal_handler) + self.addCleanup(signal.signal, signal.SIGUSR1, original) + response1 = None + received2 = None + main_thread = threading.get_ident() + + def run_client(): + s = socket.socket(server.address_family, socket.SOCK_STREAM, + socket.IPPROTO_TCP) + with s, s.makefile('rb') as reader: + s.connect(server.server_address) + nonlocal response1 + response1 = reader.readline() + s.sendall(b'client response\n') + + reader.read(100) + # The main thread should now be blocking in a send() syscall. + # But in theory, it could get interrupted by other signals, + # and then retried. So keep sending the signal in a loop, in + # case an earlier signal happens to be delivered at an + # inconvenient moment. + while True: + pthread_kill(main_thread, signal.SIGUSR1) + if interrupted.wait(timeout=float(1)): + break + nonlocal received2 + received2 = len(reader.read()) + + background = threading.Thread(target=run_client) + background.start() + server.handle_request() + background.join() + self.assertEqual(server.sent1, len(response1)) + self.assertEqual(response1, b'write data\n') + self.assertEqual(server.received, b'client response\n') + self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE) + self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100) + class MiscTestCase(unittest.TestCase): |