summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_socketserver.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_socketserver.py')
-rw-r--r--Lib/test/test_socketserver.py112
1 files changed, 35 insertions, 77 deletions
diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py
index 658056c..7fe746d 100644
--- a/Lib/test/test_socketserver.py
+++ b/Lib/test/test_socketserver.py
@@ -21,13 +21,16 @@ from test.test_support import TESTFN as TEST_FILE
test.test_support.requires("network")
-NREQ = 3
TEST_STR = b"hello world\n"
HOST = "localhost"
HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"
+def signal_alarm(n):
+ """Call signal.alarm when it exists (i.e. not on Windows)."""
+ if hasattr(signal, 'alarm'):
+ signal.alarm(n)
def receive(sock, n, timeout=20):
r, w, x = select.select([sock], [], [], timeout)
@@ -46,70 +49,6 @@ if HAVE_UNIX_SOCKETS:
pass
-class MyMixinServer:
- def serve_a_few(self):
- for i in range(NREQ):
- self.handle_request()
-
- def handle_error(self, request, client_address):
- self.close_request(request)
- self.server_close()
- raise
-
-def receive(sock, n, timeout=20):
- r, w, x = select.select([sock], [], [], timeout)
- if sock in r:
- return sock.recv(n)
- else:
- raise RuntimeError("timed out on %r" % (sock,))
-
-def testdgram(proto, addr):
- s = socket.socket(proto, socket.SOCK_DGRAM)
- s.sendto(teststring, addr)
- buf = data = receive(s, 100)
- while data and b'\n' not in buf:
- data = receive(s, 100)
- buf += data
- verify(buf == teststring)
- s.close()
-
-def teststream(proto, addr):
- s = socket.socket(proto, socket.SOCK_STREAM)
- s.connect(addr)
- s.sendall(teststring)
- buf = data = receive(s, 100)
- while data and b'\n' not in buf:
- data = receive(s, 100)
- buf += data
- verify(buf == teststring)
- s.close()
-
-class ServerThread(threading.Thread):
- def __init__(self, addr, svrcls, hdlrcls):
- threading.Thread.__init__(self)
- self.__addr = addr
- self.__svrcls = svrcls
- self.__hdlrcls = hdlrcls
- self.ready = threading.Event()
-
- def run(self):
- class svrcls(MyMixinServer, self.__svrcls):
- pass
- if verbose: print("thread: creating server")
- svr = svrcls(self.__addr, self.__hdlrcls)
- # We had the OS pick a port, so pull the real address out of
- # the server.
- self.addr = svr.server_address
- self.port = self.addr[1]
- if self.addr != svr.socket.getsockname():
- raise RuntimeError('server_address was %s, expected %s' %
- (self.addr, svr.socket.getsockname()))
- self.ready.set()
- if verbose: print("thread: serving three times")
- svr.serve_a_few()
- if verbose: print("thread: done")
-
-
@contextlib.contextmanager
def simple_subprocess(testcase):
pid = os.fork()
@@ -126,7 +65,7 @@ class SocketServerTest(unittest.TestCase):
"""Test all socket servers."""
def setUp(self):
- signal.alarm(20) # Kill deadlocks after 20 seconds.
+ signal_alarm(20) # Kill deadlocks after 20 seconds.
self.port_seed = 0
self.test_files = []
@@ -139,7 +78,7 @@ class SocketServerTest(unittest.TestCase):
except os.error:
pass
self.test_files[:] = []
- signal.alarm(0) # Didn't deadlock.
+ signal_alarm(0) # Didn't deadlock.
def pickaddr(self, proto):
if proto == socket.AF_INET:
@@ -166,29 +105,48 @@ class SocketServerTest(unittest.TestCase):
self.test_files.append(fn)
return fn
+ def make_server(self, addr, svrcls, hdlrbase):
+ class MyServer(svrcls):
+ def handle_error(self, request, client_address):
+ self.close_request(request)
+ self.server_close()
+ raise
- def run_server(self, svrcls, hdlrbase, testfunc):
class MyHandler(hdlrbase):
def handle(self):
line = self.rfile.readline()
self.wfile.write(line)
- addr = self.pickaddr(svrcls.address_family)
+ if verbose: print("creating server")
+ server = MyServer(addr, MyHandler)
+ self.assertEquals(server.server_address, server.socket.getsockname())
+ return server
+
+ def run_server(self, svrcls, hdlrbase, testfunc):
+ server = self.make_server(self.pickaddr(svrcls.address_family),
+ svrcls, hdlrbase)
+ # We had the OS pick a port, so pull the real address out of
+ # the server.
+ addr = server.server_address
if verbose:
print("ADDR =", addr)
print("CLASS =", svrcls)
- t = ServerThread(addr, svrcls, MyHandler)
- if verbose: print("server created")
+
+ t = threading.Thread(
+ name='%s serving' % svrcls,
+ target=server.serve_forever,
+ # Short poll interval to make the test finish quickly.
+ # Time between requests is short enough that we won't wake
+ # up spuriously too many times.
+ kwargs={'poll_interval':0.01})
+ t.setDaemon(True) # In case this function raises.
t.start()
if verbose: print("server running")
- t.ready.wait(10)
- self.assert_(t.ready.isSet(),
- "%s not ready within a reasonable time" % svrcls)
- addr = t.addr
- for i in range(NREQ):
+ for i in range(3):
if verbose: print("test client", i)
testfunc(svrcls.address_family, addr)
if verbose: print("waiting for server")
+ server.shutdown()
t.join()
if verbose: print("done")
@@ -295,4 +253,4 @@ def test_main():
if __name__ == "__main__":
test_main()
- signal.alarm(3) # Shutdown shouldn't take more than 3 seconds.
+ signal_alarm(3) # Shutdown shouldn't take more than 3 seconds.