summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/SocketServer.py93
-rw-r--r--Lib/test/test_socketserver.py76
2 files changed, 95 insertions, 74 deletions
diff --git a/Lib/SocketServer.py b/Lib/SocketServer.py
index 0a194e5..2c41fbb 100644
--- a/Lib/SocketServer.py
+++ b/Lib/SocketServer.py
@@ -130,8 +130,13 @@ __version__ = "0.4"
import socket
+import select
import sys
import os
+try:
+ import threading
+except ImportError:
+ import dummy_threading as threading
__all__ = ["TCPServer","UDPServer","ForkingUDPServer","ForkingTCPServer",
"ThreadingUDPServer","ThreadingTCPServer","BaseRequestHandler",
@@ -149,7 +154,8 @@ class BaseServer:
Methods for the caller:
- __init__(server_address, RequestHandlerClass)
- - serve_forever()
+ - serve_forever(poll_interval=0.5)
+ - shutdown()
- handle_request() # if you do not use serve_forever()
- fileno() -> int # for select()
@@ -190,6 +196,8 @@ class BaseServer:
"""Constructor. May be extended, do not override."""
self.server_address = server_address
self.RequestHandlerClass = RequestHandlerClass
+ self.__is_shut_down = threading.Event()
+ self.__serving = False
def server_activate(self):
"""Called by constructor to activate the server.
@@ -199,27 +207,73 @@ class BaseServer:
"""
pass
- def serve_forever(self):
- """Handle one request at a time until doomsday."""
- while 1:
- self.handle_request()
+ def serve_forever(self, poll_interval=0.5):
+ """Handle one request at a time until shutdown.
+
+ Polls for shutdown every poll_interval seconds. Ignores
+ self.timeout. If you need to do periodic tasks, do them in
+ another thread.
+ """
+ self.__serving = True
+ self.__is_shut_down.clear()
+ while self.__serving:
+ # XXX: Consider using another file descriptor or
+ # connecting to the socket to wake this up instead of
+ # polling. Polling reduces our responsiveness to a
+ # shutdown request and wastes cpu at all other times.
+ r, w, e = select.select([self], [], [], poll_interval)
+ if r:
+ self._handle_request_noblock()
+ self.__is_shut_down.set()
+
+ def shutdown(self):
+ """Stops the serve_forever loop.
+
+ Blocks until the loop has finished. This must be called while
+ serve_forever() is running in another thread, or it will
+ deadlock.
+ """
+ self.__serving = False
+ self.__is_shut_down.wait()
# The distinction between handling, getting, processing and
# finishing a request is fairly arbitrary. Remember:
#
# - handle_request() is the top-level call. It calls
- # await_request(), verify_request() and process_request()
- # - get_request(), called by await_request(), is different for
- # stream or datagram sockets
+ # select, get_request(), verify_request() and process_request()
+ # - get_request() is different for stream or datagram sockets
# - process_request() is the place that may fork a new process
# or create a new thread to finish the request
# - finish_request() instantiates the request handler class;
# this constructor will handle the request all by itself
def handle_request(self):
- """Handle one request, possibly blocking."""
+ """Handle one request, possibly blocking.
+
+ Respects self.timeout.
+ """
+ # Support people who used socket.settimeout() to escape
+ # handle_request before self.timeout was available.
+ timeout = self.socket.gettimeout()
+ if timeout is None:
+ timeout = self.timeout
+ elif self.timeout is not None:
+ timeout = min(timeout, self.timeout)
+ fd_sets = select.select([self], [], [], timeout)
+ if not fd_sets[0]:
+ self.handle_timeout()
+ return
+ self._handle_request_noblock()
+
+ def _handle_request_noblock(self):
+ """Handle one request, without blocking.
+
+ I assume that select.select has returned that the socket is
+ readable before this function was called, so there should be
+ no risk of blocking in get_request().
+ """
try:
- request, client_address = self.await_request()
+ request, client_address = self.get_request()
except socket.error:
return
if self.verify_request(request, client_address):
@@ -229,21 +283,6 @@ class BaseServer:
self.handle_error(request, client_address)
self.close_request(request)
- def await_request(self):
- """Call get_request or handle_timeout, observing self.timeout.
-
- Returns value from get_request() or raises socket.timeout exception if
- timeout was exceeded.
- """
- if self.timeout is not None:
- # If timeout == 0, you're responsible for your own fd magic.
- import select
- fd_sets = select.select([self], [], [], self.timeout)
- if not fd_sets[0]:
- self.handle_timeout()
- raise socket.timeout("Listening timed out")
- return self.get_request()
-
def handle_timeout(self):
"""Called if no new request arrives within self.timeout.
@@ -307,7 +346,8 @@ class TCPServer(BaseServer):
Methods for the caller:
- __init__(server_address, RequestHandlerClass, bind_and_activate=True)
- - serve_forever()
+ - serve_forever(poll_interval=0.5)
+ - shutdown()
- handle_request() # if you don't use serve_forever()
- fileno() -> int # for select()
@@ -523,7 +563,6 @@ class ThreadingMixIn:
def process_request(self, request, client_address):
"""Start a new thread to process the request."""
- import threading
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
if self.daemon_threads:
diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py
index 92e5d04..bd25f57 100644
--- a/Lib/test/test_socketserver.py
+++ b/Lib/test/test_socketserver.py
@@ -21,7 +21,6 @@ from test.test_support import TESTFN as TEST_FILE
test.test_support.requires("network")
-NREQ = 3
TEST_STR = "hello world\n"
HOST = "localhost"
@@ -50,43 +49,6 @@ if HAVE_UNIX_SOCKETS:
pass
-class MyMixinServer:
- def serve_a_few(self):
- for i in range(NREQ):
- self.handle_request()
-
- def handle_error(self, request, client_address):
- self.close_request(request)
- self.server_close()
- raise
-
-
-class ServerThread(threading.Thread):
- def __init__(self, addr, svrcls, hdlrcls):
- threading.Thread.__init__(self)
- self.__addr = addr
- self.__svrcls = svrcls
- self.__hdlrcls = hdlrcls
- self.ready = threading.Event()
-
- def run(self):
- class svrcls(MyMixinServer, self.__svrcls):
- pass
- if verbose: print "thread: creating server"
- svr = svrcls(self.__addr, self.__hdlrcls)
- # We had the OS pick a port, so pull the real address out of
- # the server.
- self.addr = svr.server_address
- self.port = self.addr[1]
- if self.addr != svr.socket.getsockname():
- raise RuntimeError('server_address was %s, expected %s' %
- (self.addr, svr.socket.getsockname()))
- self.ready.set()
- if verbose: print "thread: serving three times"
- svr.serve_a_few()
- if verbose: print "thread: done"
-
-
@contextlib.contextmanager
def simple_subprocess(testcase):
pid = os.fork()
@@ -143,28 +105,48 @@ class SocketServerTest(unittest.TestCase):
self.test_files.append(fn)
return fn
- def run_server(self, svrcls, hdlrbase, testfunc):
+ def make_server(self, addr, svrcls, hdlrbase):
+ class MyServer(svrcls):
+ def handle_error(self, request, client_address):
+ self.close_request(request)
+ self.server_close()
+ raise
+
class MyHandler(hdlrbase):
def handle(self):
line = self.rfile.readline()
self.wfile.write(line)
- addr = self.pickaddr(svrcls.address_family)
+ if verbose: print "creating server"
+ server = MyServer(addr, MyHandler)
+ self.assertEquals(server.server_address, server.socket.getsockname())
+ return server
+
+ def run_server(self, svrcls, hdlrbase, testfunc):
+ server = self.make_server(self.pickaddr(svrcls.address_family),
+ svrcls, hdlrbase)
+ # We had the OS pick a port, so pull the real address out of
+ # the server.
+ addr = server.server_address
if verbose:
+ print "server created"
print "ADDR =", addr
print "CLASS =", svrcls
- t = ServerThread(addr, svrcls, MyHandler)
- if verbose: print "server created"
+ t = threading.Thread(
+ name='%s serving' % svrcls,
+ target=server.serve_forever,
+ # Short poll interval to make the test finish quickly.
+ # Time between requests is short enough that we won't wake
+ # up spuriously too many times.
+ kwargs={'poll_interval':0.01})
+ t.setDaemon(True) # In case this function raises.
t.start()
if verbose: print "server running"
- t.ready.wait(10)
- self.assert_(t.ready.isSet(),
- "%s not ready within a reasonable time" % svrcls)
- addr = t.addr
- for i in range(NREQ):
+ for i in range(3):
if verbose: print "test client", i
testfunc(svrcls.address_family, addr)
if verbose: print "waiting for server"
+ server.shutdown()
t.join()
if verbose: print "done"