diff options
Diffstat (limited to 'Lib/multiprocessing/forkserver.py')
-rw-r--r-- | Lib/multiprocessing/forkserver.py | 214 |
1 files changed, 113 insertions, 101 deletions
diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py index 0a23707..387517e 100644 --- a/Lib/multiprocessing/forkserver.py +++ b/Lib/multiprocessing/forkserver.py @@ -24,105 +24,113 @@ __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', MAXFDS_TO_SEND = 256 UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t -_forkserver_address = None -_forkserver_alive_fd = None -_inherited_fds = None -_lock = threading.Lock() -_preload_modules = ['__main__'] - # -# Public function +# Forkserver class # -def set_forkserver_preload(modules_names): - '''Set list of module names to try to load in forkserver process.''' - global _preload_modules - _preload_modules = modules_names - - -def get_inherited_fds(): - '''Return list of fds inherited from parent process. - - This returns None if the current process was not started by fork server. - ''' - return _inherited_fds - - -def connect_to_new_process(fds): - '''Request forkserver to create a child process. - - Returns a pair of fds (status_r, data_w). The calling process can read - the child process's pid and (eventually) its returncode from status_r. - The calling process should write to data_w the pickled preparation and - process data. - ''' - if len(fds) + 4 >= MAXFDS_TO_SEND: - raise ValueError('too many fds') - with socket.socket(socket.AF_UNIX) as client: - client.connect(_forkserver_address) - parent_r, child_w = os.pipe() - child_r, parent_w = os.pipe() - allfds = [child_r, child_w, _forkserver_alive_fd, - semaphore_tracker._semaphore_tracker_fd] - allfds += fds - try: - reduction.sendfds(client, allfds) - return parent_r, parent_w - except: - os.close(parent_r) - os.close(parent_w) - raise - finally: - os.close(child_r) - os.close(child_w) - - -def ensure_running(): - '''Make sure that a fork server is running. - - This can be called from any process. Note that usually a child - process will just reuse the forkserver started by its parent, so - ensure_running() will do nothing. - ''' - global _forkserver_address, _forkserver_alive_fd - with _lock: - if _forkserver_alive_fd is not None: - return - - assert all(type(mod) is str for mod in _preload_modules) - cmd = ('from multiprocessing.forkserver import main; ' + - 'main(%d, %d, %r, **%r)') - - if _preload_modules: - desired_keys = {'main_path', 'sys_path'} - data = spawn.get_preparation_data('ignore') - data = dict((x,y) for (x,y) in data.items() if x in desired_keys) - else: - data = {} - - with socket.socket(socket.AF_UNIX) as listener: - address = connection.arbitrary_address('AF_UNIX') - listener.bind(address) - os.chmod(address, 0o600) - listener.listen(100) - - # all client processes own the write end of the "alive" pipe; - # when they all terminate the read end becomes ready. - alive_r, alive_w = os.pipe() +class ForkServer(object): + + def __init__(self): + self._forkserver_address = None + self._forkserver_alive_fd = None + self._inherited_fds = None + self._lock = threading.Lock() + self._preload_modules = ['__main__'] + + def set_forkserver_preload(self, modules_names): + '''Set list of module names to try to load in forkserver process.''' + if not all(type(mod) is str for mod in self._preload_modules): + raise TypeError('module_names must be a list of strings') + self._preload_modules = modules_names + + def get_inherited_fds(self): + '''Return list of fds inherited from parent process. + + This returns None if the current process was not started by fork + server. + ''' + return self._inherited_fds + + def connect_to_new_process(self, fds): + '''Request forkserver to create a child process. + + Returns a pair of fds (status_r, data_w). The calling process can read + the child process's pid and (eventually) its returncode from status_r. + The calling process should write to data_w the pickled preparation and + process data. + ''' + self.ensure_running() + if len(fds) + 4 >= MAXFDS_TO_SEND: + raise ValueError('too many fds') + with socket.socket(socket.AF_UNIX) as client: + client.connect(self._forkserver_address) + parent_r, child_w = os.pipe() + child_r, parent_w = os.pipe() + allfds = [child_r, child_w, self._forkserver_alive_fd, + semaphore_tracker.getfd()] + allfds += fds try: - fds_to_pass = [listener.fileno(), alive_r] - cmd %= (listener.fileno(), alive_r, _preload_modules, data) - exe = spawn.get_executable() - args = [exe] + util._args_from_interpreter_flags() + ['-c', cmd] - pid = util.spawnv_passfds(exe, args, fds_to_pass) + reduction.sendfds(client, allfds) + return parent_r, parent_w except: - os.close(alive_w) + os.close(parent_r) + os.close(parent_w) raise finally: - os.close(alive_r) - _forkserver_address = address - _forkserver_alive_fd = alive_w + os.close(child_r) + os.close(child_w) + + def ensure_running(self): + '''Make sure that a fork server is running. + + This can be called from any process. Note that usually a child + process will just reuse the forkserver started by its parent, so + ensure_running() will do nothing. + ''' + with self._lock: + semaphore_tracker.ensure_running() + if self._forkserver_alive_fd is not None: + return + + cmd = ('from multiprocessing.forkserver import main; ' + + 'main(%d, %d, %r, **%r)') + + if self._preload_modules: + desired_keys = {'main_path', 'sys_path'} + data = spawn.get_preparation_data('ignore') + data = dict((x,y) for (x,y) in data.items() + if x in desired_keys) + else: + data = {} + + with socket.socket(socket.AF_UNIX) as listener: + address = connection.arbitrary_address('AF_UNIX') + listener.bind(address) + os.chmod(address, 0o600) + listener.listen(100) + + # all client processes own the write end of the "alive" pipe; + # when they all terminate the read end becomes ready. + alive_r, alive_w = os.pipe() + try: + fds_to_pass = [listener.fileno(), alive_r] + cmd %= (listener.fileno(), alive_r, self._preload_modules, + data) + exe = spawn.get_executable() + args = [exe] + util._args_from_interpreter_flags() + args += ['-c', cmd] + pid = util.spawnv_passfds(exe, args, fds_to_pass) + except: + os.close(alive_w) + raise + finally: + os.close(alive_r) + self._forkserver_address = address + self._forkserver_alive_fd = alive_w +# +# +# def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): '''Run forkserver.''' @@ -151,8 +159,7 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN) with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ selectors.DefaultSelector() as selector: - global _forkserver_address - _forkserver_address = listener.getsockname() + _forkserver._forkserver_address = listener.getsockname() selector.register(listener, selectors.EVENT_READ) selector.register(alive_r, selectors.EVENT_READ) @@ -187,13 +194,7 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): if e.errno != errno.ECONNABORTED: raise -# -# Code to bootstrap new process -# - def _serve_one(s, listener, alive_r, handler): - global _inherited_fds, _forkserver_alive_fd - # close unnecessary stuff and reset SIGCHLD handler listener.close() os.close(alive_r) @@ -203,8 +204,9 @@ def _serve_one(s, listener, alive_r, handler): fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) s.close() assert len(fds) <= MAXFDS_TO_SEND - child_r, child_w, _forkserver_alive_fd, stfd, *_inherited_fds = fds - semaphore_tracker._semaphore_tracker_fd = stfd + (child_r, child_w, _forkserver._forkserver_alive_fd, + stfd, *_forkserver._inherited_fds) = fds + semaphore_tracker._semaphore_tracker._fd = stfd # send pid to client processes write_unsigned(child_w, os.getpid()) @@ -253,3 +255,13 @@ def write_unsigned(fd, n): if nbytes == 0: raise RuntimeError('should not get here') msg = msg[nbytes:] + +# +# +# + +_forkserver = ForkServer() +ensure_running = _forkserver.ensure_running +get_inherited_fds = _forkserver.get_inherited_fds +connect_to_new_process = _forkserver.connect_to_new_process +set_forkserver_preload = _forkserver.set_forkserver_preload |