summaryrefslogtreecommitdiffstats
path: root/Lib/SocketServer.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/SocketServer.py')
-rw-r--r--Lib/SocketServer.py141
1 files changed, 118 insertions, 23 deletions
diff --git a/Lib/SocketServer.py b/Lib/SocketServer.py
index 7d9b9a5..7040738 100644
--- a/Lib/SocketServer.py
+++ b/Lib/SocketServer.py
@@ -130,8 +130,13 @@ __version__ = "0.4"
import socket
+import select
import sys
import os
+try:
+ import threading
+except ImportError:
+ import dummy_threading as threading
__all__ = ["TCPServer","UDPServer","ForkingUDPServer","ForkingTCPServer",
"ThreadingUDPServer","ThreadingTCPServer","BaseRequestHandler",
@@ -149,7 +154,8 @@ class BaseServer:
Methods for the caller:
- __init__(server_address, RequestHandlerClass)
- - serve_forever()
+ - serve_forever(poll_interval=0.5)
+ - shutdown()
- handle_request() # if you do not use serve_forever()
- fileno() -> int # for select()
@@ -158,6 +164,7 @@ class BaseServer:
- server_bind()
- server_activate()
- get_request() -> request, client_address
+ - handle_timeout()
- verify_request(request, client_address)
- server_close()
- process_request(request, client_address)
@@ -171,6 +178,7 @@ class BaseServer:
Class variables that may be overridden by derived classes or
instances:
+ - timeout
- address_family
- socket_type
- allow_reuse_address
@@ -182,10 +190,14 @@ class BaseServer:
"""
+ timeout = None
+
def __init__(self, server_address, RequestHandlerClass):
"""Constructor. May be extended, do not override."""
self.server_address = server_address
self.RequestHandlerClass = RequestHandlerClass
+ self.__is_shut_down = threading.Event()
+ self.__shutdown_request = False
def server_activate(self):
"""Called by constructor to activate the server.
@@ -195,16 +207,42 @@ class BaseServer:
"""
pass
- def serve_forever(self):
- """Handle one request at a time until doomsday."""
- while 1:
- self.handle_request()
+ def serve_forever(self, poll_interval=0.5):
+ """Handle one request at a time until shutdown.
+
+ Polls for shutdown every poll_interval seconds. Ignores
+ self.timeout. If you need to do periodic tasks, do them in
+ another thread.
+ """
+ self.__is_shut_down.clear()
+ try:
+ while not self.__shutdown_request:
+ # XXX: Consider using another file descriptor or
+ # connecting to the socket to wake this up instead of
+ # polling. Polling reduces our responsiveness to a
+ # shutdown request and wastes cpu at all other times.
+ r, w, e = select.select([self], [], [], poll_interval)
+ if self in r:
+ self._handle_request_noblock()
+ finally:
+ self.__shutdown_request = False
+ self.__is_shut_down.set()
+
+ def shutdown(self):
+ """Stops the serve_forever loop.
+
+ Blocks until the loop has finished. This must be called while
+ serve_forever() is running in another thread, or it will
+ deadlock.
+ """
+ self.__shutdown_request = True
+ self.__is_shut_down.wait()
# The distinction between handling, getting, processing and
# finishing a request is fairly arbitrary. Remember:
#
# - handle_request() is the top-level call. It calls
- # get_request(), verify_request() and process_request()
+ # select, get_request(), verify_request() and process_request()
# - get_request() is different for stream or datagram sockets
# - process_request() is the place that may fork a new process
# or create a new thread to finish the request
@@ -212,7 +250,30 @@ class BaseServer:
# this constructor will handle the request all by itself
def handle_request(self):
- """Handle one request, possibly blocking."""
+ """Handle one request, possibly blocking.
+
+ Respects self.timeout.
+ """
+ # Support people who used socket.settimeout() to escape
+ # handle_request before self.timeout was available.
+ timeout = self.socket.gettimeout()
+ if timeout is None:
+ timeout = self.timeout
+ elif self.timeout is not None:
+ timeout = min(timeout, self.timeout)
+ fd_sets = select.select([self], [], [], timeout)
+ if not fd_sets[0]:
+ self.handle_timeout()
+ return
+ self._handle_request_noblock()
+
+ def _handle_request_noblock(self):
+ """Handle one request, without blocking.
+
+ I assume that select.select has returned that the socket is
+ readable before this function was called, so there should be
+ no risk of blocking in get_request().
+ """
try:
request, client_address = self.get_request()
except socket.error:
@@ -224,6 +285,13 @@ class BaseServer:
self.handle_error(request, client_address)
self.close_request(request)
+ def handle_timeout(self):
+ """Called if no new request arrives within self.timeout.
+
+ Overridden by ForkingMixIn.
+ """
+ pass
+
def verify_request(self, request, client_address):
"""Verify the request. May be overridden.
@@ -279,8 +347,9 @@ class TCPServer(BaseServer):
Methods for the caller:
- - __init__(server_address, RequestHandlerClass)
- - serve_forever()
+ - __init__(server_address, RequestHandlerClass, bind_and_activate=True)
+ - serve_forever(poll_interval=0.5)
+ - shutdown()
- handle_request() # if you don't use serve_forever()
- fileno() -> int # for select()
@@ -289,6 +358,7 @@ class TCPServer(BaseServer):
- server_bind()
- server_activate()
- get_request() -> request, client_address
+ - handle_timeout()
- verify_request(request, client_address)
- process_request(request, client_address)
- close_request(request)
@@ -301,6 +371,7 @@ class TCPServer(BaseServer):
Class variables that may be overridden by derived classes or
instances:
+ - timeout
- address_family
- socket_type
- request_queue_size (only for stream sockets)
@@ -322,13 +393,14 @@ class TCPServer(BaseServer):
allow_reuse_address = False
- def __init__(self, server_address, RequestHandlerClass):
+ def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True):
"""Constructor. May be extended, do not override."""
BaseServer.__init__(self, server_address, RequestHandlerClass)
self.socket = socket.socket(self.address_family,
self.socket_type)
- self.server_bind()
- self.server_activate()
+ if bind_and_activate:
+ self.server_bind()
+ self.server_activate()
def server_bind(self):
"""Called by constructor to bind the socket.
@@ -404,25 +476,49 @@ class ForkingMixIn:
"""Mix-in class to handle each request in a new process."""
+ timeout = 300
active_children = None
max_children = 40
def collect_children(self):
- """Internal routine to wait for died children."""
- while self.active_children:
- if len(self.active_children) < self.max_children:
- options = os.WNOHANG
- else:
- # If the maximum number of children are already
- # running, block while waiting for a child to exit
- options = 0
+ """Internal routine to wait for children that have exited."""
+ if self.active_children is None: return
+ while len(self.active_children) >= self.max_children:
+ # XXX: This will wait for any child process, not just ones
+ # spawned by this library. This could confuse other
+ # libraries that expect to be able to wait for their own
+ # children.
try:
- pid, status = os.waitpid(0, options)
+ pid, status = os.waitpid(0, 0)
except os.error:
pid = None
- if not pid: break
+ if pid not in self.active_children: continue
self.active_children.remove(pid)
+ # XXX: This loop runs more system calls than it ought
+ # to. There should be a way to put the active_children into a
+ # process group and then use os.waitpid(-pgid) to wait for any
+ # of that set, but I couldn't find a way to allocate pgids
+ # that couldn't collide.
+ for child in self.active_children:
+ try:
+ pid, status = os.waitpid(child, os.WNOHANG)
+ except os.error:
+ pid = None
+ if not pid: continue
+ try:
+ self.active_children.remove(pid)
+ except ValueError, e:
+ raise ValueError('%s. x=%d and list=%r' % (e.message, pid,
+ self.active_children))
+
+ def handle_timeout(self):
+ """Wait for zombies after self.timeout seconds of inactivity.
+
+ May be extended, do not override.
+ """
+ self.collect_children()
+
def process_request(self, request, client_address):
"""Fork a new subprocess to process the request."""
self.collect_children()
@@ -469,7 +565,6 @@ class ThreadingMixIn:
def process_request(self, request, client_address):
"""Start a new thread to process the request."""
- import threading
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
if self.daemon_threads: