summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_picklebuffer.py
blob: 435b3e038aa394f653119f38f84adad5460479ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Unit tests for the PickleBuffer object.

Pickling tests themselves are in pickletester.py.
"""

import gc
from pickle import PickleBuffer
import weakref
import unittest

from test.support import import_helper


class B(bytes):
    pass


class PickleBufferTest(unittest.TestCase):

    def check_memoryview(self, pb, equiv):
        with memoryview(pb) as m:
            with memoryview(equiv) as expected:
                self.assertEqual(m.nbytes, expected.nbytes)
                self.assertEqual(m.readonly, expected.readonly)
                self.assertEqual(m.itemsize, expected.itemsize)
                self.assertEqual(m.shape, expected.shape)
                self.assertEqual(m.strides, expected.strides)
                self.assertEqual(m.c_contiguous, expected.c_contiguous)
                self.assertEqual(m.f_contiguous, expected.f_contiguous)
                self.assertEqual(m.format, expected.format)
                self.assertEqual(m.tobytes(), expected.tobytes())

    def test_constructor_failure(self):
        with self.assertRaises(TypeError):
            PickleBuffer()
        with self.assertRaises(TypeError):
            PickleBuffer("foo")
        # Released memoryview fails taking a buffer
        m = memoryview(b"foo")
        m.release()
        with self.assertRaises(ValueError):
            PickleBuffer(m)

    def test_basics(self):
        pb = PickleBuffer(b"foo")
        self.assertEqual(b"foo", bytes(pb))
        with memoryview(pb) as m:
            self.assertTrue(m.readonly)

        pb = PickleBuffer(bytearray(b"foo"))
        self.assertEqual(b"foo", bytes(pb))
        with memoryview(pb) as m:
            self.assertFalse(m.readonly)
            m[0] = 48
        self.assertEqual(b"0oo", bytes(pb))

    def test_release(self):
        pb = PickleBuffer(b"foo")
        pb.release()
        with self.assertRaises(ValueError) as raises:
            memoryview(pb)
        self.assertIn("operation forbidden on released PickleBuffer object",
                      str(raises.exception))
        # Idempotency
        pb.release()

    def test_cycle(self):
        b = B(b"foo")
        pb = PickleBuffer(b)
        b.cycle = pb
        wpb = weakref.ref(pb)
        del b, pb
        gc.collect()
        self.assertIsNone(wpb())

    def test_ndarray_2d(self):
        # C-contiguous
        ndarray = import_helper.import_module("_testbuffer").ndarray
        arr = ndarray(list(range(12)), shape=(4, 3), format='<i')
        self.assertTrue(arr.c_contiguous)
        self.assertFalse(arr.f_contiguous)
        pb = PickleBuffer(arr)
        self.check_memoryview(pb, arr)
        # Non-contiguous
        arr = arr[::2]
        self.assertFalse(arr.c_contiguous)
        self.assertFalse(arr.f_contiguous)
        pb = PickleBuffer(arr)
        self.check_memoryview(pb, arr)
        # F-contiguous
        arr = ndarray(list(range(12)), shape=(3, 4), strides=(4, 12), format='<i')
        self.assertTrue(arr.f_contiguous)
        self.assertFalse(arr.c_contiguous)
        pb = PickleBuffer(arr)
        self.check_memoryview(pb, arr)

    # Tests for PickleBuffer.raw()

    def check_raw(self, obj, equiv):
        pb = PickleBuffer(obj)
        with pb.raw() as m:
            self.assertIsInstance(m, memoryview)
            self.check_memoryview(m, equiv)

    def test_raw(self):
        for obj in (b"foo", bytearray(b"foo")):
            with self.subTest(obj=obj):
                self.check_raw(obj, obj)

    def test_raw_ndarray(self):
        # 1-D, contiguous
        ndarray = import_helper.import_module("_testbuffer").ndarray
        arr = ndarray(list(range(3)), shape=(3,), format='<h')
        equiv = b"\x00\x00\x01\x00\x02\x00"
        self.check_raw(arr, equiv)
        # 2-D, C-contiguous
        arr = ndarray(list(range(6)), shape=(2, 3), format='<h')
        equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00"
        self.check_raw(arr, equiv)
        # 2-D, F-contiguous
        arr = ndarray(list(range(6)), shape=(2, 3), strides=(2, 4),
                      format='<h')
        # Note this is different from arr.tobytes()
        equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00"
        self.check_raw(arr, equiv)
        # 0-D
        arr = ndarray(456, shape=(), format='<i')
        equiv = b'\xc8\x01\x00\x00'
        self.check_raw(arr, equiv)

    def check_raw_non_contiguous(self, obj):
        pb = PickleBuffer(obj)
        with self.assertRaisesRegex(BufferError, "non-contiguous"):
            pb.raw()

    def test_raw_non_contiguous(self):
        # 1-D
        ndarray = import_helper.import_module("_testbuffer").ndarray
        arr = ndarray(list(range(6)), shape=(6,), format='<i')[::2]
        self.check_raw_non_contiguous(arr)
        # 2-D
        arr = ndarray(list(range(12)), shape=(4, 3), format='<i')[::2]
        self.check_raw_non_contiguous(arr)

    def test_raw_released(self):
        pb = PickleBuffer(b"foo")
        pb.release()
        with self.assertRaises(ValueError) as raises:
            pb.raw()


if __name__ == "__main__":
    unittest.main()