summaryrefslogtreecommitdiffstats
path: root/Doc/includes/mp_distributing.py
blob: ef1e862dd7ceb51ae18980a920081e771720e3bc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
#
# Module to allow spawning of processes on foreign host
#
# Depends on `multiprocessing` package -- tested with `processing-0.60`
#
# Copyright (c) 2006-2008, R Oudkerk
# All rights reserved.
#

__all__ = ['Cluster', 'Host', 'get_logger', 'current_process']

#
# Imports
#

import sys
import os
import tarfile
import shutil
import subprocess
import logging
import itertools
import queue

try:
    import pickle as pickle
except ImportError:
    import pickle

from multiprocessing import Process, current_process, cpu_count
from multiprocessing import util, managers, connection, forking, pool

#
# Logging
#

def get_logger():
    return _logger

_logger = logging.getLogger('distributing')
_logger.propogate = 0

_formatter = logging.Formatter(util.DEFAULT_LOGGING_FORMAT)
_handler = logging.StreamHandler()
_handler.setFormatter(_formatter)
_logger.addHandler(_handler)

info = _logger.info
debug = _logger.debug

#
# Get number of cpus
#

try:
    slot_count = cpu_count()
except NotImplemented:
    slot_count = 1

#
# Manager type which spawns subprocesses
#

class HostManager(managers.SyncManager):
    '''
    Manager type used for spawning processes on a (presumably) foreign host
    '''
    def __init__(self, address, authkey):
        managers.SyncManager.__init__(self, address, authkey)
        self._name = 'Host-unknown'

    def Process(self, group=None, target=None, name=None, args=(), kwargs={}):
        if hasattr(sys.modules['__main__'], '__file__'):
            main_path = os.path.basename(sys.modules['__main__'].__file__)
        else:
            main_path = None
        data = pickle.dumps((target, args, kwargs))
        p = self._RemoteProcess(data, main_path)
        if name is None:
            temp = self._name.split('Host-')[-1] + '/Process-%s'
            name = temp % ':'.join(map(str, p.get_identity()))
        p.set_name(name)
        return p

    @classmethod
    def from_address(cls, address, authkey):
        manager = cls(address, authkey)
        managers.transact(address, authkey, 'dummy')
        manager._state.value = managers.State.STARTED
        manager._name = 'Host-%s:%s' % manager.address
        manager.shutdown = util.Finalize(
            manager, HostManager._finalize_host,
            args=(manager._address, manager._authkey, manager._name),
            exitpriority=-10
            )
        return manager

    @staticmethod
    def _finalize_host(address, authkey, name):
        managers.transact(address, authkey, 'shutdown')

    def __repr__(self):
        return '<Host(%s)>' % self._name

#
# Process subclass representing a process on (possibly) a remote machine
#

class RemoteProcess(Process):
    '''
    Represents a process started on a remote host
    '''
    def __init__(self, data, main_path):
        assert not main_path or os.path.basename(main_path) == main_path
        Process.__init__(self)
        self._data = data
        self._main_path = main_path

    def _bootstrap(self):
        forking.prepare({'main_path': self._main_path})
        self._target, self._args, self._kwargs = pickle.loads(self._data)
        return Process._bootstrap(self)

    def get_identity(self):
        return self._identity

HostManager.register('_RemoteProcess', RemoteProcess)

#
# A Pool class that uses a cluster
#

class DistributedPool(pool.Pool):

    def __init__(self, cluster, processes=None, initializer=None, initargs=()):
        self._cluster = cluster
        self.Process = cluster.Process
        pool.Pool.__init__(self, processes or len(cluster),
                           initializer, initargs)

    def _setup_queues(self):
        self._inqueue = self._cluster._SettableQueue()
        self._outqueue = self._cluster._SettableQueue()
        self._quick_put = self._inqueue.put
        self._quick_get = self._outqueue.get

    @staticmethod
    def _help_stuff_finish(inqueue, task_handler, size):
        inqueue.set_contents([None] * size)

#
# Manager type which starts host managers on other machines
#

def LocalProcess(**kwds):
    p = Process(**kwds)
    p.set_name('localhost/' + p.name)
    return p

