summaryrefslogtreecommitdiffstats
path: root/Lib/multiprocessing/forkserver.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/multiprocessing/forkserver.py')
-rw-r--r--Lib/multiprocessing/forkserver.py214
1 files changed, 113 insertions, 101 deletions
diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py
index 0a23707..387517e 100644
--- a/Lib/multiprocessing/forkserver.py
+++ b/Lib/multiprocessing/forkserver.py
@@ -24,105 +24,113 @@ __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
MAXFDS_TO_SEND = 256
UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t
-_forkserver_address = None
-_forkserver_alive_fd = None
-_inherited_fds = None
-_lock = threading.Lock()
-_preload_modules = ['__main__']
-
#
-# Public function
+# Forkserver class
#
-def set_forkserver_preload(modules_names):
- '''Set list of module names to try to load in forkserver process.'''
- global _preload_modules
- _preload_modules = modules_names
-
-
-def get_inherited_fds():
- '''Return list of fds inherited from parent process.
-
- This returns None if the current process was not started by fork server.
- '''
- return _inherited_fds
-
-
-def connect_to_new_process(fds):
- '''Request forkserver to create a child process.
-
- Returns a pair of fds (status_r, data_w). The calling process can read
- the child process's pid and (eventually) its returncode from status_r.
- The calling process should write to data_w the pickled preparation and
- process data.
- '''
- if len(fds) + 4 >= MAXFDS_TO_SEND:
- raise ValueError('too many fds')
- with socket.socket(socket.AF_UNIX) as client:
- client.connect(_forkserver_address)
- parent_r, child_w = os.pipe()
- child_r, parent_w = os.pipe()
- allfds = [child_r, child_w, _forkserver_alive_fd,
- semaphore_tracker._semaphore_tracker_fd]
- allfds += fds
- try:
- reduction.sendfds(client, allfds)
- return parent_r, parent_w
- except:
- os.close(parent_r)
- os.close(parent_w)
- raise
- finally:
- os.close(child_r)
- os.close(child_w)
-
-
-def ensure_running():
- '''Make sure that a fork server is running.
-
- This can be called from any process. Note that usually a child
- process will just reuse the forkserver started by its parent, so
- ensure_running() will do nothing.
- '''
- global _forkserver_address, _forkserver_alive_fd
- with _lock:
- if _forkserver_alive_fd is not None:
- return
-
- assert all(type(mod) is str for mod in _preload_modules)
- cmd = ('from multiprocessing.forkserver import main; ' +
- 'main(%d, %d, %r, **%r)')
-
- if _preload_modules:
- desired_keys = {'main_path', 'sys_path'}
- data = spawn.get_preparation_data('ignore')
- data = dict((x,y) for (x,y) in data.items() if x in desired_keys)
- else:
- data = {}
-
- with socket.socket(socket.AF_UNIX) as listener:
- address = connection.arbitrary_address('AF_UNIX')
- listener.bind(address)
- os.chmod(address, 0o600)
- listener.listen(100)
-
- # all client processes own the write end of the "alive" pipe;
- # when they all terminate the read end becomes ready.
- alive_r, alive_w = os.pipe()
+class ForkServer(object):
+
+ def __init__(self):
+ self._forkserver_address = None
+ self._forkserver_alive_fd = None
+ self._inherited_fds = None
+ self._lock = threading.Lock()
+ self._preload_modules = ['__main__']
+
+ def set_forkserver_preload(self, modules_names):
+ '''Set list of module names to try to load in forkserver process.'''
+ if not all(type(mod) is str for mod in self._preload_modules):
+ raise TypeError('module_names must be a list of strings')
+ self._preload_modules = modules_names
+
+ def get_inherited_fds(self):
+ '''Return list of fds inherited from parent process.
+
+ This returns None if the current process was not started by fork
+ server.
+ '''
+ return self._inherited_fds
+
+ def connect_to_new_process(self, fds):
+ '''Request forkserver to create a child process.
+
+ Returns a pair of fds (status_r, data_w). The calling process can read
+ the child process's pid and (eventually) its returncode from status_r.
+ The calling process should write to data_w the pickled preparation and
+ process data.
+ '''
+ self.ensure_running()
+ if len(fds) + 4 >= MAXFDS_TO_SEND:
+ raise ValueError('too many fds')
+ with socket.socket(socket.AF_UNIX) as client:
+ client.connect(self._forkserver_address)
+ parent_r, child_w = os.pipe()
+ child_r, parent_w = os.pipe()
+ allfds = [child_r, child_w, self._forkserver_alive_fd,
+ semaphore_tracker.getfd()]
+ allfds += fds
try:
- fds_to_pass = [listener.fileno(), alive_r]
- cmd %= (listener.fileno(), alive_r, _preload_modules, data)
- exe = spawn.get_executable()
- args = [exe] + util._args_from_interpreter_flags() + ['-c', cmd]
- pid = util.spawnv_passfds(exe, args, fds_to_pass)
+ reduction.sendfds(client, allfds)
+ return parent_r, parent_w
except:
- os.close(alive_w)
+ os.close(parent_r)
+ os.close(parent_w)
raise
finally:
- os.close(alive_r)
- _forkserver_address = address
- _forkserver_alive_fd = alive_w
+ os.close(child_r)
+ os.close(child_w)
+
+ def ensure_running(self):
+ '''Make sure that a fork server is running.
+
+ This can be called from any process. Note that usually a child
+ process will just reuse the forkserver started by its parent, so
+ ensure_running() will do nothing.
+ '''
+ with self._lock:
+ semaphore_tracker.ensure_running()
+ if self._forkserver_alive_fd is not None:
+ return
+
+ cmd = ('from multiprocessing.forkserver import main; ' +
+ 'main(%d, %d, %r, **%r)')
+
+ if self._preload_modules:
+ desired_keys = {'main_path', 'sys_path'}
+ data = spawn.get_preparation_data('ignore')
+ data = dict((x,y) for (x,y) in data.items()
+ if x in desired_keys)
+ else:
+ data = {}
+
+ with socket.socket(socket.AF_UNIX) as listener:
+ address = connection.arbitrary_address('AF_UNIX')
+ listener.bind(address)
+ os.chmod(address, 0o600)
+ listener.listen(100)
+
+ # all client processes own the write end of the "alive" pipe;
+ # when they all terminate the read end becomes ready.
+ alive_r, alive_w = os.pipe()
+ try:
+ fds_to_pass = [listener.fileno(), alive_r]
+ cmd %= (listener.fileno(), alive_r, self._preload_modules,
+ data)
+ exe = spawn.get_executable()
+ args = [exe] + util._args_from_interpreter_flags()
+ args += ['-c', cmd]
+ pid = util.spawnv_passfds(exe, args, fds_to_pass)
+ except:
+ os.close(alive_w)
+ raise
+ finally:
+ os.close(alive_r)
+ self._forkserver_address = address
+ self._forkserver_alive_fd = alive_w
+#
+#
+#
def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
'''Run forkserver.'''
@@ -151,8 +159,7 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
selectors.DefaultSelector() as selector:
- global _forkserver_address
- _forkserver_address = listener.getsockname()
+ _forkserver._forkserver_address = listener.getsockname()
selector.register(listener, selectors.EVENT_READ)
selector.register(alive_r, selectors.EVENT_READ)
@@ -187,13 +194,7 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
if e.errno != errno.ECONNABORTED:
raise
-#
-# Code to bootstrap new process
-#
-
def _serve_one(s, listener, alive_r, handler):
- global _inherited_fds, _forkserver_alive_fd
-
# close unnecessary stuff and reset SIGCHLD handler
listener.close()
os.close(alive_r)
@@ -203,8 +204,9 @@ def _serve_one(s, listener, alive_r, handler):
fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
s.close()
assert len(fds) <= MAXFDS_TO_SEND
- child_r, child_w, _forkserver_alive_fd, stfd, *_inherited_fds = fds
- semaphore_tracker._semaphore_tracker_fd = stfd
+ (child_r, child_w, _forkserver._forkserver_alive_fd,
+ stfd, *_forkserver._inherited_fds) = fds
+ semaphore_tracker._semaphore_tracker._fd = stfd
# send pid to client processes
write_unsigned(child_w, os.getpid())
@@ -253,3 +255,13 @@ def write_unsigned(fd, n):
if nbytes == 0:
raise RuntimeError('should not get here')
msg = msg[nbytes:]
+
+#
+#
+#
+
+_forkserver = ForkServer()
+ensure_running = _forkserver.ensure_running
+get_inherited_fds = _forkserver.get_inherited_fds
+connect_to_new_process = _forkserver.connect_to_new_process
+set_forkserver_preload = _forkserver.set_forkserver_preload