summaryrefslogtreecommitdiffstats
path: root/Python/pyhash.c
blob: 97cb54759b6183ead7ee41935f21db2899987e7d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
/* Set of hash utility functions to help maintaining the invariant that
    if a==b then hash(a)==hash(b)

   All the utility functions (_Py_Hash*()) return "-1" to signify an error.
*/
#include "Python.h"

#ifdef __APPLE__
#  include <libkern/OSByteOrder.h>
#elif defined(HAVE_LE64TOH) && defined(HAVE_ENDIAN_H)
#  include <endian.h>
#elif defined(HAVE_LE64TOH) && defined(HAVE_SYS_ENDIAN_H)
#  include <sys/endian.h>
#endif

#ifdef __cplusplus
extern "C" {
#endif

_Py_HashSecret_t _Py_HashSecret;

#if Py_HASH_ALGORITHM == Py_HASH_EXTERNAL
extern PyHash_FuncDef PyHash_Func;
#else
static PyHash_FuncDef PyHash_Func;
#endif

/* Count _Py_HashBytes() calls */
#ifdef Py_HASH_STATS
#define Py_HASH_STATS_MAX 32
static Py_ssize_t hashstats[Py_HASH_STATS_MAX + 1] = {0};
#endif

/* For numeric types, the hash of a number x is based on the reduction
   of x modulo the prime P = 2**_PyHASH_BITS - 1.  It's designed so that
   hash(x) == hash(y) whenever x and y are numerically equal, even if
   x and y have different types.

   A quick summary of the hashing strategy:

   (1) First define the 'reduction of x modulo P' for any rational
   number x; this is a standard extension of the usual notion of
   reduction modulo P for integers.  If x == p/q (written in lowest
   terms), the reduction is interpreted as the reduction of p times
   the inverse of the reduction of q, all modulo P; if q is exactly
   divisible by P then define the reduction to be infinity.  So we've
   got a well-defined map

      reduce : { rational numbers } -> { 0, 1, 2, ..., P-1, infinity }.

   (2) Now for a rational number x, define hash(x) by:

      reduce(x)   if x >= 0
      -reduce(-x) if x < 0

   If the result of the reduction is infinity (this is impossible for
   integers, floats and Decimals) then use the predefined hash value
   _PyHASH_INF for x >= 0, or -_PyHASH_INF for x < 0, instead.
   _PyHASH_INF, -_PyHASH_INF and _PyHASH_NAN are also used for the
   hashes of float and Decimal infinities and nans.

   A selling point for the above strategy is that it makes it possible
   to compute hashes of decimal and binary floating-point numbers
   efficiently, even if the exponent of the binary or decimal number
   is large.  The key point is that

      reduce(x * y) == reduce(x) * reduce(y) (modulo _PyHASH_MODULUS)

   provided that {reduce(x), reduce(y)} != {0, infinity}.  The reduction of a
   binary or decimal float is never infinity, since the denominator is a power
   of 2 (for binary) or a divisor of a power of 10 (for decimal).  So we have,
   for nonnegative x,

      reduce(x * 2**e) == reduce(x) * reduce(2**e) % _PyHASH_MODULUS

      reduce(x * 10**e) == reduce(x) * reduce(10**e) % _PyHASH_MODULUS

   and reduce(10**e) can be computed efficiently by the usual modular
   exponentiation algorithm.  For reduce(2**e) it's even better: since
   P is of the form 2**n-1, reduce(2**e) is 2**(e mod n), and multiplication
   by 2**(e mod n) modulo 2**n-1 just amounts to a rotation of bits.

   */

Py_hash_t
_Py_HashDouble(double v)
{
    int e, sign;
    double m;
    Py_uhash_t x, y;

    if (!Py_IS_FINITE(v)) {
        if (Py_IS_INFINITY(v))
            return v > 0 ? _PyHASH_INF : -_PyHASH_INF;
        else
            return _PyHASH_NAN;
    }

    m = frexp(v, &e);

    sign = 1;
    if (m < 0) {
        sign = -1;
        m = -m;
    }

    /* process 28 bits at a time;  this should work well both for binary
       and hexadecimal floating point. */
    x = 0;
    while (m) {
        x = ((x << 28) & _PyHASH_MODULUS) | x >> (_PyHASH_BITS - 28);
        m *= 268435456.0;  /* 2**28 */
        e -= 28;
        y = (Py_uhash_t)m;  /* pull out integer part */
        m -= y;
        x += y;
        if (x >= _PyHASH_MODULUS)
            x -= _PyHASH_MODULUS;
    }

    /* adjust for the exponent;  first reduce it modulo _PyHASH_BITS */
    e = e >= 0 ? e % _PyHASH_BITS : _PyHASH_BITS-1-((-1-e) % _PyHASH_BITS);
    x = ((x << e) & _PyHASH_MODULUS) | x >> (_PyHASH_BITS - e);

    x = x * sign;
    if (x == (Py_uhash_t)-1)
        x = (Py_uhash_t)-2;
    return (Py_hash_t)x;
}

Py_hash_t
_Py_HashPointer(void *p)
{
    Py_hash_t x;
    size_t y = (size_t)p;
    /* bottom 3 or 4 bits are likely to be 0; rotate y by 4 to avoid
       excessive hash collisions for dicts and sets */
    y = (y >> 4) | (y << (8 * SIZEOF_VOID_P - 4));
    x = (Py_hash_t)y;
    if (x == -1)
        x = -2;
    return x;
}

Py_hash_t
_Py_HashBytes(const void *src, Py_ssize_t len)
{
    Py_hash_t x;
    /*
      We make the hash of the empty string be 0, rather than using
      (prefix ^ suffix), since this slightly obfuscates the hash secret
    */
    if (len == 0) {
        return 0;
    }

#ifdef Py_HASH_STATS
    hashstats[(len <= Py_HASH_STATS_MAX) ? len : 0]++;
#endif

#if Py_HASH_CUTOFF > 0
    if (len < Py_HASH_CUTOFF) {
        /* Optimize hashing of very small strings with inline DJBX33A. */
        Py_uhash_t hash;
        const unsigned char *p = src;
        hash = 5381; /* DJBX33A starts with 5381 */

        switch(len) {
            /* ((hash << 5) + hash) + *p == hash * 33 + *p */
            case 7: hash = ((hash << 5) + hash) + *p++; /* fallthrough */
            case 6: hash = ((hash << 5) + hash) + *p++; /* fallthrough */
            case 5: hash = ((hash << 5) + hash) + *p++; /* fallthrough */
            case 4: hash = ((hash << 5) + hash) + *p++; /* fallthrough */
            case 3: hash = ((hash << 5) + hash) + *p++; /* fallthrough */
            case 2: hash = ((hash << 5) + hash) + *p++; /* fallthrough */
            case 1: hash = ((hash << 5) + hash) + *p++; break;
            default:
                assert(0);
        }
        hash ^= len;
        hash ^= (Py_uhash_t) _Py_HashSecret.djbx33a.suffix;
        x = (Py_hash_t)hash;
    }
    else
#endif /* Py_HASH_CUTOFF */
        x = PyHash_Func.hash(src, len);

    if (x == -1)
        return -2;
    return x;
}

void
_PyHash_Fini(void)
{
#ifdef Py_HASH_STATS
    int i;
    Py_ssize_t total = 0;
    char *fmt = "%2i %8" PY_FORMAT_SIZE_T "d %8" PY_FORMAT_SIZE_T "d\n";

    fprintf(stderr, "len   calls    total\n");
    for (i = 1; i <= Py_HASH_STATS_MAX; i++) {
        total += hashstats[i];
        fprintf(stderr, fmt, i, hashstats[i], total);
    }
    total += hashstats[0];
    fprintf(stderr, ">  %8" PY_FORMAT_SIZE_T "d %8" PY_FORMAT_SIZE_T "d\n",
            hashstats[0], total);
#endif
}

PyHash_FuncDef *
PyHash_GetFuncDef(void)
{
    return &PyHash_Func;
}

/* Optimized memcpy() for Windows */
#ifdef _MSC_VER
#  if SIZEOF_PY_UHASH_T == 4
#    define PY_UHASH_CPY(dst, src) do {                                    \
       dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3]; \
       } while(0)