class Cluster(managers.SyncManager):
    '''
    Represents collection of slots running on various hosts.

    `Cluster` is a subclass of `SyncManager` so it allows creation of
    various types of shared objects.
    '''
    def __init__(self, hostlist, modules):
        managers.SyncManager.__init__(self, address=('localhost', 0))
        self._hostlist = hostlist
        self._modules = modules
        if __name__ not in modules:
            modules.append(__name__)
        files = [sys.modules[name].__file__ for name in modules]
        for i, file in enumerate(files):
            if file.endswith('.pyc') or file.endswith('.pyo'):
                files[i] = file[:-4] + '.py'
        self._files = [os.path.abspath(file) for file in files]

    def start(self):
        managers.SyncManager.start(self)

        l = connection.Listener(family='AF_INET', authkey=self._authkey)

        for i, host in enumerate(self._hostlist):
            host._start_manager(i, self._authkey, l.address, self._files)

        for host in self._hostlist:
            if host.hostname != 'localhost':
                conn = l.accept()
                i, address, cpus = conn.recv()
                conn.close()
                other_host = self._hostlist[i]
                other_host.manager = HostManager.from_address(address,
                                                              self._authkey)
                other_host.slots = other_host.slots or cpus
                other_host.Process = other_host.manager.Process
            else:
                host.slots = host.slots or slot_count
                host.Process = LocalProcess

        self._slotlist = [
            Slot(host) for host in self._hostlist for i in range(host.slots)
            ]
        self._slot_iterator = itertools.cycle(self._slotlist)
        self._base_shutdown = self.shutdown
        del self.shutdown

    def shutdown(self):
        for host in self._hostlist:
            if host.hostname != 'localhost':
                host.manager.shutdown()
        self._base_shutdown()

    def Process(self, group=None, target=None, name=None, args=(), kwargs={}):
        slot = next(self._slot_iterator)
        return slot.Process(
            group=group, target=target, name=name, args=args, kwargs=kwargs
            )

    def Pool(self, processes=None, initializer=None, initargs=()):
        return DistributedPool(self, processes, initializer, initargs)

    def __getitem__(self, i):
        return self._slotlist[i]

    def __len__(self):
        return len(self._slotlist)

    def __iter__(self):
        return iter(self._slotlist)

#
# Queue subclass used by distributed pool
#

class SettableQueue(queue.Queue):
    def empty(self):
        return not self.queue
    def full(self):
        return self.maxsize > 0 and len(self.queue) == self.maxsize
    def set_contents(self, contents):
        # length of contents must be at least as large as the number of
        # threads which have potentially called get()
        self.not_empty.acquire()
        try:
            self.queue.clear()
            self.queue.extend(contents)
            self.not_empty.notifyAll()
        finally:
            self.not_empty.release()

Cluster.register('_SettableQueue', SettableQueue)

#
# Class representing a notional cpu in the cluster
#

class Slot(object):
    def __init__(self, host):
        self.host = host
        self.Process = host.Process

#
# Host
#

class Host(object):
    '''
    Represents a host to use as a node in a cluster.

    `hostname` gives the name of the host.  If hostname is not
    "localhost" then ssh is used to log in to the host.  To log in as
    a different user use a host name of the form
    "username@somewhere.org"

    `slots` is used to specify the number of slots for processes on
    the host.  This affects how often processes will be allocated to
    this host.  Normally this should be equal to the number of cpus on
    that host.
    '''
    def __init__(self, hostname, slots=None):
        self.hostname = hostname
        self.slots = slots

    def _start_manager(self, index, authkey, address, files):
        if self.hostname != 'localhost':
            tempdir = copy_to_remote_temporary_directory(self.hostname, files)
            debug('startup files copied to %s:%s', self.hostname, tempdir)
            p = subprocess.Popen(
                ['ssh', self.hostname, 'python', '-c',
                 '"import os; os.chdir(%r); '
                 'from distributing import main; main()"' % tempdir],
                stdin=subprocess.PIPE
                )
            data = dict(
                name='BoostrappingHost', index=index,
                dist_log_level=_logger.getEffectiveLevel(),
                dir=tempdir, authkey=str(authkey), parent_address=address
                )
            pickle.dump(data, p.stdin, pickle.HIGHEST_PROTOCOL)
            p.stdin.close()

#
# Copy files to remote directory, returning name of directory
#

unzip_code = '''"
import tempfile, os, sys, tarfile
tempdir = tempfile.mkdtemp(prefix='distrib-')
os.chdir(tempdir)
tf = tarfile.open(fileobj=sys.stdin, mode='r|gz')
for ti in tf:
    tf.extract(ti)
print tempdir
"'''

def copy_to_remote_temporary_directory(host, files):
    p = subprocess.Popen(
        ['ssh', host, 'python', '-c', unzip_code],
        stdout=subprocess.PIPE, stdin=subprocess.PIPE
        )
    tf = tarfile.open(fileobj=p.stdin, mode='w|gz')
    for name in files:
        tf.add(name, os.path.basename(name))
    tf.close()
    p.stdin.close()
    return p.stdout.read().rstrip()

#
# Code which runs a host manager
#

def main():
    # get data from parent over stdin
    data = pickle.load(sys.stdin)
    sys.stdin.close()

    # set some stuff
    _logger.setLevel(data['dist_log_level'])
    forking.prepare(data)

    # create server for a `HostManager` object
    server = managers.Server(HostManager._registry, ('', 0), data['authkey'])
    current_process()._server = server

    # report server address and number of cpus back to parent
    conn = connection.Client(data['parent_address'], authkey=data['authkey'])
    conn.send((data['index'], server.address, slot_count))
    conn.close()

    # set name etc
    current_process().set_name('Host-%s:%s' % server.address)
    util._run_after_forkers()

    # register a cleanup function
    def cleanup(directory):
        debug('removing directory %s', directory)
        shutil.rmtree(directory)
        debug('shutting down host manager')
    util.Finalize(None, cleanup, args=[data['dir']], exitpriority=0)

    # start host manager
    debug('remote host manager starting in %s', data['dir'])
    server.serve_forever()