From de50360ac2fec81dbf733f6c3c739b39a8822a39 Mon Sep 17 00:00:00 2001 From: Oren Milman Date: Thu, 24 Aug 2017 21:33:42 +0300 Subject: bpo-29741: Update some methods in the _pyio module to also accept integer types. Patch by Oren Milman. (#560) --- Lib/_pyio.py | 52 ++++++++++++++++------ Lib/test/test_memoryio.py | 28 ++++++++++++ .../2017-08-23-00-31-32.bpo-29741.EBn_DM.rst | 2 + 3 files changed, 68 insertions(+), 14 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2017-08-23-00-31-32.bpo-29741.EBn_DM.rst diff --git a/Lib/_pyio.py b/Lib/_pyio.py index d50230d..4653847 100644 --- a/Lib/_pyio.py +++ b/Lib/_pyio.py @@ -504,8 +504,13 @@ class IOBase(metaclass=abc.ABCMeta): return 1 if size is None: size = -1 - elif not isinstance(size, int): - raise TypeError("size must be an integer") + else: + try: + size_index = size.__index__ + except AttributeError: + raise TypeError(f"{size!r} is not an integer") + else: + size = size_index() res = bytearray() while size < 0 or len(res) < size: b = self.read(nreadahead()) @@ -868,6 +873,13 @@ class BytesIO(BufferedIOBase): raise ValueError("read from closed file") if size is None: size = -1 + else: + try: + size_index = size.__index__ + except AttributeError: + raise TypeError(f"{size!r} is not an integer") + else: + size = size_index() if size < 0: size = len(self._buffer) if len(self._buffer) <= self._pos: @@ -905,9 +917,11 @@ class BytesIO(BufferedIOBase): if self.closed: raise ValueError("seek on closed file") try: - pos.__index__ - except AttributeError as err: - raise TypeError("an integer is required") from err + pos_index = pos.__index__ + except AttributeError: + raise TypeError(f"{pos!r} is not an integer") + else: + pos = pos_index() if whence == 0: if pos < 0: raise ValueError("negative seek position %r" % (pos,)) @@ -932,9 +946,11 @@ class BytesIO(BufferedIOBase): pos = self._pos else: try: - pos.__index__ - except AttributeError as err: - raise TypeError("an integer is required") from err + pos_index = pos.__index__ + except AttributeError: + raise TypeError(f"{pos!r} is not an integer") + else: + pos = pos_index() if pos < 0: raise ValueError("negative truncate position %r" % (pos,)) del self._buffer[pos:] @@ -2378,11 +2394,14 @@ class TextIOWrapper(TextIOBase): self._checkReadable() if size is None: size = -1 + else: + try: + size_index = size.__index__ + except AttributeError: + raise TypeError(f"{size!r} is not an integer") + else: + size = size_index() decoder = self._decoder or self._get_decoder() - try: - size.__index__ - except AttributeError as err: - raise TypeError("an integer is required") from err if size < 0: # Read everything. result = (self._get_decoded_chars() + @@ -2413,8 +2432,13 @@ class TextIOWrapper(TextIOBase): raise ValueError("read from closed file") if size is None: size = -1 - elif not isinstance(size, int): - raise TypeError("size must be an integer") + else: + try: + size_index = size.__index__ + except AttributeError: + raise TypeError(f"{size!r} is not an integer") + else: + size = size_index() # Grab all the decoded text (we will rewind any extra bits later). line = self._get_decoded_chars() diff --git a/Lib/test/test_memoryio.py b/Lib/test/test_memoryio.py index 80055ce..e16c57e 100644 --- a/Lib/test/test_memoryio.py +++ b/Lib/test/test_memoryio.py @@ -11,6 +11,13 @@ import _pyio as pyio import pickle import sys +class IntLike: + def __init__(self, num): + self._num = num + def __index__(self): + return self._num + __int__ = __index__ + class MemorySeekTestMixin: def testInit(self): @@ -116,7 +123,10 @@ class MemoryTestMixin: memio = self.ioclass(buf) self.assertRaises(ValueError, memio.truncate, -1) + self.assertRaises(ValueError, memio.truncate, IntLike(-1)) memio.seek(6) + self.assertEqual(memio.truncate(IntLike(8)), 8) + self.assertEqual(memio.getvalue(), buf[:8]) self.assertEqual(memio.truncate(), 6) self.assertEqual(memio.getvalue(), buf[:6]) self.assertEqual(memio.truncate(4), 4) @@ -131,6 +141,7 @@ class MemoryTestMixin: self.assertRaises(TypeError, memio.truncate, '0') memio.close() self.assertRaises(ValueError, memio.truncate, 0) + self.assertRaises(ValueError, memio.truncate, IntLike(0)) def test_init(self): buf = self.buftype("1234567890") @@ -154,12 +165,19 @@ class MemoryTestMixin: self.assertEqual(memio.read(900), buf[5:]) self.assertEqual(memio.read(), self.EOF) memio.seek(0) + self.assertEqual(memio.read(IntLike(0)), self.EOF) + self.assertEqual(memio.read(IntLike(1)), buf[:1]) + self.assertEqual(memio.read(IntLike(4)), buf[1:5]) + self.assertEqual(memio.read(IntLike(900)), buf[5:]) + memio.seek(0) self.assertEqual(memio.read(), buf) self.assertEqual(memio.read(), self.EOF) self.assertEqual(memio.tell(), 10) memio.seek(0) self.assertEqual(memio.read(-1), buf) memio.seek(0) + self.assertEqual(memio.read(IntLike(-1)), buf) + memio.seek(0) self.assertEqual(type(memio.read()), type(buf)) memio.seek(100) self.assertEqual(type(memio.read()), type(buf)) @@ -169,6 +187,8 @@ class MemoryTestMixin: memio.seek(len(buf) + 1) self.assertEqual(memio.read(1), self.EOF) memio.seek(len(buf) + 1) + self.assertEqual(memio.read(IntLike(1)), self.EOF) + memio.seek(len(buf) + 1) self.assertEqual(memio.read(), self.EOF) memio.close() self.assertRaises(ValueError, memio.read) @@ -178,6 +198,7 @@ class MemoryTestMixin: memio = self.ioclass(buf * 2) self.assertEqual(memio.readline(0), self.EOF) + self.assertEqual(memio.readline(IntLike(0)), self.EOF) self.assertEqual(memio.readline(), buf) self.assertEqual(memio.readline(), buf) self.assertEqual(memio.readline(), self.EOF) @@ -186,9 +207,16 @@ class MemoryTestMixin: self.assertEqual(memio.readline(5), buf[5:10]) self.assertEqual(memio.readline(5), buf[10:15]) memio.seek(0) + self.assertEqual(memio.readline(IntLike(5)), buf[:5]) + self.assertEqual(memio.readline(IntLike(5)), buf[5:10]) + self.assertEqual(memio.readline(IntLike(5)), buf[10:15]) + memio.seek(0) self.assertEqual(memio.readline(-1), buf) memio.seek(0) + self.assertEqual(memio.readline(IntLike(-1)), buf) + memio.seek(0) self.assertEqual(memio.readline(0), self.EOF) + self.assertEqual(memio.readline(IntLike(0)), self.EOF) # Issue #24989: Buffer overread memio.seek(len(buf) * 2 + 1) self.assertEqual(memio.readline(), self.EOF) diff --git a/Misc/NEWS.d/next/Library/2017-08-23-00-31-32.bpo-29741.EBn_DM.rst b/Misc/NEWS.d/next/Library/2017-08-23-00-31-32.bpo-29741.EBn_DM.rst new file mode 100644 index 0000000..dce720b --- /dev/null +++ b/Misc/NEWS.d/next/Library/2017-08-23-00-31-32.bpo-29741.EBn_DM.rst @@ -0,0 +1,2 @@ +Update some methods in the _pyio module to also accept integer types. Patch +by Oren Milman. -- cgit v0.12