#  elif SIZEOF_PY_UHASH_T == 8
#    define PY_UHASH_CPY(dst, src) do {                                    \
       dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3]; \
       dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7]; \
       } while(0)
#  else
#    error SIZEOF_PY_UHASH_T must be 4 or 8
#  endif /* SIZEOF_PY_UHASH_T */
#else /* not Windows */
#  define PY_UHASH_CPY(dst, src) memcpy(dst, src, SIZEOF_PY_UHASH_T)
#endif /* _MSC_VER */


#if Py_HASH_ALGORITHM == Py_HASH_FNV
/* **************************************************************************
 * Modified Fowler-Noll-Vo (FNV) hash function
 */
static Py_hash_t
fnv(const void *src, Py_ssize_t len)
{
    const unsigned char *p = src;
    Py_uhash_t x;
    Py_ssize_t remainder, blocks;
    union {
        Py_uhash_t value;
        unsigned char bytes[SIZEOF_PY_UHASH_T];
    } block;

#ifdef Py_DEBUG
    assert(_Py_HashSecret_Initialized);
#endif
    remainder = len % SIZEOF_PY_UHASH_T;
    if (remainder == 0) {
        /* Process at least one block byte by byte to reduce hash collisions
         * for strings with common prefixes. */
        remainder = SIZEOF_PY_UHASH_T;
    }
    blocks = (len - remainder) / SIZEOF_PY_UHASH_T;

    x = (Py_uhash_t) _Py_HashSecret.fnv.prefix;
    x ^= (Py_uhash_t) *p << 7;
    while (blocks--) {
        PY_UHASH_CPY(block.bytes, p);
        x = (_PyHASH_MULTIPLIER * x) ^ block.value;
        p += SIZEOF_PY_UHASH_T;
    }
    /* add remainder */
    for (; remainder > 0; remainder--)
        x = (_PyHASH_MULTIPLIER * x) ^ (Py_uhash_t) *p++;
    x ^= (Py_uhash_t) len;
    x ^= (Py_uhash_t) _Py_HashSecret.fnv.suffix;
    if (x == -1) {
        x = -2;
    }
    return x;
}

