summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/pickle.py50
-rw-r--r--Lib/test/pickletester.py49
2 files changed, 78 insertions, 21 deletions
diff --git a/Lib/pickle.py b/Lib/pickle.py
index 33c97c8..d719ceb 100644
--- a/Lib/pickle.py
+++ b/Lib/pickle.py
@@ -782,14 +782,10 @@ class _Pickler:
self.write(FLOAT + repr(obj).encode("ascii") + b'\n')
dispatch[float] = save_float
- def save_bytes(self, obj):
- if self.proto < 3:
- if not obj: # bytes object is empty
- self.save_reduce(bytes, (), obj=obj)
- else:
- self.save_reduce(codecs.encode,
- (str(obj, 'latin1'), 'latin1'), obj=obj)
- return
+ def _save_bytes_no_memo(self, obj):
+ # helper for writing bytes objects for protocol >= 3
+ # without memoizing them
+ assert self.proto >= 3
n = len(obj)
if n <= 0xff:
self.write(SHORT_BINBYTES + pack("<B", n) + obj)
@@ -799,9 +795,29 @@ class _Pickler:
self._write_large_bytes(BINBYTES + pack("<I", n), obj)
else:
self.write(BINBYTES + pack("<I", n) + obj)
+
+ def save_bytes(self, obj):
+ if self.proto < 3:
+ if not obj: # bytes object is empty
+ self.save_reduce(bytes, (), obj=obj)
+ else:
+ self.save_reduce(codecs.encode,
+ (str(obj, 'latin1'), 'latin1'), obj=obj)
+ return
+ self._save_bytes_no_memo(obj)
self.memoize(obj)
dispatch[bytes] = save_bytes
+ def _save_bytearray_no_memo(self, obj):
+ # helper for writing bytearray objects for protocol >= 5
+ # without memoizing them
+ assert self.proto >= 5
+ n = len(obj)
+ if n >= self.framer._FRAME_SIZE_TARGET:
+ self._write_large_bytes(BYTEARRAY8 + pack("<Q", n), obj)
+ else:
+ self.write(BYTEARRAY8 + pack("<Q", n) + obj)
+
def save_bytearray(self, obj):
if self.proto < 5:
if not obj: # bytearray is empty
@@ -809,11 +825,7 @@ class _Pickler:
else:
self.save_reduce(bytearray, (bytes(obj),), obj=obj)
return
- n = len(obj)
- if n >= self.framer._FRAME_SIZE_TARGET:
- self._write_large_bytes(BYTEARRAY8 + pack("<Q", n), obj)
- else:
- self.write(BYTEARRAY8 + pack("<Q", n) + obj)
+ self._save_bytearray_no_memo(obj)
self.memoize(obj)
dispatch[bytearray] = save_bytearray
@@ -832,10 +844,18 @@ class _Pickler:
if in_band:
# Write data in-band
# XXX The C implementation avoids a copy here
+ buf = m.tobytes()
+ in_memo = id(buf) in self.memo
if m.readonly:
- self.save_bytes(m.tobytes())
+ if in_memo:
+ self._save_bytes_no_memo(buf)
+ else:
+ self.save_bytes(buf)
else:
- self.save_bytearray(m.tobytes())
+ if in_memo:
+ self._save_bytearray_no_memo(buf)
+ else:
+ self.save_bytearray(buf)
else:
# Write data out-of-band
self.write(NEXT_BUFFER)
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index 93e7dbb..9922591 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -1845,6 +1845,25 @@ class AbstractPickleTests:
p = self.dumps(s, proto)
self.assert_is_copy(s, self.loads(p))
+ def test_bytes_memoization(self):
+ for proto in protocols:
+ for array_type in [bytes, ZeroCopyBytes]:
+ for s in b'', b'xyz', b'xyz'*100:
+ with self.subTest(proto=proto, array_type=array_type, s=s, independent=False):
+ b = array_type(s)
+ p = self.dumps((b, b), proto)
+ x, y = self.loads(p)
+ self.assertIs(x, y)
+ self.assert_is_copy((b, b), (x, y))
+
+ with self.subTest(proto=proto, array_type=array_type, s=s, independent=True):
+ b1, b2 = array_type(s), array_type(s)
+ p = self.dumps((b1, b2), proto)
+ # Note that (b1, b2) = self.loads(p) might have identical
+ # components, i.e., b1 is b2, but this is not always the
+ # case if the content is large (equality still holds).
+ self.assert_is_copy((b1, b2), self.loads(p))
+
def test_bytearray(self):
for proto in protocols:
for s in b'', b'xyz', b'xyz'*100:
@@ -1864,13 +1883,31 @@ class AbstractPickleTests:
self.assertNotIn(b'bytearray', p)
self.assertTrue(opcode_in_pickle(pickle.BYTEARRAY8, p))
- def test_bytearray_memoization_bug(self):
+ def test_bytearray_memoization(self):
for proto in protocols:
- for s in b'', b'xyz', b'xyz'*100:
- b = bytearray(s)
- p = self.dumps((b, b), proto)
- b1, b2 = self.loads(p)
- self.assertIs(b1, b2)
+ for array_type in [bytearray, ZeroCopyBytearray]:
+ for s in b'', b'xyz', b'xyz'*100:
+ with self.subTest(proto=proto, array_type=array_type, s=s, independent=False):
+ b = array_type(s)
+ p = self.dumps((b, b), proto)
+ b1, b2 = self.loads(p)
+ self.assertIs(b1, b2)
+
+ with self.subTest(proto=proto, array_type=array_type, s=s, independent=True):
+ b1a, b2a = array_type(s), array_type(s)
+ # Unlike bytes, equal but independent bytearray objects are
+ # never identical.
+ self.assertIsNot(b1a, b2a)
+
+ p = self.dumps((b1a, b2a), proto)
+ b1b, b2b = self.loads(p)
+ self.assertIsNot(b1b, b2b)
+
+ self.assertIsNot(b1a, b1b)
+ self.assert_is_copy(b1a, b1b)
+
+ self.assertIsNot(b2a, b2b)
+ self.assert_is_copy(b2a, b2b)
def test_ints(self):
for proto in protocols: