summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2016-07-17 08:24:17 (GMT)
committerSerhiy Storchaka <storchaka@gmail.com>2016-07-17 08:24:17 (GMT)
commitdec25afab1c325c28621dda3ba2b32dbc200c8b3 (patch)
treee1b19ba17d8dab7f94aa6f0053264030b954dc32 /Lib
parent6fd76bceda3fefc5e5814108c5fe819050613d33 (diff)
downloadcpython-dec25afab1c325c28621dda3ba2b32dbc200c8b3.zip
cpython-dec25afab1c325c28621dda3ba2b32dbc200c8b3.tar.gz
cpython-dec25afab1c325c28621dda3ba2b32dbc200c8b3.tar.bz2
Issue #17711: Fixed unpickling by the persistent ID with protocol 0.
Original patch by Alexandre Vassalotti.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/pickle.py12
-rw-r--r--Lib/test/pickletester.py29
-rw-r--r--Lib/test/test_pickle.py33
3 files changed, 64 insertions, 10 deletions
diff --git a/Lib/pickle.py b/Lib/pickle.py
index 7760425..040ecb2 100644
--- a/Lib/pickle.py
+++ b/Lib/pickle.py
@@ -529,7 +529,11 @@ class _Pickler:
self.save(pid, save_persistent_id=False)
self.write(BINPERSID)
else:
- self.write(PERSID + str(pid).encode("ascii") + b'\n')
+ try:
+ self.write(PERSID + str(pid).encode("ascii") + b'\n')
+ except UnicodeEncodeError:
+ raise PicklingError(
+ "persistent IDs in protocol 0 must be ASCII strings")
def save_reduce(self, func, args, state=None, listitems=None,
dictitems=None, obj=None):
@@ -1075,7 +1079,11 @@ class _Unpickler:
dispatch[FRAME[0]] = load_frame
def load_persid(self):
- pid = self.readline()[:-1].decode("ascii")
+ try:
+ pid = self.readline()[:-1].decode("ascii")
+ except UnicodeDecodeError:
+ raise UnpicklingError(
+ "persistent IDs in protocol 0 must be ASCII strings")
self.append(self.persistent_load(pid))
dispatch[PERSID[0]] = load_persid
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index f252a0a..7922b54 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -2629,6 +2629,35 @@ class AbstractPersistentPicklerTests(unittest.TestCase):
self.assertEqual(self.load_false_count, 1)
+class AbstractIdentityPersistentPicklerTests(unittest.TestCase):
+
+ def persistent_id(self, obj):
+ return obj
+
+ def persistent_load(self, pid):
+ return pid
+
+ def _check_return_correct_type(self, obj, proto):
+ unpickled = self.loads(self.dumps(obj, proto))
+ self.assertIsInstance(unpickled, type(obj))
+ self.assertEqual(unpickled, obj)
+
+ def test_return_correct_type(self):
+ for proto in protocols:
+ # Protocol 0 supports only ASCII strings.
+ if proto == 0:
+ self._check_return_correct_type("abc", 0)
+ else:
+ for obj in [b"abc\n", "abc\n", -1, -1.1 * 0.1, str]:
+ self._check_return_correct_type(obj, proto)
+
+ def test_protocol0_is_ascii_only(self):
+ non_ascii_str = "\N{EMPTY SET}"
+ self.assertRaises(pickle.PicklingError, self.dumps, non_ascii_str, 0)
+ pickled = pickle.PERSID + non_ascii_str.encode('utf-8') + b'\n.'
+ self.assertRaises(pickle.UnpicklingError, self.loads, pickled)
+
+
class AbstractPicklerUnpicklerObjectTests(unittest.TestCase):
pickler_class = None
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py
index ee7a667..d467d52 100644
--- a/Lib/test/test_pickle.py
+++ b/Lib/test/test_pickle.py
@@ -14,6 +14,7 @@ from test.pickletester import AbstractUnpickleTests
from test.pickletester import AbstractPickleTests
from test.pickletester import AbstractPickleModuleTests
from test.pickletester import AbstractPersistentPicklerTests
+from test.pickletester import AbstractIdentityPersistentPicklerTests
from test.pickletester import AbstractPicklerUnpicklerObjectTests
from test.pickletester import AbstractDispatchTableTests
from test.pickletester import BigmemPickleTests
@@ -82,10 +83,7 @@ class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
return pickle.loads(buf, **kwds)
-class PyPersPicklerTests(AbstractPersistentPicklerTests):
-
- pickler = pickle._Pickler
- unpickler = pickle._Unpickler
+class PersistentPicklerUnpicklerMixin(object):
def dumps(self, arg, proto=None):
class PersPickler(self.pickler):
@@ -94,8 +92,7 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests):
f = io.BytesIO()
p = PersPickler(f, proto)
p.dump(arg)
- f.seek(0)
- return f.read()
+ return f.getvalue()
def loads(self, buf, **kwds):
class PersUnpickler(self.unpickler):
@@ -106,6 +103,20 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests):
return u.load()
+class PyPersPicklerTests(AbstractPersistentPicklerTests,
+ PersistentPicklerUnpicklerMixin):
+
+ pickler = pickle._Pickler
+ unpickler = pickle._Unpickler
+
+
+class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
+ PersistentPicklerUnpicklerMixin):
+
+ pickler = pickle._Pickler
+ unpickler = pickle._Unpickler
+
+
class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
pickler_class = pickle._Pickler
@@ -144,6 +155,10 @@ if has_c_implementation:
pickler = _pickle.Pickler
unpickler = _pickle.Unpickler
+ class CIdPersPicklerTests(PyIdPersPicklerTests):
+ pickler = _pickle.Pickler
+ unpickler = _pickle.Unpickler
+
class CDumpPickle_LoadPickle(PyPicklerTests):
pickler = _pickle.Pickler
unpickler = pickle._Unpickler
@@ -409,11 +424,13 @@ class CompatPickleTests(unittest.TestCase):
def test_main():
- tests = [PickleTests, PyUnpicklerTests, PyPicklerTests, PyPersPicklerTests,
+ tests = [PickleTests, PyUnpicklerTests, PyPicklerTests,
+ PyPersPicklerTests, PyIdPersPicklerTests,
PyDispatchTableTests, PyChainDispatchTableTests,
CompatPickleTests]
if has_c_implementation:
- tests.extend([CUnpicklerTests, CPicklerTests, CPersPicklerTests,
+ tests.extend([CUnpicklerTests, CPicklerTests,
+ CPersPicklerTests, CIdPersPicklerTests,
CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
PyPicklerUnpicklerObjectTests,
CPicklerUnpicklerObjectTests,