diff options
author | Dong-hee Na <donghee.na@python.org> | 2022-09-10 20:44:10 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-10 20:44:10 (GMT) |
commit | 8d75a13fdece95ddc1bba42cad3aea3ccb396e05 (patch) | |
tree | d10fcf61314f11ecd5dd966b5c1c874808ad5727 /Lib | |
parent | c4e57fb6df15266920941b732ef3a58fb619d851 (diff) | |
download | cpython-8d75a13fdece95ddc1bba42cad3aea3ccb396e05.zip cpython-8d75a13fdece95ddc1bba42cad3aea3ccb396e05.tar.gz cpython-8d75a13fdece95ddc1bba42cad3aea3ccb396e05.tar.bz2 |
gh-90751: memoryview now supports half-float (#96738)
Co-authored-by: Antoine Pitrou <antoine@python.org>
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/test_buffer.py | 20 | ||||
-rw-r--r-- | Lib/test/test_memoryview.py | 9 |
2 files changed, 22 insertions, 7 deletions
diff --git a/Lib/test/test_buffer.py b/Lib/test/test_buffer.py index 468c6ea..8ac3b7e 100644 --- a/Lib/test/test_buffer.py +++ b/Lib/test/test_buffer.py @@ -64,7 +64,7 @@ NATIVE = { '?':0, 'c':0, 'b':0, 'B':0, 'h':0, 'H':0, 'i':0, 'I':0, 'l':0, 'L':0, 'n':0, 'N':0, - 'f':0, 'd':0, 'P':0 + 'e':0, 'f':0, 'd':0, 'P':0 } # NumPy does not have 'n' or 'N': @@ -89,7 +89,8 @@ STANDARD = { 'i':(-(1<<31), 1<<31), 'I':(0, 1<<32), 'l':(-(1<<31), 1<<31), 'L':(0, 1<<32), 'q':(-(1<<63), 1<<63), 'Q':(0, 1<<64), - 'f':(-(1<<63), 1<<63), 'd':(-(1<<1023), 1<<1023) + 'e':(-65519, 65520), 'f':(-(1<<63), 1<<63), + 'd':(-(1<<1023), 1<<1023) } def native_type_range(fmt): @@ -98,6 +99,8 @@ def native_type_range(fmt): lh = (0, 256) elif fmt == '?': lh = (0, 2) + elif fmt == 'e': + lh = (-65519, 65520) elif fmt == 'f': lh = (-(1<<63), 1<<63) elif fmt == 'd': @@ -125,7 +128,10 @@ if struct: for fmt in fmtdict['@']: fmtdict['@'][fmt] = native_type_range(fmt) +# Format codes suppported by the memoryview object MEMORYVIEW = NATIVE.copy() + +# Format codes suppported by array.array ARRAY = NATIVE.copy() for k in NATIVE: if not k in "bBhHiIlLfd": @@ -164,7 +170,7 @@ def randrange_fmt(mode, char, obj): x = b'\x01' if char == '?': x = bool(x) - if char == 'f' or char == 'd': + if char in 'efd': x = struct.pack(char, x) x = struct.unpack(char, x)[0] return x @@ -2246,7 +2252,7 @@ class TestBufferProtocol(unittest.TestCase): ### ### Fortran output: ### --------------- - ### >>> fortran_buf = nd.tostring(order='F') + ### >>> fortran_buf = nd.tobytes(order='F') ### >>> fortran_buf ### b'\x00\x04\x08\x01\x05\t\x02\x06\n\x03\x07\x0b' ### @@ -2289,7 +2295,7 @@ class TestBufferProtocol(unittest.TestCase): self.assertEqual(memoryview(y), memoryview(nd)) if numpy_array: - self.assertEqual(b, na.tostring(order='C')) + self.assertEqual(b, na.tobytes(order='C')) # 'F' request if f == 0: # 'C' to 'F' @@ -2312,7 +2318,7 @@ class TestBufferProtocol(unittest.TestCase): self.assertEqual(memoryview(y), memoryview(nd)) if numpy_array: - self.assertEqual(b, na.tostring(order='F')) + self.assertEqual(b, na.tobytes(order='F')) # 'A' request if f == ND_FORTRAN: @@ -2336,7 +2342,7 @@ class TestBufferProtocol(unittest.TestCase): self.assertEqual(memoryview(y), memoryview(nd)) if numpy_array: - self.assertEqual(b, na.tostring(order='A')) + self.assertEqual(b, na.tobytes(order='A')) # multi-dimensional, non-contiguous input nd = ndarray(list(range(12)), shape=[3, 4], flags=ND_WRITABLE|ND_PIL) diff --git a/Lib/test/test_memoryview.py b/Lib/test/test_memoryview.py index 9d1e1f3..0eb2a36 100644 --- a/Lib/test/test_memoryview.py +++ b/Lib/test/test_memoryview.py @@ -13,6 +13,7 @@ import array import io import copy import pickle +import struct from test.support import import_helper @@ -527,6 +528,14 @@ class OtherTest(unittest.TestCase): m[2:] = memoryview(p6).cast(format)[2:] self.assertEqual(d.value, 0.6) + def test_half_float(self): + half_data = struct.pack('eee', 0.0, -1.5, 1.5) + float_data = struct.pack('fff', 0.0, -1.5, 1.5) + half_view = memoryview(half_data).cast('e') + float_view = memoryview(float_data).cast('f') + self.assertEqual(half_view.nbytes * 2, float_view.nbytes) + self.assertListEqual(half_view.tolist(), float_view.tolist()) + def test_memoryview_hex(self): # Issue #9951: memoryview.hex() segfaults with non-contiguous buffers. x = b'0' * 200000 |