summaryrefslogtreecommitdiffstats
path: root/Lib/multiprocessing/forkserver.py
blob: 11df38285c6d82c9162df3c4b195fbc22c598c4b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import errno
import os
import select
import signal
import socket
import struct
import sys
import threading

from . import connection
from . import process
from . import reduction
from . import semaphore_tracker
from . import spawn
from . import util

__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
           'set_forkserver_preload']

#
#
#

MAXFDS_TO_SEND = 256
UNSIGNED_STRUCT = struct.Struct('Q')     # large enough for pid_t

_forkserver_address = None
_forkserver_alive_fd = None
_inherited_fds = None
_lock = threading.Lock()
_preload_modules = ['__main__']

#
# Public function
#

def set_forkserver_preload(modules_names):
    '''Set list of module names to try to load in forkserver process.'''
    global _preload_modules
    _preload_modules = modules_names


def get_inherited_fds():
    '''Return list of fds inherited from parent process.

    This returns None if the current process was not started by fork server.
    '''
    return _inherited_fds


def connect_to_new_process(fds):
    '''Request forkserver to create a child process.

    Returns a pair of fds (status_r, data_w).  The calling process can read
    the child process's pid and (eventually) its returncode from status_r.
    The calling process should write to data_w the pickled preparation and
    process data.
    '''
    if len(fds) + 4 >= MAXFDS_TO_SEND:
        raise ValueError('too many fds')
    with socket.socket(socket.AF_UNIX) as client:
        client.connect(_forkserver_address)
        parent_r, child_w = os.pipe()
        child_r, parent_w = os.pipe()
        allfds = [child_r, child_w, _forkserver_alive_fd,
                  semaphore_tracker._semaphore_tracker_fd]
        allfds += fds
        try:
            reduction.sendfds(client, allfds)
            return parent_r, parent_w
        except:
            os.close(parent_r)
            os.close(parent_w)
            raise
        finally:
            os.close(child_r)
            os.close(child_w)


def ensure_running():
    '''Make sure that a fork server is running.

    This can be called from any process.  Note that usually a child
    process will just reuse the forkserver started by its parent, so
    ensure_running() will do nothing.
    '''
    global _forkserver_address, _forkserver_alive_fd
    with _lock:
        if _forkserver_alive_fd is not None:
            return

        assert all(type(mod) is str for mod in _preload_modules)
        cmd = ('from multiprocessing.forkserver import main; ' +
               'main(%d, %d, %r, **%r)')

        if _preload_modules:
            desired_keys = {'main_path', 'sys_path'}
            data = spawn.get_preparation_data('ignore')
            data = dict((x,y) for (x,y) in data.items() if x in desired_keys)
        else:
            data = {}

        with socket.socket(socket.AF_UNIX) as listener:
            address = connection.arbitrary_address('AF_UNIX')
            listener.bind(address)
            os.chmod(address, 0o600)
            listener.listen(100)

            # all client processes own the write end of the "alive" pipe;
            # when they all terminate the read end becomes ready.
            alive_r, alive_w = util.pipe()
            try:
                fds_to_pass = [listener.fileno(), alive_r]
                cmd %= (listener.fileno(), alive_r, _preload_modules, data)
                exe = spawn.get_executable()
                args = [exe] + util._args_from_interpreter_flags() + ['-c', cmd]
                pid = util.spawnv_passfds(exe, args, fds_to_pass)
            except:
                os.close(alive_w)
                raise
            finally:
                os.close(alive_r)
            _forkserver_address = address
            _forkserver_alive_fd = alive_w


def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
    '''Run forkserver.'''
    if preload:
        if '__main__' in preload and main_path is not None:
            process.current_process()._inheriting = True
            try:
                spawn.import_main_path(main_path)
            finally:
                del process.current_process()._inheriting
        for modname in preload:
            try:
                __import__(modname)
            except ImportError:
                pass

    # close sys.stdin
    if sys.stdin is not None:
        try:
            sys.stdin.close()
            sys.stdin = open(os.devnull)
        except (OSError, ValueError):
            pass

    # ignoring SIGCHLD means no need to reap zombie processes
    handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
    with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener:
        global _forkserver_address
        _forkserver_address = listener.getsockname()
        readers = [listener, alive_r]

        while True:
            try:
                rfds, wfds, xfds = select.select(readers, [], [])

                if alive_r in rfds:
                    # EOF because no more client processes left
                    assert os.read(alive_r, 1) == b''
                    raise SystemExit

                assert listener in rfds
                with listener.accept()[0] as s:
                    code = 1
                    if os.fork() == 0:
                        try:
                            _serve_one(s, listener, alive_r, handler)
                        except Exception:
                            sys.excepthook(*sys.exc_info())
                            sys.stderr.flush()
                        finally:
                            os._exit(code)

            except InterruptedError:
                pass
            except OSError as e:
                if e.errno != errno.ECONNABORTED:
                    raise

#
# Code to bootstrap new process
#

def _serve_one(s, listener, alive_r, handler):
    global _inherited_fds, _forkserver_alive_fd

    # close unnecessary stuff and reset SIGCHLD handler
    listener.close()
    os.close(alive_r)
    signal.signal(signal.SIGCHLD, handler)

    # receive fds from parent process
    fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
    s.close()
    assert len(fds) <= MAXFDS_TO_SEND
    child_r, child_w, _forkserver_alive_fd, stfd, *_inherited_fds = fds
    semaphore_tracker._semaphore_tracker_fd = stfd

    # send pid to client processes
    write_unsigned(child_w, os.getpid())

    # reseed random number generator
    if 'random' in sys.modules:
        import random
        random.seed()

    # run process object received over pipe
    code = spawn._main(child_r)

    # write the exit code to the pipe
    write_unsigned(child_w, code)

#
# Read and write unsigned numbers
#

def read_unsigned(fd):
    data = b''
    length = UNSIGNED_STRUCT.size
    while len(data) < length:
        while True:
            try:
                s = os.read(fd, length - len(data))
            except InterruptedError:
                pass
            else:
                break
        if not s:
            raise EOFError('unexpected EOF')
        data += s
    return UNSIGNED_STRUCT.unpack(data)[0]

def write_unsigned(fd, n):
    msg = UNSIGNED_STRUCT.pack(n)
    while msg:
        while True:
            try:
                nbytes = os.write(fd, msg)
            except InterruptedError:
                pass
            else:
                break
        if nbytes == 0:
            raise RuntimeError('should not get here')
        msg = msg[nbytes:]