static PyHash_FuncDef PyHash_Func = {fnv, "fnv", 8 * SIZEOF_PY_HASH_T,
                                     16 * SIZEOF_PY_HASH_T};

#endif /* Py_HASH_ALGORITHM == Py_HASH_FNV */


#if Py_HASH_ALGORITHM == Py_HASH_SIPHASH24
/* **************************************************************************
 <MIT License>
 Copyright (c) 2013  Marek Majkowski <marek@popcount.org>

 Permission is hereby granted, free of charge, to any person obtaining a copy
 of this software and associated documentation files (the "Software"), to deal
 in the Software without restriction, including without limitation the rights
 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 copies of the Software, and to permit persons to whom the Software is
 furnished to do so, subject to the following conditions:

 The above copyright notice and this permission notice shall be included in
 all copies or substantial portions of the Software.

 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 </MIT License>

 Original location:
    https://github.com/majek/csiphash/

 Solution inspired by code from:
    Samuel Neves (supercop/crypto_auth/siphash24/little)
    djb (supercop/crypto_auth/siphash24/little2)
    Jean-Philippe Aumasson (https://131002.net/siphash/siphash24.c)

 Modified for Python by Christian Heimes:
    - C89 / MSVC compatibility
    - PY_UINT64_T, PY_UINT32_T and PY_UINT8_T
    - _rotl64() on Windows
    - letoh64() fallback
*/

typedef unsigned char PY_UINT8_T;

/* byte swap little endian to host endian
 * Endian conversion not only ensures that the hash function returns the same
 * value on all platforms. It is also required to for a good dispersion of
 * the hash values' least significant bits.
 */
#if PY_LITTLE_ENDIAN
#  define _le64toh(x) ((PY_UINT64_T)(x))
#elif defined(__APPLE__)
#  define _le64toh(x) OSSwapLittleToHostInt64(x)
#elif defined(HAVE_LETOH64)
#  define _le64toh(x) le64toh(x)
#else
#  define _le64toh(x) (((PY_UINT64_T)(x) << 56) | \
                      (((PY_UINT64_T)(x) << 40) & 0xff000000000000ULL) | \
                      (((PY_UINT64_T)(x) << 24) & 0xff0000000000ULL) | \
                      (((PY_UINT64_T)(x) << 8)  & 0xff00000000ULL) | \
                      (((PY_UINT64_T)(x) >> 8)  & 0xff000000ULL) | \
                      (((PY_UINT64_T)(x) >> 24) & 0xff0000ULL) | \
                      (((PY_UINT64_T)(x) >> 40) & 0xff00ULL) | \
                      ((PY_UINT64_T)(x)  >> 56))
#endif


#ifdef _MSC_VER
#  define ROTATE(x, b)  _rotl64(x, b)
#else
#  define ROTATE(x, b) (PY_UINT64_T)( ((x) << (b)) | ( (x) >> (64 - (b))) )
#endif

#define HALF_ROUND(a,b,c,d,s,t)         \
    a += b; c += d;             \
    b = ROTATE(b, s) ^ a;           \
    d = ROTATE(d, t) ^ c;           \
    a = ROTATE(a, 32);

#define DOUBLE_ROUND(v0,v1,v2,v3)       \
    HALF_ROUND(v0,v1,v2,v3,13,16);      \
    HALF_ROUND(v2,v1,v0,v3,17,21);      \
    HALF_ROUND(v0,v1,v2,v3,13,16);      \
    HALF_ROUND(v2,v1,v0,v3,17,21);


static Py_hash_t
siphash24(const void *src, Py_ssize_t src_sz) {
    PY_UINT64_T k0 = _le64toh(_Py_HashSecret.siphash.k0);
    PY_UINT64_T k1 = _le64toh(_Py_HashSecret.siphash.k1);
    PY_UINT64_T b = (PY_UINT64_T)src_sz << 56;
    const PY_UINT64_T *in = (PY_UINT64_T*)src;

    PY_UINT64_T v0 = k0 ^ 0x736f6d6570736575ULL;
    PY_UINT64_T v1 = k1 ^ 0x646f72616e646f6dULL;
    PY_UINT64_T v2 = k0 ^ 0x6c7967656e657261ULL;
    PY_UINT64_T v3 = k1 ^ 0x7465646279746573ULL;

    PY_UINT64_T t;
    PY_UINT8_T *pt;
    PY_UINT8_T *m;

    while (src_sz >= 8) {
        PY_UINT64_T mi = _le64toh(*in);
        in += 1;
        src_sz -= 8;
        v3 ^= mi;
        DOUBLE_ROUND(v0,v1,v2,v3);
        v0 ^= mi;
    }

    t = 0;
    pt = (PY_UINT8_T *)&t;
    m = (PY_UINT8_T *)in;
    switch (src_sz) {
        case 7: pt[6] = m[6];
        case 6: pt[5] = m[5];
        case 5: pt[4] = m[4];
        case 4: Py_MEMCPY(pt, m, sizeof(PY_UINT32_T)); break;
        case 3: pt[2] = m[2];
        case 2: pt[1] = m[1];
        case 1: pt[0] = m[0];
    }
    b |= _le64toh(t);

    v3 ^= b;
    DOUBLE_ROUND(v0,v1,v2,v3);
    v0 ^= b;
    v2 ^= 0xff;
    DOUBLE_ROUND(v0,v1,v2,v3);
    DOUBLE_ROUND(v0,v1,v2,v3);

    /* modified */
    t = (v0 ^ v1) ^ (v2 ^ v3);
    return (Py_hash_t)t;
}

