diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/SocketServer.py | 93 | ||||
-rw-r--r-- | Lib/test/test_socketserver.py | 76 |
2 files changed, 95 insertions, 74 deletions
diff --git a/Lib/SocketServer.py b/Lib/SocketServer.py index 0a194e5..2c41fbb 100644 --- a/Lib/SocketServer.py +++ b/Lib/SocketServer.py @@ -130,8 +130,13 @@ __version__ = "0.4" import socket +import select import sys import os +try: + import threading +except ImportError: + import dummy_threading as threading __all__ = ["TCPServer","UDPServer","ForkingUDPServer","ForkingTCPServer", "ThreadingUDPServer","ThreadingTCPServer","BaseRequestHandler", @@ -149,7 +154,8 @@ class BaseServer: Methods for the caller: - __init__(server_address, RequestHandlerClass) - - serve_forever() + - serve_forever(poll_interval=0.5) + - shutdown() - handle_request() # if you do not use serve_forever() - fileno() -> int # for select() @@ -190,6 +196,8 @@ class BaseServer: """Constructor. May be extended, do not override.""" self.server_address = server_address self.RequestHandlerClass = RequestHandlerClass + self.__is_shut_down = threading.Event() + self.__serving = False def server_activate(self): """Called by constructor to activate the server. @@ -199,27 +207,73 @@ class BaseServer: """ pass - def serve_forever(self): - """Handle one request at a time until doomsday.""" - while 1: - self.handle_request() + def serve_forever(self, poll_interval=0.5): + """Handle one request at a time until shutdown. + + Polls for shutdown every poll_interval seconds. Ignores + self.timeout. If you need to do periodic tasks, do them in + another thread. + """ + self.__serving = True + self.__is_shut_down.clear() + while self.__serving: + # XXX: Consider using another file descriptor or + # connecting to the socket to wake this up instead of + # polling. Polling reduces our responsiveness to a + # shutdown request and wastes cpu at all other times. + r, w, e = select.select([self], [], [], poll_interval) + if r: + self._handle_request_noblock() + self.__is_shut_down.set() + + def shutdown(self): + """Stops the serve_forever loop. + + Blocks until the loop has finished. This must be called while + serve_forever() is running in another thread, or it will + deadlock. + """ + self.__serving = False + self.__is_shut_down.wait() # The distinction between handling, getting, processing and # finishing a request is fairly arbitrary. Remember: # # - handle_request() is the top-level call. It calls - # await_request(), verify_request() and process_request() - # - get_request(), called by await_request(), is different for - # stream or datagram sockets + # select, get_request(), verify_request() and process_request() + # - get_request() is different for stream or datagram sockets # - process_request() is the place that may fork a new process # or create a new thread to finish the request # - finish_request() instantiates the request handler class; # this constructor will handle the request all by itself def handle_request(self): - """Handle one request, possibly blocking.""" + """Handle one request, possibly blocking. + + Respects self.timeout. + """ + # Support people who used socket.settimeout() to escape + # handle_request before self.timeout was available. + timeout = self.socket.gettimeout() + if timeout is None: + timeout = self.timeout + elif self.timeout is not None: + timeout = min(timeout, self.timeout) + fd_sets = select.select([self], [], [], timeout) + if not fd_sets[0]: + self.handle_timeout() + return + self._handle_request_noblock() + + def _handle_request_noblock(self): + """Handle one request, without blocking. + + I assume that select.select has returned that the socket is + readable before this function was called, so there should be + no risk of blocking in get_request(). + """ try: - request, client_address = self.await_request() + request, client_address = self.get_request() except socket.error: return if self.verify_request(request, client_address): @@ -229,21 +283,6 @@ class BaseServer: self.handle_error(request, client_address) self.close_request(request) - def await_request(self): - """Call get_request or handle_timeout, observing self.timeout. - - Returns value from get_request() or raises socket.timeout exception if - timeout was exceeded. - """ - if self.timeout is not None: - # If timeout == 0, you're responsible for your own fd magic. - import select - fd_sets = select.select([self], [], [], self.timeout) - if not fd_sets[0]: - self.handle_timeout() - raise socket.timeout("Listening timed out") - return self.get_request() - def handle_timeout(self): """Called if no new request arrives within self.timeout. @@ -307,7 +346,8 @@ class TCPServer(BaseServer): Methods for the caller: - __init__(server_address, RequestHandlerClass, bind_and_activate=True) - - serve_forever() + - serve_forever(poll_interval=0.5) + - shutdown() - handle_request() # if you don't use serve_forever() - fileno() -> int # for select() @@ -523,7 +563,6 @@ class ThreadingMixIn: def process_request(self, request, client_address): """Start a new thread to process the request.""" - import threading t = threading.Thread(target = self.process_request_thread, args = (request, client_address)) if self.daemon_threads: diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 92e5d04..bd25f57 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -21,7 +21,6 @@ from test.test_support import TESTFN as TEST_FILE test.test_support.requires("network") -NREQ = 3 TEST_STR = "hello world\n" HOST = "localhost" @@ -50,43 +49,6 @@ if HAVE_UNIX_SOCKETS: pass -class MyMixinServer: - def serve_a_few(self): - for i in range(NREQ): - self.handle_request() - - def handle_error(self, request, client_address): - self.close_request(request) - self.server_close() - raise - - -class ServerThread(threading.Thread): - def __init__(self, addr, svrcls, hdlrcls): - threading.Thread.__init__(self) - self.__addr = addr - self.__svrcls = svrcls - self.__hdlrcls = hdlrcls - self.ready = threading.Event() - - def run(self): - class svrcls(MyMixinServer, self.__svrcls): - pass - if verbose: print "thread: creating server" - svr = svrcls(self.__addr, self.__hdlrcls) - # We had the OS pick a port, so pull the real address out of - # the server. - self.addr = svr.server_address - self.port = self.addr[1] - if self.addr != svr.socket.getsockname(): - raise RuntimeError('server_address was %s, expected %s' % - (self.addr, svr.socket.getsockname())) - self.ready.set() - if verbose: print "thread: serving three times" - svr.serve_a_few() - if verbose: print "thread: done" - - @contextlib.contextmanager def simple_subprocess(testcase): pid = os.fork() @@ -143,28 +105,48 @@ class SocketServerTest(unittest.TestCase): self.test_files.append(fn) return fn - def run_server(self, svrcls, hdlrbase, testfunc): + def make_server(self, addr, svrcls, hdlrbase): + class MyServer(svrcls): + def handle_error(self, request, client_address): + self.close_request(request) + self.server_close() + raise + class MyHandler(hdlrbase): def handle(self): line = self.rfile.readline() self.wfile.write(line) - addr = self.pickaddr(svrcls.address_family) + if verbose: print "creating server" + server = MyServer(addr, MyHandler) + self.assertEquals(server.server_address, server.socket.getsockname()) + return server + + def run_server(self, svrcls, hdlrbase, testfunc): + server = self.make_server(self.pickaddr(svrcls.address_family), + svrcls, hdlrbase) + # We had the OS pick a port, so pull the real address out of + # the server. + addr = server.server_address if verbose: + print "server created" print "ADDR =", addr print "CLASS =", svrcls - t = ServerThread(addr, svrcls, MyHandler) - if verbose: print "server created" + t = threading.Thread( + name='%s serving' % svrcls, + target=server.serve_forever, + # Short poll interval to make the test finish quickly. + # Time between requests is short enough that we won't wake + # up spuriously too many times. + kwargs={'poll_interval':0.01}) + t.setDaemon(True) # In case this function raises. t.start() if verbose: print "server running" - t.ready.wait(10) - self.assert_(t.ready.isSet(), - "%s not ready within a reasonable time" % svrcls) - addr = t.addr - for i in range(NREQ): + for i in range(3): if verbose: print "test client", i testfunc(svrcls.address_family, addr) if verbose: print "waiting for server" + server.shutdown() t.join() if verbose: print "done" |