summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorDong-hee Na <donghee.na@python.org>2022-09-10 20:44:10 (GMT)
committerGitHub <noreply@github.com>2022-09-10 20:44:10 (GMT)
commit8d75a13fdece95ddc1bba42cad3aea3ccb396e05 (patch)
treed10fcf61314f11ecd5dd966b5c1c874808ad5727 /Lib
parentc4e57fb6df15266920941b732ef3a58fb619d851 (diff)
downloadcpython-8d75a13fdece95ddc1bba42cad3aea3ccb396e05.zip
cpython-8d75a13fdece95ddc1bba42cad3aea3ccb396e05.tar.gz
cpython-8d75a13fdece95ddc1bba42cad3aea3ccb396e05.tar.bz2
gh-90751: memoryview now supports half-float (#96738)
Co-authored-by: Antoine Pitrou <antoine@python.org>
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_buffer.py20
-rw-r--r--Lib/test/test_memoryview.py9
2 files changed, 22 insertions, 7 deletions
diff --git a/Lib/test/test_buffer.py b/Lib/test/test_buffer.py
index 468c6ea..8ac3b7e 100644
--- a/Lib/test/test_buffer.py
+++ b/Lib/test/test_buffer.py
@@ -64,7 +64,7 @@ NATIVE = {
'?':0, 'c':0, 'b':0, 'B':0,
'h':0, 'H':0, 'i':0, 'I':0,
'l':0, 'L':0, 'n':0, 'N':0,
- 'f':0, 'd':0, 'P':0
+ 'e':0, 'f':0, 'd':0, 'P':0
}
# NumPy does not have 'n' or 'N':
@@ -89,7 +89,8 @@ STANDARD = {
'i':(-(1<<31), 1<<31), 'I':(0, 1<<32),
'l':(-(1<<31), 1<<31), 'L':(0, 1<<32),
'q':(-(1<<63), 1<<63), 'Q':(0, 1<<64),
- 'f':(-(1<<63), 1<<63), 'd':(-(1<<1023), 1<<1023)
+ 'e':(-65519, 65520), 'f':(-(1<<63), 1<<63),
+ 'd':(-(1<<1023), 1<<1023)
}
def native_type_range(fmt):
@@ -98,6 +99,8 @@ def native_type_range(fmt):
lh = (0, 256)
elif fmt == '?':
lh = (0, 2)
+ elif fmt == 'e':
+ lh = (-65519, 65520)
elif fmt == 'f':
lh = (-(1<<63), 1<<63)
elif fmt == 'd':
@@ -125,7 +128,10 @@ if struct:
for fmt in fmtdict['@']:
fmtdict['@'][fmt] = native_type_range(fmt)
+# Format codes suppported by the memoryview object
MEMORYVIEW = NATIVE.copy()
+
+# Format codes suppported by array.array
ARRAY = NATIVE.copy()
for k in NATIVE:
if not k in "bBhHiIlLfd":
@@ -164,7 +170,7 @@ def randrange_fmt(mode, char, obj):
x = b'\x01'
if char == '?':
x = bool(x)
- if char == 'f' or char == 'd':
+ if char in 'efd':
x = struct.pack(char, x)
x = struct.unpack(char, x)[0]
return x
@@ -2246,7 +2252,7 @@ class TestBufferProtocol(unittest.TestCase):
###
### Fortran output:
### ---------------
- ### >>> fortran_buf = nd.tostring(order='F')
+ ### >>> fortran_buf = nd.tobytes(order='F')
### >>> fortran_buf
### b'\x00\x04\x08\x01\x05\t\x02\x06\n\x03\x07\x0b'
###
@@ -2289,7 +2295,7 @@ class TestBufferProtocol(unittest.TestCase):
self.assertEqual(memoryview(y), memoryview(nd))
if numpy_array:
- self.assertEqual(b, na.tostring(order='C'))
+ self.assertEqual(b, na.tobytes(order='C'))
# 'F' request
if f == 0: # 'C' to 'F'
@@ -2312,7 +2318,7 @@ class TestBufferProtocol(unittest.TestCase):
self.assertEqual(memoryview(y), memoryview(nd))
if numpy_array:
- self.assertEqual(b, na.tostring(order='F'))
+ self.assertEqual(b, na.tobytes(order='F'))
# 'A' request
if f == ND_FORTRAN:
@@ -2336,7 +2342,7 @@ class TestBufferProtocol(unittest.TestCase):
self.assertEqual(memoryview(y), memoryview(nd))
if numpy_array:
- self.assertEqual(b, na.tostring(order='A'))
+ self.assertEqual(b, na.tobytes(order='A'))
# multi-dimensional, non-contiguous input
nd = ndarray(list(range(12)), shape=[3, 4], flags=ND_WRITABLE|ND_PIL)
diff --git a/Lib/test/test_memoryview.py b/Lib/test/test_memoryview.py
index 9d1e1f3..0eb2a36 100644
--- a/Lib/test/test_memoryview.py
+++ b/Lib/test/test_memoryview.py
@@ -13,6 +13,7 @@ import array
import io
import copy
import pickle
+import struct
from test.support import import_helper
@@ -527,6 +528,14 @@ class OtherTest(unittest.TestCase):
m[2:] = memoryview(p6).cast(format)[2:]
self.assertEqual(d.value, 0.6)
+ def test_half_float(self):
+ half_data = struct.pack('eee', 0.0, -1.5, 1.5)
+ float_data = struct.pack('fff', 0.0, -1.5, 1.5)
+ half_view = memoryview(half_data).cast('e')
+ float_view = memoryview(float_data).cast('f')
+ self.assertEqual(half_view.nbytes * 2, float_view.nbytes)
+ self.assertListEqual(half_view.tolist(), float_view.tolist())
+
def test_memoryview_hex(self):
# Issue #9951: memoryview.hex() segfaults with non-contiguous buffers.
x = b'0' * 200000