static PyHash_FuncDef PyHash_Func = {siphash24, "siphash24", 64, 128};

#endif /* Py_HASH_ALGORITHM == Py_HASH_SIPHASH24 */

#ifdef __cplusplus
}
#endif
16) self.address = self.listener.address self.id_to_obj = {'0': (None, ())} self.id_to_refcount = {} self.mutex = threading.RLock() def serve_forever(self): ''' Run the server forever ''' self.stop_event = threading.Event() current_process()._manager_server = self try: accepter = threading.Thread(target=self.accepter) accepter.daemon = True accepter.start() try: while not self.stop_event.is_set(): self.stop_event.wait(1) except (KeyboardInterrupt, SystemExit): pass finally: if sys.stdout != sys.__stdout__: util.debug('resetting stdout, stderr') sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ sys.exit(0) def accepter(self): while True: try: c = self.listener.accept() except (OSError, IOError): continue t = threading.Thread(target=self.handle_request, args=(c,)) t.daemon = True t.start() def handle_request(self, c): ''' Handle a new connection ''' funcname = result = request = None try: connection.deliver_challenge(c, self.authkey) connection.answer_challenge(c, self.authkey) request = c.recv() ignore, funcname, args, kwds = request assert funcname in self.public, '%r unrecognized' % funcname func = getattr(self, funcname) except Exception: msg = ('#TRACEBACK', format_exc()) else: try: result = func(c, *args, **kwds) except Exception: msg = ('#TRACEBACK', format_exc()) else: msg = ('#RETURN', result) try: c.send(msg) except Exception as e: try: c.send(('#TRACEBACK', format_exc())) except Exception: pass util.info('Failure to send message: %r', msg) util.info(' ... request was %r', request) util.info(' ... exception was %r', e) c.close() def serve_client(self, conn): ''' Handle requests from the proxies in a particular process/thread ''' util.debug('starting server thread to service %r', threading.current_thread().name) recv = conn.recv send = conn.send id_to_obj = self.id_to_obj while not self.stop_event.is_set(): try: methodname = obj = None request = recv() ident, methodname, args, kwds = request obj, exposed, gettypeid = id_to_obj[ident] if methodname not in exposed: raise AttributeError( 'method %r of %r object is not in exposed=%r' % (methodname, type(obj), exposed) ) function = getattr(obj, methodname) try: res = function(*args, **kwds) except Exception as e: msg = ('#ERROR', e) else: typeid = gettypeid and gettypeid.get(methodname, None) if typeid: rident, rexposed = self.create(conn, typeid, res) token = Token(typeid, self.address, rident) msg = ('#PROXY', (rexposed, token)) else: msg = ('#RETURN', res) except AttributeError: if methodname is None: msg = ('#TRACEBACK', format_exc()) else: try: fallback_func = self.fallback_mapping[methodname] result = fallback_func( self, conn, ident, obj, *args, **kwds ) msg = ('#RETURN', result) except Exception: msg = ('#TRACEBACK', format_exc()) except EOFError: util.debug('got EOF -- exiting thread serving %r', threading.current_thread().name) sys.exit(0) except Exception: msg = ('#TRACEBACK', format_exc()) try: try: send(msg) except Exception as e: send(('#UNSERIALIZABLE', repr(msg))) except Exception as e: util.info('exception in thread serving %r', threading.current_thread().name) util.info(' ... message was %r', msg) util.info(' ... exception was %r', e) conn.close() sys.exit(1) def fallback_getvalue(self, conn, ident, obj): return obj def fallback_str(self, conn, ident, obj): return str(obj) def fallback_repr(self, conn, ident, obj): return repr(obj) fallback_mapping = { '__str__':fallback_str, '__repr__':fallback_repr, '#GETVALUE':fallback_getvalue } def dummy(self, c): pass def debug_info(self, c): ''' Return some info --- useful to spot problems with refcounting ''' self.mutex.acquire() try: result = [] keys = list(self.id_to_obj.keys()) keys.sort() for ident in keys: if ident != '0': result.append(' %s: refcount=%s\n %s' % (ident, self.id_to_refcount[ident], str(self.id_to_obj[ident][0])[:75])) return '\n'.join(result) finally: self.mutex.release() def number_of_objects(self, c): ''' Number of shared objects ''' return len(self.id_to_obj) - 1 # don't count ident='0' def shutdown(self, c): ''' Shutdown this process ''' try: util.debug('manager received shutdown message') c.send(('#RETURN', None)) except: import traceback traceback.print_exc() finally: self.stop_event.set() def create(self, c, typeid, *args, **kwds): ''' Create a new shared object and return its id ''' self.mutex.acquire() try: callable, exposed, method_to_typeid, proxytype = \ self.registry[typeid] if callable is None: assert len(args) == 1 and not kwds obj = args[0] else: obj = callable(*args, **kwds) if exposed is None: exposed = public_methods(obj) if method_to_typeid is not None: assert type(method_to_typeid) is dict exposed = list(exposed) + list(method_to_typeid) ident = '%x' % id(obj) # convert to string because xmlrpclib # only has 32 bit signed integers util.debug('%r callable returned object with id %r', typeid, ident) self.id_to_obj[ident] = (obj, set(exposed), method_to_typeid) if ident not in self.id_to_refcount: self.id_to_refcount[ident] = 0 # increment the reference count immediately, to avoid # this object being garbage collected before a Proxy # object for it can be created. The caller of create() # is responsible for doing a decref once the Proxy object # has been created. self.incref(c, ident) return ident, tuple(exposed) finally: self.mutex.release() def get_methods(self, c, token): ''' Return the methods of the shared object indicated by token ''' return tuple(self.id_to_obj[token.id][1]) def accept_connection(self, c, name): ''' Spawn a new thread to serve this connection ''' threading.current_thread().name = name c.send(('#RETURN', None)) self.serve_client(c) def incref(self, c, ident): self.mutex.acquire() try: self.id_to_refcount[ident] += 1 finally: self.mutex.release() def decref(self, c, ident): self.mutex.acquire() try: assert self.id_to_refcount[ident] >= 1 self.id_to_refcount[ident] -= 1 if self.id_to_refcount[ident] == 0: del self.id_to_obj[ident], self.id_to_refcount[ident] util.debug('disposing of obj with id %r', ident) finally: self.mutex.release() # # Class to represent state of a manager # class State(object): __slots__ = ['value'] INITIAL = 0 STARTED = 1 SHUTDOWN = 2 # # Mapping from serializer name to Listener and Client types # listener_client = { 'pickle' : (connection.Listener, connection.Client), 'xmlrpclib' : (connection.XmlListener, connection.XmlClient) } # # Definition of BaseManager # class BaseManager(object): ''' Base class for managers ''' _registry = {} _Server = Server def __init__(self, address=None, authkey=None, serializer='pickle'): if authkey is None: authkey = current_process().authkey self._address = address # XXX not final address if eg ('', 0) self._authkey = AuthenticationString(authkey) self._state = State() self._state.value = State.INITIAL self._serializer = serializer self._Listener, self._Client = listener_client[serializer] def get_server(self): ''' Return server object with serve_forever() method and address attribute ''' assert self._state.value == State.INITIAL return Server(self._registry, self._address, self._authkey, self._serializer) def connect(self): ''' Connect manager object to the server process ''' Listener, Client = listener_client[self._serializer] conn = Client(self._address, authkey=self._authkey) dispatch(conn, None, 'dummy') self._state.value = State.STARTED def start(self, initializer=None, initargs=()): ''' Spawn a server process for this manager object ''' assert self._state.value == State.INITIAL if initializer is not None and not callable(initializer): raise TypeError('initializer must be a callable') # pipe over which we will retrieve address of server reader, writer = connection.Pipe(duplex=False) # spawn process which runs a server self._process = Process( target=type(self)._run_server, args=(self._registry, self._address, self._authkey, self._serializer, writer, initializer, initargs), ) ident = ':'.join(str(i) for i in self._process._identity) self._process.name = type(self).__name__ + '-' + ident self._process.start() # get address of server writer.close() self._address = reader.recv() reader.close() # register a finalizer self._state.value = State.STARTED self.shutdown = util.Finalize( self, type(self)._finalize_manager, args=(self._process, self._address, self._authkey, self._state, self._Client), exitpriority=0 ) @classmethod def _run_server(cls, registry, address, authkey, serializer, writer, initializer=None, initargs=()): ''' Create a server, report its address and run it ''' if initializer is not None: initializer(*initargs) # create server server = cls._Server(registry, address, authkey, serializer) # inform parent process of the server's address writer.send(server.address) writer.close() # run the manager util.info('manager serving at %r', server.address) server.serve_forever() def _create(self, typeid, *args, **kwds): ''' Create a new shared object; return the token and exposed tuple ''' assert self._state.value == State.STARTED, 'server not yet started' conn = self._Client(self._address, authkey=self._authkey) try: id, exposed = dispatch(conn, None, 'create', (typeid,)+args, kwds) finally: conn.close() return Token(typeid, self._address, id), exposed def join(self, timeout=None): ''' Join the manager process (if it has been spawned) ''' if self._process is not None: self._process.join(timeout) if not self._process.is_alive(): self._process = None def _debug_info(self): ''' Return some info about the servers shared objects and connections ''' conn = self._Client(self._address, authkey=self._authkey) try: return dispatch(conn, None, 'debug_info') finally: conn.close() def _number_of_objects(self): ''' Return the number of shared objects ''' conn = self._Client(self._address, authkey=self._authkey) try: return dispatch(conn, None, 'number_of_objects') finally: conn.close() def __enter__(self): if self._state.value == State.INITIAL: self.start() assert self._state.value == State.STARTED return self def __exit__(self, exc_type, exc_val, exc_tb): self.shutdown() @staticmethod def _finalize_manager(process, address, authkey, state, _Client): ''' Shutdown the manager process; will be registered as a finalizer ''' if process.is_alive(): util.info('sending shutdown message to manager') try: conn = _Client(address, authkey=authkey) try: dispatch(conn, None, 'shutdown') finally: conn.close() except Exception: pass process.join(timeout=1.0) if process.is_alive(): util.info('manager still alive') if hasattr(process, 'terminate'): util.info('trying to `terminate()` manager process') process.terminate() process.join(timeout=0.1) if process.is_alive(): util.info('manager still alive after terminate') state.value = State.SHUTDOWN try: del BaseProxy._address_to_local[address] except KeyError: pass address = property(lambda self: self._address) @classmethod def register(cls, typeid, callable=None, proxytype=None, exposed=None, method_to_typeid=None, create_method=True): ''' Register a typeid with the manager type ''' if '_registry' not in cls.__dict__: cls._registry = cls._registry.copy() if proxytype is None: proxytype = AutoProxy exposed = exposed or getattr(proxytype, '_exposed_', None) method_to_typeid = method_to_typeid or \ getattr(proxytype, '_method_to_typeid_', None) if method_to_typeid: for key, value in list(method_to_typeid.items()): assert type(key) is str, '%r is not a string' % key assert type(value) is str, '%r is not a string' % value cls._registry[typeid] = ( callable, exposed, method_to_typeid, proxytype ) if create_method: def temp(self, *args, **kwds): util.debug('requesting creation of a shared %r object', typeid) token, exp = self._create(typeid, *args, **kwds) proxy = proxytype( token, self._serializer, manager=self, authkey=self._authkey, exposed=exp ) conn = self._Client(token.address, authkey=self._authkey) dispatch(conn, None, 'decref', (token.id,)) return proxy temp.__name__ = typeid setattr(cls, typeid, temp) # # Subclass of set which get cleared after a fork # class ProcessLocalSet(set): def __init__(self): util.register_after_fork(self, lambda obj: obj.clear()) def __reduce__(self): return type(self), () # # Definition of BaseProxy # class BaseProxy(object): ''' A base for proxies of shared objects ''' _address_to_local = {} _mutex = util.ForkAwareThreadLock() def __init__(self, token, serializer, manager=None, authkey=None, exposed=None, incref=True): BaseProxy._mutex.acquire() try: tls_idset = BaseProxy._address_to_local.get(token.address, None) if tls_idset is None: tls_idset = util.ForkAwareLocal(), ProcessLocalSet() BaseProxy._address_to_local[token.address] = tls_idset finally: BaseProxy._mutex.release() # self._tls is used to record the connection used by this # thread to communicate with the manager at token.address self._tls = tls_idset[0] # self._idset is used to record the identities of all shared # objects for which the current process owns references and # which are in the manager at token.address self._idset = tls_idset[1] self._token = token self._id = self._token.id self._manager = manager self._serializer = serializer self._Client = listener_client[serializer][1] if authkey is not None: self._authkey = AuthenticationString(authkey) elif self._manager is not None: self._authkey = self._manager._authkey else: self._authkey = current_process().authkey if incref: self._incref() util.register_after_fork(self, BaseProxy._after_fork) def _connect(self): util.debug('making connection to manager') name = current_process().name if threading.current_thread().name != 'MainThread': name += '|' + threading.current_thread().name conn = self._Client(self._token.address, authkey=self._authkey) dispatch(conn, None, 'accept_connection', (name,)) self._tls.connection = conn def _callmethod(self, methodname, args=(), kwds={}): ''' Try to call a method of the referrent and return a copy of the result ''' try: conn = self._tls.connection except AttributeError: util.debug('thread %r does not own a connection', threading.current_thread().name) self._connect() conn = self._tls.connection conn.send((self._id, methodname, args, kwds)) kind, result = conn.recv() if kind == '#RETURN': return result elif kind == '#PROXY': exposed, token = result proxytype = self._manager._registry[token.typeid][-1] proxy = proxytype( token, self._serializer, manager=self._manager, authkey=self._authkey, exposed=exposed ) conn = self._Client(token.address, authkey=self._authkey) dispatch(conn, None, 'decref', (token.id,)) return proxy raise convert_to_error(kind, result) def _getvalue(self): ''' Get a copy of the value of the referent ''' return self._callmethod('#GETVALUE') def _incref(self): conn = self._Client(self._token.address, authkey=self._authkey) dispatch(conn, None, 'incref', (self._id,)) util.debug('INCREF %r', self._token.id) self._idset.add(self._id) state = self._manager and self._manager._state self._close = util.Finalize( self, BaseProxy._decref, args=(self._token, self._authkey, state, self._tls, self._idset, self._Client), exitpriority=10 ) @staticmethod def _decref(token, authkey, state, tls, idset, _Client): idset.discard(token.id) # check whether manager is still alive if state is None or state.value == State.STARTED: # tell manager this process no longer cares about referent try: util.debug('DECREF %r', token.id) conn = _Client(token.address, authkey=authkey) dispatch(conn, None, 'decref', (token.id,)) except Exception as e: util.debug('... decref failed %s', e) else: util.debug('DECREF %r -- manager already shutdown', token.id) # check whether we can close this thread's connection because # the process owns no more references to objects for this manager if not idset and hasattr(tls, 'connection'): util.debug('thread %r has no more proxies so closing conn', threading.current_thread().name) tls.connection.close() del tls.connection def _after_fork(self): self._manager = None try: self._incref() except Exception as e: # the proxy may just be for a manager which has shutdown util.info('incref failed: %s' % e) def __reduce__(self): kwds = {} if Popen.thread_is_spawning(): kwds['authkey'] = self._authkey if getattr(self, '_isauto', False): kwds['exposed'] = self._exposed_ return (RebuildProxy, (AutoProxy, self._token, self._serializer, kwds)) else: return (RebuildProxy, (type(self), self._token, self._serializer, kwds)) def __deepcopy__(self, memo): return self._getvalue() def __repr__(self): return '<%s object, typeid %r at %s>' % \ (type(self).__name__, self._token.typeid, '0x%x' % id(self)) def __str__(self): ''' Return representation of the referent (or a fall-back if that fails) ''' try: return self._callmethod('__repr__') except Exception: return repr(self)[:-1] + "; '__str__()' failed>" # # Function used for unpickling # def RebuildProxy(func, token, serializer, kwds): ''' Function used for unpickling proxy objects. If possible the shared object is returned, or otherwise a proxy for it. ''' server = getattr(current_process(), '_manager_server', None) if server and server.address == token.address: return server.id_to_obj[token.id][0] else: incref = ( kwds.pop('incref', True) and not getattr(current_process(), '_inheriting', False) ) return func(token, serializer, incref=incref, **kwds) # # Functions to create proxies and proxy types # def MakeProxyType(name, exposed, _cache={}): ''' Return an proxy type whose methods are given by `exposed` ''' exposed = tuple(exposed) try: return _cache[(name, exposed)] except KeyError: pass dic = {} for meth in exposed: exec('''def %s(self, *args, **kwds): return self._callmethod(%r, args, kwds)''' % (meth, meth), dic) ProxyType = type(name, (BaseProxy,), dic) ProxyType._exposed_ = exposed _cache[(name, exposed)] = ProxyType return ProxyType def AutoProxy(token, serializer, manager=None, authkey=None, exposed=None, incref=True): ''' Return an auto-proxy for `token` ''' _Client = listener_client[serializer][1] if exposed is None: conn = _Client(token.address, authkey=authkey) try: exposed = dispatch(conn, None, 'get_methods', (token,)) finally: conn.close() if authkey is None and manager is not None: authkey = manager._authkey if authkey is None: authkey = current_process().authkey ProxyType = MakeProxyType('AutoProxy[%s]' % token.typeid, exposed) proxy = ProxyType(token, serializer, manager=manager, authkey=authkey, incref=incref) proxy._isauto = True return proxy # # Types/callables which we will register with SyncManager # class Namespace(object): def __init__(self, **kwds): self.__dict__.update(kwds) def __repr__(self): items = list(self.__dict__.items()) temp = [] for name, value in items: if not name.startswith('_'): temp.append('%s=%r' % (name, value)) temp.sort() return 'Namespace(%s)' % str.join(', ', temp) class Value(object): def __init__(self, typecode, value, lock=True): self._typecode = typecode self._value = value def get(self): return self._value def set(self, value): self._value = value def __repr__(self): return '%s(%r, %r)'%(type(self).__name__, self._typecode, self._value) value = property(get, set) def Array(typecode, sequence, lock=True): return array.array(typecode, sequence) # # Proxy types used by SyncManager # class IteratorProxy(BaseProxy): _exposed_ = ('__next__', 'send', 'throw', 'close') def __iter__(self): return self def __next__(self, *args): return self._callmethod('__next__', args) def send(self, *args): return self._callmethod('send', args) def throw(self, *args): return self._callmethod('throw', args) def close(self, *args): return self._callmethod('close', args) class AcquirerProxy(BaseProxy): _exposed_ = ('acquire', 'release') def acquire(self, blocking=True, timeout=None): args = (blocking,) if timeout is None else (blocking, timeout) return self._callmethod('acquire', args) def release(self): return self._callmethod('release') def __enter__(self): return self._callmethod('acquire') def __exit__(self, exc_type, exc_val, exc_tb): return self._callmethod('release') class ConditionProxy(AcquirerProxy): _exposed_ = ('acquire', 'release', 'wait', 'notify', 'notify_all') def wait(self, timeout=None): return self._callmethod('wait', (timeout,)) def notify(self): return self._callmethod('notify') def notify_all(self): return self._callmethod('notify_all') def wait_for(self, predicate, timeout=None): result = predicate() if result: return result if timeout is not None: endtime = _time() + timeout else: endtime = None waittime = None while not result: if endtime is not None: waittime = endtime - _time() if waittime <= 0: break self.wait(waittime) result = predicate() return result class EventProxy(BaseProxy): _exposed_ = ('is_set', 'set', 'clear', 'wait') def is_set(self): return self._callmethod('is_set') def set(self): return self._callmethod('set') def clear(self): return self._callmethod('clear') def wait(self, timeout=None): return self._callmethod('wait', (timeout,)) class BarrierProxy(BaseProxy): _exposed_ = ('__getattribute__', 'wait', 'abort', 'reset') def wait(self, timeout=None): return self._callmethod('wait', (timeout,)) def abort(self): return self._callmethod('abort') def reset(self): return self._callmethod('reset') @property def parties(self): return self._callmethod('__getattribute__', ('parties',)) @property def n_waiting(self): return self._callmethod('__getattribute__', ('n_waiting',)) @property def broken(self): return self._callmethod('__getattribute__', ('broken',)) class NamespaceProxy(BaseProxy): _exposed_ = ('__getattribute__', '__setattr__', '__delattr__') def __getattr__(self, key): if key[0] == '_': return object.__getattribute__(self, key) callmethod = object.__getattribute__(self, '_callmethod') return callmethod('__getattribute__', (key,)) def __setattr__(self, key, value): if key[0] == '_': return object.__setattr__(self, key, value) callmethod = object.__getattribute__(self, '_callmethod') return callmethod('__setattr__', (key, value)) def __delattr__(self, key): if key[0] == '_': return object.__delattr__(self, key) callmethod = object.__getattribute__(self, '_callmethod') return callmethod('__delattr__', (key,)) class ValueProxy(BaseProxy): _exposed_ = ('get', 'set') def get(self): return self._callmethod('get') def set(self, value): return self._callmethod('set', (value,)) value = property(get, set) BaseListProxy = MakeProxyType('BaseListProxy', ( '__add__', '__contains__', '__delitem__', '__getitem__', '__len__', '__mul__', '__reversed__', '__rmul__', '__setitem__', 'append', 'count', 'extend', 'index', 'insert', 'pop', 'remove', 'reverse', 'sort', '__imul__' )) class ListProxy(BaseListProxy): def __iadd__(self, value): self._callmethod('extend', (value,)) return self def __imul__(self, value): self._callmethod('__imul__', (value,)) return self DictProxy = MakeProxyType('DictProxy', ( '__contains__', '__delitem__', '__getitem__', '__len__', '__setitem__', 'clear', 'copy', 'get', 'has_key', 'items', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values' )) ArrayProxy = MakeProxyType('ArrayProxy', ( '__len__', '__getitem__', '__setitem__' )) PoolProxy = MakeProxyType('PoolProxy', ( 'apply', 'apply_async', 'close', 'imap', 'imap_unordered', 'join', 'map', 'map_async', 'starmap', 'starmap_async', 'terminate' )) PoolProxy._method_to_typeid_ = { 'apply_async': 'AsyncResult', 'map_async': 'AsyncResult', 'starmap_async': 'AsyncResult', 'imap': 'Iterator', 'imap_unordered': 'Iterator' } # # Definition of SyncManager # class SyncManager(BaseManager): ''' Subclass of `BaseManager` which supports a number of shared object types. The types registered are those intended for the synchronization of threads, plus `dict`, `list` and `Namespace`. The `multiprocessing.Manager()` function creates started instances of this class. ''' SyncManager.register('Queue', queue.Queue) SyncManager.register('JoinableQueue', queue.Queue) SyncManager.register('Event', threading.Event, EventProxy) SyncManager.register('Lock', threading.Lock, AcquirerProxy) SyncManager.register('RLock', threading.RLock, AcquirerProxy) SyncManager.register('Semaphore', threading.Semaphore, AcquirerProxy) SyncManager.register('BoundedSemaphore', threading.BoundedSemaphore, AcquirerProxy) SyncManager.register('Condition', threading.Condition, ConditionProxy) SyncManager.register('Barrier', threading.Barrier, BarrierProxy) SyncManager.register('Pool', Pool, PoolProxy) SyncManager.register('list', list, ListProxy) SyncManager.register('dict', dict, DictProxy) SyncManager.register('Value', Value, ValueProxy) SyncManager.register('Array', Array, ArrayProxy) SyncManager.register('Namespace', Namespace, NamespaceProxy) # types returned by methods of PoolProxy SyncManager.register('Iterator', proxytype=IteratorProxy, create_method=False) SyncManager.register('AsyncResult', create_method=False)