summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_memoryview.py
blob: 52fa3a935ff31a74a72acf71d620de945e100f38 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""Unit tests for the memoryview

XXX We need more tests! Some tests are in test_bytes
"""

import unittest
import test.support
import sys
import gc
import weakref
import array
import io


class AbstractMemoryTests:
    source_bytes = b"abcdef"

    @property
    def _source(self):
        return self.source_bytes

    @property
    def _types(self):
        return filter(None, [self.ro_type, self.rw_type])

    def check_getitem_with_type(self, tp):
        item = self.getitem_type
        b = tp(self._source)
        oldrefcount = sys.getrefcount(b)
        m = self._view(b)
        self.assertEqual(m[0], item(b"a"))
        self.assertTrue(isinstance(m[0], bytes), type(m[0]))
        self.assertEqual(m[5], item(b"f"))
        self.assertEqual(m[-1], item(b"f"))
        self.assertEqual(m[-6], item(b"a"))
        # Bounds checking
        self.assertRaises(IndexError, lambda: m[6])
        self.assertRaises(IndexError, lambda: m[-7])
        self.assertRaises(IndexError, lambda: m[sys.maxsize])
        self.assertRaises(IndexError, lambda: m[-sys.maxsize])
        # Type checking
        self.assertRaises(TypeError, lambda: m[None])
        self.assertRaises(TypeError, lambda: m[0.0])
        self.assertRaises(TypeError, lambda: m["a"])
        m = None
        self.assertEqual(sys.getrefcount(b), oldrefcount)

    def test_getitem(self):
        for tp in self._types:
            self.check_getitem_with_type(tp)

    def test_iter(self):
        for tp in self._types:
            b = tp(self._source)
            m = self._view(b)
            self.assertEqual(list(m), [m[i] for i in range(len(m))])

    def test_setitem_readonly(self):
        if not self.ro_type:
            return
        b = self.ro_type(self._source)
        oldrefcount = sys.getrefcount(b)
        m = self._view(b)
        def setitem(value):
            m[0] = value
        self.assertRaises(TypeError, setitem, b"a")
        self.assertRaises(TypeError, setitem, 65)
        self.assertRaises(TypeError, setitem, memoryview(b"a"))
        m = None
        self.assertEqual(sys.getrefcount(b), oldrefcount)

    def test_setitem_writable(self):
        if not self.rw_type:
            return
        tp = self.rw_type
        b = self.rw_type(self._source)
        oldrefcount = sys.getrefcount(b)
        m = self._view(b)
        m[0] = tp(b"0")
        self._check_contents(tp, b, b"0bcdef")
        m[1:3] = tp(b"12")
        self._check_contents(tp, b, b"012def")
        m[1:1] = tp(b"")
        self._check_contents(tp, b, b"012def")
        m[:] = tp(b"abcdef")
        self._check_contents(tp, b, b"abcdef")

        # Overlapping copies of a view into itself
        m[0:3] = m[2:5]
        self._check_contents(tp, b, b"cdedef")
        m[:] = tp(b"abcdef")
        m[2:5] = m[0:3]
        self._check_contents(tp, b, b"ababcf")

        def setitem(key, value):
            m[key] = tp(value)
        # Bounds checking
        self.assertRaises(IndexError, setitem, 6, b"a")
        self.assertRaises(IndexError, setitem, -7, b"a")
        self.assertRaises(IndexError, setitem, sys.maxsize, b"a")
        self.assertRaises(IndexError, setitem, -sys.maxsize, b"a")
        # Wrong index/slice types
        self.assertRaises(TypeError, setitem, 0.0, b"a")
        self.assertRaises(TypeError, setitem, (0,), b"a")
        self.assertRaises(TypeError, setitem, "a", b"a")
        # Trying to resize the memory object
        self.assertRaises(ValueError, setitem, 0, b"")
        self.assertRaises(ValueError, setitem, 0, b"ab")
        self.assertRaises(ValueError, setitem, slice(1,1), b"a")
        self.assertRaises(ValueError, setitem, slice(0,2), b"a")

        m = None
        self.assertEqual(sys.getrefcount(b), oldrefcount)

    def test_delitem(self):
        for tp in self._types:
            b = tp(self._source)
            m = self._view(b)
            with self.assertRaises(TypeError):
                del m[1]
            with self.assertRaises(TypeError):
                del m[1:4]

    def test_tobytes(self):
        for tp in self._types:
            m = self._view(tp(self._source))
            b = m.tobytes()
            # This calls self.getitem_type() on each separate byte of b"abcdef"
            expected = b"".join(
                self.getitem_type(bytes([c])) for c in b"abcdef")
            self.assertEqual(b, expected)
            self.assertTrue(isinstance(b, bytes), type(b))

    def test_tolist(self):
        for tp in self._types:
            m = self._view(tp(self._source))
            l = m.tolist()
            self.assertEqual(l, list(b"abcdef"))

    def test_compare(self):
        # memoryviews can compare for equality with other objects
        # having the buffer interface.
        for tp in self._types:
            m = self._view(tp(self._source))
            for tp_comp in self._types:
                self.assertTrue(m == tp_comp(b"abcdef"))
                self.assertFalse(m != tp_comp(b"abcdef"))
                self.assertFalse(m == tp_comp(b"abcde"))
                self.assertTrue(m != tp_comp(b"abcde"))
                self.assertFalse(m == tp_comp(b"abcde1"))
                self.assertTrue(m != tp_comp(b"abcde1"))
            self.assertTrue(m == m)
            self.assertTrue(m == m[:])
            self.assertTrue(m[0:6] == m[:])
            self.assertFalse(m[0:5] == m)

            # Comparison with objects which don't support the buffer API
            self.assertFalse(m == "abcdef")
            self.assertTrue(m != "abcdef")
            self.assertFalse("abcdef" == m)
            self.assertTrue("abcdef" != m)

            # Unordered comparisons
            for c in (m, b"abcdef"):
                self.assertRaises(TypeError, lambda: m < c)
                self.assertRaises(TypeError, lambda: c <= m)
                self.assertRaises(TypeError, lambda: m >= c)
                self.assertRaises(TypeError, lambda: c > m)

    def check_attributes_with_type(self, tp):
        m = self._view(tp(self._source))
        self.assertEqual(m.format, self.format)
        self.assertEqual(m.itemsize, self.itemsize)
        self.assertEqual(m.ndim, 1)
        self.assertEqual(m.shape, (6,))
        self.assertEqual(len(m), 6)
        self.assertEqual(m.strides, (self.itemsize,))
        self.assertEqual(m.suboffsets, None)
        return m

    def test_attributes_readonly(self):
        if not self.ro_type:
            return
        m = self.check_attributes_with_type(self.ro_type)
        self.assertEqual(m.readonly, True)

    def test_attributes_writable(self):
        if not self.rw_type:
            return
        m = self.check_attributes_with_type(self.rw_type)
        self.assertEqual(m.readonly, False)

    def test_getbuffer(self):
        # Test PyObject_GetBuffer() on a memoryview object.
        for tp in self._types:
            b = tp(self._source)
            oldrefcount = sys.getrefcount(b)
            m = self._view(b)
            oldviewrefcount = sys.getrefcount(m)
            s = str(m, "utf-8")
            self._check_contents(tp, b, s.encode("utf-8"))
            self.assertEqual(sys.getrefcount(m), oldviewrefcount)
            m = None
            self.assertEqual(sys.getrefcount(b), oldrefcount)

    def test_gc(self):
        for tp in self._types:
            if not isinstance(tp, type):
                # If tp is a factory rather than a plain type, skip
                continue

            class MySource(tp):
                pass
            class MyObject:
                pass

            # Create a reference cycle through a memoryview object
            b = MySource(tp(b'abc'))
            m = self._view(b)
            o = MyObject()
            b.m = m
            b.o = o
            wr = weakref.ref(o)
            b = m = o = None
            # The cycle must be broken
            gc.collect()
            self.assertTrue(wr() is None, wr())

    def test_writable_readonly(self):
        # Issue #10451: memoryview incorrectly exposes a readonly
        # buffer as writable causing a segfault if using mmap
        tp = self.ro_type
        if tp is None:
            return
        b = tp(self._source)
        m = self._view(b)
        i = io.BytesIO(b'ZZZZ')
        self.assertRaises(TypeError, i.readinto, m)

# Variations on source objects for the buffer: bytes-like objects, then arrays
# with itemsize > 1.
# NOTE: support for multi-dimensional objects is unimplemented.

class BaseBytesMemoryTests(AbstractMemoryTests):
    ro_type = bytes
    rw_type = bytearray
    getitem_type = bytes
    itemsize = 1
    format = 'B'

class BaseArrayMemoryTests(AbstractMemoryTests):
    ro_type = None
    rw_type = lambda self, b: array.array('i', list(b))
    getitem_type = lambda self, b: array.array('i', list(b)).tostring()
    itemsize = array.array('i').itemsize
    format = 'i'

    def test_getbuffer(self):
        # XXX Test should be adapted for non-byte buffers
        pass

    def test_tolist(self):
        # XXX NotImplementedError: tolist() only supports byte views
        pass


# Variations on indirection levels: memoryview, slice of memoryview,
# slice of slice of memoryview.
# This is important to test allocation subtleties.

class BaseMemoryviewTests:
    def _view(self, obj):
        return memoryview(obj)

    def _check_contents(self, tp, obj, contents):
        self.assertEqual(obj, tp(contents))

class BaseMemorySliceTests:
    source_bytes = b"XabcdefY"

    def _view(self, obj):
        m = memoryview(obj)
        return m[1:7]

    def _check_contents(self, tp, obj, contents):
        self.assertEqual(obj[1:7], tp(contents))

    def test_refs(self):
        for tp in self._types:
            m = memoryview(tp(self._source))
            oldrefcount = sys.getrefcount(m)
            m[1:2]
            self.assertEqual(sys.getrefcount(m), oldrefcount)

class BaseMemorySliceSliceTests:
    source_bytes = b"XabcdefY"

    def _view(self, obj):
        m = memoryview(obj)
        return m[:7][1:]

    def _check_contents(self, tp, obj, contents):
        self.assertEqual(obj[1:7], tp(contents))


# Concrete test classes

class BytesMemoryviewTest(unittest.TestCase,
    BaseMemoryviewTests, BaseBytesMemoryTests):

    def test_constructor(self):
        for tp in self._types:
            ob = tp(self._source)
            self.assertTrue(memoryview(ob))
            self.assertTrue(memoryview(object=ob))
            self.assertRaises(TypeError, memoryview)
            self.assertRaises(TypeError, memoryview, ob, ob)
            self.assertRaises(TypeError, memoryview, argument=ob)
            self.assertRaises(TypeError, memoryview, ob, argument=True)

class ArrayMemoryviewTest(unittest.TestCase,
    BaseMemoryviewTests, BaseArrayMemoryTests):

    def test_array_assign(self):
        # Issue #4569: segfault when mutating a memoryview with itemsize != 1
        a = array.array('i', range(10))
        m = memoryview(a)
        new_a = array.array('i', range(9, -1, -1))
        m[:] = new_a
        self.assertEqual(a, new_a)


class BytesMemorySliceTest(unittest.TestCase,
    BaseMemorySliceTests, BaseBytesMemoryTests):
    pass

class ArrayMemorySliceTest(unittest.TestCase,
    BaseMemorySliceTests, BaseArrayMemoryTests):
    pass

class BytesMemorySliceSliceTest(unittest.TestCase,
    BaseMemorySliceSliceTests, BaseBytesMemoryTests):
    pass

class ArrayMemorySliceSliceTest(unittest.TestCase,
    BaseMemorySliceSliceTests, BaseArrayMemoryTests):
    pass


def test_main():
    test.support.run_unittest(__name__)

if __name__ == "__main__":
    test_main()
debug('task handler sending sentinel to result handler') outqueue.put(None) # tell workers there is no more work debug('task handler sending sentinel to workers') for p in pool: put(None) except IOError: debug('task handler got IOError when sending sentinels') debug('task handler exiting') @staticmethod def _handle_results(outqueue, get, cache): thread = threading.current_thread() while 1: try: task = get() except (IOError, EOFError): debug('result handler got EOFError/IOError -- exiting') return if thread._state: assert thread._state == TERMINATE debug('result handler found thread._state=TERMINATE') break if task is None: debug('result handler got sentinel') break job, i, obj = task try: cache[job]._set(i, obj) except KeyError: pass while cache and thread._state != TERMINATE: try: task = get() except (IOError, EOFError): debug('result handler got EOFError/IOError -- exiting') return if task is None: debug('result handler ignoring extra sentinel') continue job, i, obj = task try: cache[job]._set(i, obj) except KeyError: pass if hasattr(outqueue, '_reader'): debug('ensuring that outqueue is not full') # If we don't make room available in outqueue then # attempts to add the sentinel (None) to outqueue may # block. There is guaranteed to be no more than 2 sentinels. try: for i in range(10): if not outqueue._reader.poll(): break get() except (IOError, EOFError): pass debug('result handler exiting: len(cache)=%s, thread._state=%s', len(cache), thread._state) @staticmethod def _get_tasks(func, it, size): it = iter(it) while 1: x = tuple(itertools.islice(it, size)) if not x: return yield (func, x) def __reduce__(self): raise NotImplementedError( 'pool objects cannot be passed between processes or pickled' ) def close(self): debug('closing pool') if self._state == RUN: self._state = CLOSE self._worker_handler._state = CLOSE self._taskqueue.put(None) def terminate(self): debug('terminating pool') self._state = TERMINATE self._worker_handler._state = TERMINATE self._terminate() def join(self): debug('joining pool') assert self._state in (CLOSE, TERMINATE) self._worker_handler.join() self._task_handler.join() self._result_handler.join() for p in self._pool: p.join() @staticmethod def _help_stuff_finish(inqueue, task_handler, size): # task_handler may be blocked trying to put items on inqueue debug('removing tasks from inqueue until task handler finished') inqueue._rlock.acquire() while task_handler.is_alive() and inqueue._reader.poll(): inqueue._reader.recv() time.sleep(0) @classmethod def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, worker_handler, task_handler, result_handler, cache): # this is guaranteed to only be called once debug('finalizing pool') worker_handler._state = TERMINATE task_handler._state = TERMINATE taskqueue.put(None) # sentinel debug('helping task handler/workers to finish') cls._help_stuff_finish(inqueue, task_handler, len(pool)) assert result_handler.is_alive() or len(cache) == 0 result_handler._state = TERMINATE outqueue.put(None) # sentinel # Terminate workers which haven't already finished. if pool and hasattr(pool[0], 'terminate'): debug('terminating workers') for p in pool: if p.exitcode is None: p.terminate() debug('joining task handler') task_handler.join() debug('joining result handler') task_handler.join() if pool and hasattr(pool[0], 'terminate'): debug('joining pool workers') for p in pool: if p.is_alive(): # worker has not yet exited debug('cleaning up worker %d' % p.pid) p.join() # # Class whose instances are returned by `Pool.apply_async()` # class ApplyResult(object): def __init__(self, cache, callback, error_callback): self._cond = threading.Condition(threading.Lock()) self._job = next(job_counter) self._cache = cache self._ready = False self._callback = callback self._error_callback = error_callback cache[self._job] = self def ready(self): return self._ready def successful(self): assert self._ready return self._success def wait(self, timeout=None): self._cond.acquire() try: if not self._ready: self._cond.wait(timeout) finally: self._cond.release() def get(self, timeout=None): self.wait(timeout) if not self._ready: raise TimeoutError if self._success: return self._value else: raise self._value def _set(self, i, obj): self._success, self._value = obj if self._callback and self._success: self._callback(self._value) if self._error_callback and not self._success: self._error_callback(self._value) self._cond.acquire() try: self._ready = True self._cond.notify() finally: self._cond.release() del self._cache[self._job] # # Class whose instances are returned by `Pool.map_async()` # class MapResult(ApplyResult): def __init__(self, cache, chunksize, length, callback, error_callback): ApplyResult.__init__(self, cache, callback, error_callback=error_callback) self._success = True self._value = [None] * length self._chunksize = chunksize if chunksize <= 0: self._number_left = 0 self._ready = True else: self._number_left = length//chunksize + bool(length % chunksize) def _set(self, i, success_result): success, result = success_result if success: self._value[i*self._chunksize:(i+1)*self._chunksize] = result self._number_left -= 1 if self._number_left == 0: if self._callback: self._callback(self._value) del self._cache[self._job] self._cond.acquire() try: self._ready = True self._cond.notify() finally: self._cond.release() else: self._success = False self._value = result if self._error_callback: self._error_callback(self._value) del self._cache[self._job] self._cond.acquire() try: self._ready = True self._cond.notify() finally: self._cond.release() # # Class whose instances are returned by `Pool.imap()` # class IMapIterator(object): def __init__(self, cache): self._cond = threading.Condition(threading.Lock()) self._job = next(job_counter) self._cache = cache self._items = collections.deque() self._index = 0 self._length = None self._unsorted = {} cache[self._job] = self def __iter__(self): return self def next(self, timeout=None): self._cond.acquire() try: try: item = self._items.popleft() except IndexError: if self._index == self._length: raise StopIteration self._cond.wait(timeout) try: item = self._items.popleft() except IndexError: if self._index == self._length: raise StopIteration raise TimeoutError finally: self._cond.release() success, value = item if success: return value raise value __next__ = next # XXX def _set(self, i, obj): self._cond.acquire() try: if self._index == i: self._items.append(obj) self._index += 1 while self._index in self._unsorted: obj = self._unsorted.pop(self._index) self._items.append(obj) self._index += 1 self._cond.notify() else: self._unsorted[i] = obj if self._index == self._length: del self._cache[self._job] finally: self._cond.release() def _set_length(self, length): self._cond.acquire() try: self._length = length if self._index == self._length: self._cond.notify() del self._cache[self._job] finally: self._cond.release() # # Class whose instances are returned by `Pool.imap_unordered()` # class IMapUnorderedIterator(IMapIterator): def _set(self, i, obj): self._cond.acquire() try: self._items.append(obj) self._index += 1 self._cond.notify() if self._index == self._length: del self._cache[self._job] finally: self._cond.release() # # # class ThreadPool(Pool): from .dummy import Process def __init__(self, processes=None, initializer=None, initargs=()): Pool.__init__(self, processes, initializer, initargs) def _setup_queues(self): self._inqueue = queue.Queue() self._outqueue = queue.Queue() self._quick_put = self._inqueue.put self._quick_get = self._outqueue.get @staticmethod def _help_stuff_finish(inqueue, task_handler, size): # put sentinels at head of inqueue to make workers finish inqueue.not_empty.acquire() try: inqueue.queue.clear() inqueue.queue.extend([None] * size) inqueue.not_empty.notify_all() finally: inqueue.not_empty.release()