summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorOren Milman <orenmn@gmail.com>2017-08-24 18:33:42 (GMT)
committerSteve Dower <steve.dower@microsoft.com>2017-08-24 18:33:42 (GMT)
commitde50360ac2fec81dbf733f6c3c739b39a8822a39 (patch)
tree056148095d23b20ddceae1d53f64197de8ddce5d /Lib
parent13614e375cc3637cf1311733d453df6107e964ea (diff)
downloadcpython-de50360ac2fec81dbf733f6c3c739b39a8822a39.zip
cpython-de50360ac2fec81dbf733f6c3c739b39a8822a39.tar.gz
cpython-de50360ac2fec81dbf733f6c3c739b39a8822a39.tar.bz2
bpo-29741: Update some methods in the _pyio module to also accept integer types. Patch by Oren Milman. (#560)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/_pyio.py52
-rw-r--r--Lib/test/test_memoryio.py28
2 files changed, 66 insertions, 14 deletions
diff --git a/Lib/_pyio.py b/Lib/_pyio.py
index d50230d..4653847 100644
--- a/Lib/_pyio.py
+++ b/Lib/_pyio.py
@@ -504,8 +504,13 @@ class IOBase(metaclass=abc.ABCMeta):
return 1
if size is None:
size = -1
- elif not isinstance(size, int):
- raise TypeError("size must be an integer")
+ else:
+ try:
+ size_index = size.__index__
+ except AttributeError:
+ raise TypeError(f"{size!r} is not an integer")
+ else:
+ size = size_index()
res = bytearray()
while size < 0 or len(res) < size:
b = self.read(nreadahead())
@@ -868,6 +873,13 @@ class BytesIO(BufferedIOBase):
raise ValueError("read from closed file")
if size is None:
size = -1
+ else:
+ try:
+ size_index = size.__index__
+ except AttributeError:
+ raise TypeError(f"{size!r} is not an integer")
+ else:
+ size = size_index()
if size < 0:
size = len(self._buffer)
if len(self._buffer) <= self._pos:
@@ -905,9 +917,11 @@ class BytesIO(BufferedIOBase):
if self.closed:
raise ValueError("seek on closed file")
try:
- pos.__index__
- except AttributeError as err:
- raise TypeError("an integer is required") from err
+ pos_index = pos.__index__
+ except AttributeError:
+ raise TypeError(f"{pos!r} is not an integer")
+ else:
+ pos = pos_index()
if whence == 0:
if pos < 0:
raise ValueError("negative seek position %r" % (pos,))
@@ -932,9 +946,11 @@ class BytesIO(BufferedIOBase):
pos = self._pos
else:
try:
- pos.__index__
- except AttributeError as err:
- raise TypeError("an integer is required") from err
+ pos_index = pos.__index__
+ except AttributeError:
+ raise TypeError(f"{pos!r} is not an integer")
+ else:
+ pos = pos_index()
if pos < 0:
raise ValueError("negative truncate position %r" % (pos,))
del self._buffer[pos:]
@@ -2378,11 +2394,14 @@ class TextIOWrapper(TextIOBase):
self._checkReadable()
if size is None:
size = -1
+ else:
+ try:
+ size_index = size.__index__
+ except AttributeError:
+ raise TypeError(f"{size!r} is not an integer")
+ else:
+ size = size_index()
decoder = self._decoder or self._get_decoder()
- try:
- size.__index__
- except AttributeError as err:
- raise TypeError("an integer is required") from err
if size < 0:
# Read everything.
result = (self._get_decoded_chars() +
@@ -2413,8 +2432,13 @@ class TextIOWrapper(TextIOBase):
raise ValueError("read from closed file")
if size is None:
size = -1
- elif not isinstance(size, int):
- raise TypeError("size must be an integer")
+ else:
+ try:
+ size_index = size.__index__
+ except AttributeError:
+ raise TypeError(f"{size!r} is not an integer")
+ else:
+ size = size_index()
# Grab all the decoded text (we will rewind any extra bits later).
line = self._get_decoded_chars()
diff --git a/Lib/test/test_memoryio.py b/Lib/test/test_memoryio.py
index 80055ce..e16c57e 100644
--- a/Lib/test/test_memoryio.py
+++ b/Lib/test/test_memoryio.py
@@ -11,6 +11,13 @@ import _pyio as pyio
import pickle
import sys
+class IntLike:
+ def __init__(self, num):
+ self._num = num
+ def __index__(self):
+ return self._num
+ __int__ = __index__
+
class MemorySeekTestMixin:
def testInit(self):
@@ -116,7 +123,10 @@ class MemoryTestMixin:
memio = self.ioclass(buf)
self.assertRaises(ValueError, memio.truncate, -1)
+ self.assertRaises(ValueError, memio.truncate, IntLike(-1))
memio.seek(6)
+ self.assertEqual(memio.truncate(IntLike(8)), 8)
+ self.assertEqual(memio.getvalue(), buf[:8])
self.assertEqual(memio.truncate(), 6)
self.assertEqual(memio.getvalue(), buf[:6])
self.assertEqual(memio.truncate(4), 4)
@@ -131,6 +141,7 @@ class MemoryTestMixin:
self.assertRaises(TypeError, memio.truncate, '0')
memio.close()
self.assertRaises(ValueError, memio.truncate, 0)
+ self.assertRaises(ValueError, memio.truncate, IntLike(0))
def test_init(self):
buf = self.buftype("1234567890")
@@ -154,12 +165,19 @@ class MemoryTestMixin:
self.assertEqual(memio.read(900), buf[5:])
self.assertEqual(memio.read(), self.EOF)
memio.seek(0)
+ self.assertEqual(memio.read(IntLike(0)), self.EOF)
+ self.assertEqual(memio.read(IntLike(1)), buf[:1])
+ self.assertEqual(memio.read(IntLike(4)), buf[1:5])
+ self.assertEqual(memio.read(IntLike(900)), buf[5:])
+ memio.seek(0)
self.assertEqual(memio.read(), buf)
self.assertEqual(memio.read(), self.EOF)
self.assertEqual(memio.tell(), 10)
memio.seek(0)
self.assertEqual(memio.read(-1), buf)
memio.seek(0)
+ self.assertEqual(memio.read(IntLike(-1)), buf)
+ memio.seek(0)
self.assertEqual(type(memio.read()), type(buf))
memio.seek(100)
self.assertEqual(type(memio.read()), type(buf))
@@ -169,6 +187,8 @@ class MemoryTestMixin:
memio.seek(len(buf) + 1)
self.assertEqual(memio.read(1), self.EOF)
memio.seek(len(buf) + 1)
+ self.assertEqual(memio.read(IntLike(1)), self.EOF)
+ memio.seek(len(buf) + 1)
self.assertEqual(memio.read(), self.EOF)
memio.close()
self.assertRaises(ValueError, memio.read)
@@ -178,6 +198,7 @@ class MemoryTestMixin:
memio = self.ioclass(buf * 2)
self.assertEqual(memio.readline(0), self.EOF)
+ self.assertEqual(memio.readline(IntLike(0)), self.EOF)
self.assertEqual(memio.readline(), buf)
self.assertEqual(memio.readline(), buf)
self.assertEqual(memio.readline(), self.EOF)
@@ -186,9 +207,16 @@ class MemoryTestMixin:
self.assertEqual(memio.readline(5), buf[5:10])
self.assertEqual(memio.readline(5), buf[10:15])
memio.seek(0)
+ self.assertEqual(memio.readline(IntLike(5)), buf[:5])
+ self.assertEqual(memio.readline(IntLike(5)), buf[5:10])
+ self.assertEqual(memio.readline(IntLike(5)), buf[10:15])
+ memio.seek(0)
self.assertEqual(memio.readline(-1), buf)
memio.seek(0)
+ self.assertEqual(memio.readline(IntLike(-1)), buf)
+ memio.seek(0)
self.assertEqual(memio.readline(0), self.EOF)
+ self.assertEqual(memio.readline(IntLike(0)), self.EOF)
# Issue #24989: Buffer overread
memio.seek(len(buf) * 2 + 1)
self.assertEqual(memio.readline(), self.EOF)