diff options
-rw-r--r-- | Lib/pickle.py | 55 | ||||
-rw-r--r-- | Lib/pickletools.py | 174 | ||||
-rw-r--r-- | Lib/test/pickletester.py | 69 | ||||
-rw-r--r-- | Lib/test/test_pickle.py | 12 | ||||
-rw-r--r-- | Modules/cPickle.c | 7 |
5 files changed, 172 insertions, 145 deletions
diff --git a/Lib/pickle.py b/Lib/pickle.py index f976ffb..62658cb 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -465,7 +465,7 @@ class Pickler: self.write(BININT1 + bytes([obj])) return if obj <= 0xffff: - self.write(BININT2, bytes([obj&0xff, obj>>8])) + self.write(BININT2 + bytes([obj&0xff, obj>>8])) return # Next check for 4-byte signed ints: high_bits = obj >> 31 # note that Python shift sign-extends @@ -820,6 +820,7 @@ class Unpickler: key = read(1) if not key: raise EOFError + assert isinstance(key, bytes) dispatch[key[0]](self) except _Stop as stopinst: return stopinst.value @@ -892,7 +893,7 @@ class Unpickler: dispatch[BININT1[0]] = load_binint1 def load_binint2(self): - self.append(mloads(b'i' + self.read(2) + '\000\000')) + self.append(mloads(b'i' + self.read(2) + b'\000\000')) dispatch[BININT2[0]] = load_binint2 def load_long(self): @@ -1111,7 +1112,7 @@ class Unpickler: dispatch[DUP[0]] = load_dup def load_get(self): - self.append(self.memo[self.readline()[:-1]]) + self.append(self.memo[str8(self.readline())[:-1]]) dispatch[GET[0]] = load_get def load_binget(self): @@ -1226,24 +1227,24 @@ def encode_long(x): byte in the LONG1 pickling context. >>> encode_long(0) - '' + b'' >>> encode_long(255) - '\xff\x00' + b'\xff\x00' >>> encode_long(32767) - '\xff\x7f' + b'\xff\x7f' >>> encode_long(-256) - '\x00\xff' + b'\x00\xff' >>> encode_long(-32768) - '\x00\x80' + b'\x00\x80' >>> encode_long(-128) - '\x80' + b'\x80' >>> encode_long(127) - '\x7f' + b'\x7f' >>> """ if x == 0: - return '' + return b'' if x > 0: ashex = hex(x) assert ashex.startswith("0x") @@ -1284,24 +1285,24 @@ def encode_long(x): ashex = ashex[2:] assert len(ashex) & 1 == 0, (x, ashex) binary = _binascii.unhexlify(ashex) - return binary[::-1] + return bytes(binary[::-1]) def decode_long(data): r"""Decode a long from a two's complement little-endian binary string. - >>> decode_long('') + >>> decode_long(b'') 0 - >>> decode_long("\xff\x00") + >>> decode_long(b"\xff\x00") 255 - >>> decode_long("\xff\x7f") + >>> decode_long(b"\xff\x7f") 32767 - >>> decode_long("\x00\xff") + >>> decode_long(b"\x00\xff") -256 - >>> decode_long("\x00\x80") + >>> decode_long(b"\x00\x80") -32768 - >>> decode_long("\x80") + >>> decode_long(b"\x80") -128 - >>> decode_long("\x7f") + >>> decode_long(b"\x7f") 127 """ @@ -1310,7 +1311,7 @@ def decode_long(data): return 0 ashex = _binascii.hexlify(data[::-1]) n = int(ashex, 16) # quadratic time before Python 2.3; linear now - if data[-1] >= '\x80': + if data[-1] >= 0x80: n -= 1 << (nbytes * 8) return n @@ -1320,15 +1321,19 @@ def dump(obj, file, protocol=None): Pickler(file, protocol).dump(obj) def dumps(obj, protocol=None): - file = io.BytesIO() - Pickler(file, protocol).dump(obj) - return file.getvalue() + f = io.BytesIO() + Pickler(f, protocol).dump(obj) + res = f.getvalue() + assert isinstance(res, bytes) + return res def load(file): return Unpickler(file).load() -def loads(str): - file = io.BytesIO(str) +def loads(s): + if isinstance(s, str): + raise TypeError("Can't load pickle from unicode string") + file = io.BytesIO(s) return Unpickler(file).load() # Doctest diff --git a/Lib/pickletools.py b/Lib/pickletools.py index c050fc5..c5c45eb 100644 --- a/Lib/pickletools.py +++ b/Lib/pickletools.py @@ -202,14 +202,14 @@ from struct import unpack as _unpack def read_uint1(f): r""" - >>> import StringIO - >>> read_uint1(StringIO.StringIO('\xff')) + >>> import io + >>> read_uint1(io.BytesIO(b'\xff')) 255 """ data = f.read(1) if data: - return ord(data) + return data[0] raise ValueError("not enough data in stream to read uint1") uint1 = ArgumentDescriptor( @@ -221,10 +221,10 @@ uint1 = ArgumentDescriptor( def read_uint2(f): r""" - >>> import StringIO - >>> read_uint2(StringIO.StringIO('\xff\x00')) + >>> import io + >>> read_uint2(io.BytesIO(b'\xff\x00')) 255 - >>> read_uint2(StringIO.StringIO('\xff\xff')) + >>> read_uint2(io.BytesIO(b'\xff\xff')) 65535 """ @@ -242,10 +242,10 @@ uint2 = ArgumentDescriptor( def read_int4(f): r""" - >>> import StringIO - >>> read_int4(StringIO.StringIO('\xff\x00\x00\x00')) + >>> import io + >>> read_int4(io.BytesIO(b'\xff\x00\x00\x00')) 255 - >>> read_int4(StringIO.StringIO('\x00\x00\x00\x80')) == -(2**31) + >>> read_int4(io.BytesIO(b'\x00\x00\x00\x80')) == -(2**31) True """ @@ -261,34 +261,48 @@ int4 = ArgumentDescriptor( doc="Four-byte signed integer, little-endian, 2's complement.") +def readline(f): + """Read a line from a binary file.""" + # XXX Slow but at least correct + b = bytes() + while True: + c = f.read(1) + if not c: + break + b += c + if c == b'\n': + break + return b + + def read_stringnl(f, decode=True, stripquotes=True): r""" - >>> import StringIO - >>> read_stringnl(StringIO.StringIO("'abcd'\nefg\n")) + >>> import io + >>> read_stringnl(io.BytesIO(b"'abcd'\nefg\n")) 'abcd' - >>> read_stringnl(StringIO.StringIO("\n")) + >>> read_stringnl(io.BytesIO(b"\n")) Traceback (most recent call last): ... - ValueError: no string quotes around '' + ValueError: no string quotes around b'' - >>> read_stringnl(StringIO.StringIO("\n"), stripquotes=False) + >>> read_stringnl(io.BytesIO(b"\n"), stripquotes=False) '' - >>> read_stringnl(StringIO.StringIO("''\n")) + >>> read_stringnl(io.BytesIO(b"''\n")) '' - >>> read_stringnl(StringIO.StringIO('"abcd"')) + >>> read_stringnl(io.BytesIO(b'"abcd"')) Traceback (most recent call last): ... ValueError: no newline found when trying to read stringnl Embedded escapes are undone in the result. - >>> read_stringnl(StringIO.StringIO(r"'a\n\\b\x00c\td'" + "\n'e'")) + >>> read_stringnl(io.BytesIO(br"'a\n\\b\x00c\td'" + b"\n'e'")) 'a\n\\b\x00c\td' """ - data = f.readline() + data = readline(f) if not data.endswith('\n'): raise ValueError("no newline found when trying to read stringnl") data = data[:-1] # lose the newline @@ -336,8 +350,8 @@ stringnl_noescape = ArgumentDescriptor( def read_stringnl_noescape_pair(f): r""" - >>> import StringIO - >>> read_stringnl_noescape_pair(StringIO.StringIO("Queue\nEmpty\njunk")) + >>> import io + >>> read_stringnl_noescape_pair(io.BytesIO(b"Queue\nEmpty\njunk")) 'Queue Empty' """ @@ -358,12 +372,12 @@ stringnl_noescape_pair = ArgumentDescriptor( def read_string4(f): r""" - >>> import StringIO - >>> read_string4(StringIO.StringIO("\x00\x00\x00\x00abc")) + >>> import io + >>> read_string4(io.BytesIO(b"\x00\x00\x00\x00abc")) '' - >>> read_string4(StringIO.StringIO("\x03\x00\x00\x00abcdef")) + >>> read_string4(io.BytesIO(b"\x03\x00\x00\x00abcdef")) 'abc' - >>> read_string4(StringIO.StringIO("\x00\x00\x00\x03abcdef")) + >>> read_string4(io.BytesIO(b"\x00\x00\x00\x03abcdef")) Traceback (most recent call last): ... ValueError: expected 50331648 bytes in a string4, but only 6 remain @@ -374,7 +388,7 @@ def read_string4(f): raise ValueError("string4 byte count < 0: %d" % n) data = f.read(n) if len(data) == n: - return data + return data.decode("latin-1") raise ValueError("expected %d bytes in a string4, but only %d remain" % (n, len(data))) @@ -392,10 +406,10 @@ string4 = ArgumentDescriptor( def read_string1(f): r""" - >>> import StringIO - >>> read_string1(StringIO.StringIO("\x00")) + >>> import io + >>> read_string1(io.BytesIO(b"\x00")) '' - >>> read_string1(StringIO.StringIO("\x03abcdef")) + >>> read_string1(io.BytesIO(b"\x03abcdef")) 'abc' """ @@ -403,7 +417,7 @@ def read_string1(f): assert n >= 0 data = f.read(n) if len(data) == n: - return data + return data.decode("latin-1") raise ValueError("expected %d bytes in a string1, but only %d remain" % (n, len(data))) @@ -421,12 +435,12 @@ string1 = ArgumentDescriptor( def read_unicodestringnl(f): r""" - >>> import StringIO - >>> read_unicodestringnl(StringIO.StringIO("abc\uabcd\njunk")) - u'abc\uabcd' + >>> import io + >>> read_unicodestringnl(io.BytesIO(b"abc\\uabcd\njunk")) == 'abc\uabcd' + True """ - data = f.readline() + data = readline(f) if not data.endswith('\n'): raise ValueError("no newline found when trying to read " "unicodestringnl") @@ -446,17 +460,17 @@ unicodestringnl = ArgumentDescriptor( def read_unicodestring4(f): r""" - >>> import StringIO - >>> s = u'abcd\uabcd' + >>> import io + >>> s = 'abcd\uabcd' >>> enc = s.encode('utf-8') >>> enc - 'abcd\xea\xaf\x8d' - >>> n = chr(len(enc)) + chr(0) * 3 # little-endian 4-byte length - >>> t = read_unicodestring4(StringIO.StringIO(n + enc + 'junk')) + b'abcd\xea\xaf\x8d' + >>> n = bytes([len(enc), 0, 0, 0]) # little-endian 4-byte length + >>> t = read_unicodestring4(io.BytesIO(n + enc + b'junk')) >>> s == t True - >>> read_unicodestring4(StringIO.StringIO(n + enc[:-1])) + >>> read_unicodestring4(io.BytesIO(n + enc[:-1])) Traceback (most recent call last): ... ValueError: expected 7 bytes in a unicodestring4, but only 6 remain @@ -486,14 +500,14 @@ unicodestring4 = ArgumentDescriptor( def read_decimalnl_short(f): r""" - >>> import StringIO - >>> read_decimalnl_short(StringIO.StringIO("1234\n56")) + >>> import io + >>> read_decimalnl_short(io.BytesIO(b"1234\n56")) 1234 - >>> read_decimalnl_short(StringIO.StringIO("1234L\n56")) + >>> read_decimalnl_short(io.BytesIO(b"1234L\n56")) Traceback (most recent call last): ... - ValueError: trailing 'L' not allowed in '1234L' + ValueError: trailing 'L' not allowed in b'1234L' """ s = read_stringnl(f, decode=False, stripquotes=False) @@ -515,12 +529,12 @@ def read_decimalnl_short(f): def read_decimalnl_long(f): r""" - >>> import StringIO + >>> import io - >>> read_decimalnl_long(StringIO.StringIO("1234L\n56")) + >>> read_decimalnl_long(io.BytesIO(b"1234L\n56")) 1234 - >>> read_decimalnl_long(StringIO.StringIO("123456789012345678901234L\n6")) + >>> read_decimalnl_long(io.BytesIO(b"123456789012345678901234L\n6")) 123456789012345678901234 """ @@ -554,8 +568,8 @@ decimalnl_long = ArgumentDescriptor( def read_floatnl(f): r""" - >>> import StringIO - >>> read_floatnl(StringIO.StringIO("-1.25\n6")) + >>> import io + >>> read_floatnl(io.BytesIO(b"-1.25\n6")) -1.25 """ s = read_stringnl(f, decode=False, stripquotes=False) @@ -576,11 +590,11 @@ floatnl = ArgumentDescriptor( def read_float8(f): r""" - >>> import StringIO, struct + >>> import io, struct >>> raw = struct.pack(">d", -1.25) >>> raw - '\xbf\xf4\x00\x00\x00\x00\x00\x00' - >>> read_float8(StringIO.StringIO(raw + "\n")) + b'\xbf\xf4\x00\x00\x00\x00\x00\x00' + >>> read_float8(io.BytesIO(raw + b"\n")) -1.25 """ @@ -614,16 +628,16 @@ from pickle import decode_long def read_long1(f): r""" - >>> import StringIO - >>> read_long1(StringIO.StringIO("\x00")) + >>> import io + >>> read_long1(io.BytesIO(b"\x00")) 0 - >>> read_long1(StringIO.StringIO("\x02\xff\x00")) + >>> read_long1(io.BytesIO(b"\x02\xff\x00")) 255 - >>> read_long1(StringIO.StringIO("\x02\xff\x7f")) + >>> read_long1(io.BytesIO(b"\x02\xff\x7f")) 32767 - >>> read_long1(StringIO.StringIO("\x02\x00\xff")) + >>> read_long1(io.BytesIO(b"\x02\x00\xff")) -256 - >>> read_long1(StringIO.StringIO("\x02\x00\x80")) + >>> read_long1(io.BytesIO(b"\x02\x00\x80")) -32768 """ @@ -646,16 +660,16 @@ long1 = ArgumentDescriptor( def read_long4(f): r""" - >>> import StringIO - >>> read_long4(StringIO.StringIO("\x02\x00\x00\x00\xff\x00")) + >>> import io + >>> read_long4(io.BytesIO(b"\x02\x00\x00\x00\xff\x00")) 255 - >>> read_long4(StringIO.StringIO("\x02\x00\x00\x00\xff\x7f")) + >>> read_long4(io.BytesIO(b"\x02\x00\x00\x00\xff\x7f")) 32767 - >>> read_long4(StringIO.StringIO("\x02\x00\x00\x00\x00\xff")) + >>> read_long4(io.BytesIO(b"\x02\x00\x00\x00\x00\xff")) -256 - >>> read_long4(StringIO.StringIO("\x02\x00\x00\x00\x00\x80")) + >>> read_long4(io.BytesIO(b"\x02\x00\x00\x00\x00\x80")) -32768 - >>> read_long1(StringIO.StringIO("\x00\x00\x00\x00")) + >>> read_long1(io.BytesIO(b"\x00\x00\x00\x00")) 0 """ @@ -701,7 +715,7 @@ class StackObject(object): ) def __init__(self, name, obtype, doc): - assert isinstance(name, str) + assert isinstance(name, basestring) self.name = name assert isinstance(obtype, type) or isinstance(obtype, tuple) @@ -710,7 +724,7 @@ class StackObject(object): assert isinstance(contained, type) self.obtype = obtype - assert isinstance(doc, str) + assert isinstance(doc, basestring) self.doc = doc def __repr__(self): @@ -846,10 +860,10 @@ class OpcodeInfo(object): def __init__(self, name, code, arg, stack_before, stack_after, proto, doc): - assert isinstance(name, str) + assert isinstance(name, basestring) self.name = name - assert isinstance(code, str) + assert isinstance(code, basestring) assert len(code) == 1 self.code = code @@ -869,7 +883,7 @@ class OpcodeInfo(object): assert isinstance(proto, int) and 0 <= proto <= 2 self.proto = proto - assert isinstance(doc, str) + assert isinstance(doc, basestring) self.doc = doc I = OpcodeInfo @@ -1819,10 +1833,9 @@ def genops(pickle): to query its current position) pos is None. """ - import cStringIO as StringIO - - if isinstance(pickle, str): - pickle = StringIO.StringIO(pickle) + if isinstance(pickle, bytes): + import io + pickle = io.BytesIO(pickle) if hasattr(pickle, "tell"): getpos = pickle.tell @@ -1832,9 +1845,9 @@ def genops(pickle): while True: pos = getpos() code = pickle.read(1) - opcode = code2op.get(code) + opcode = code2op.get(code.decode("latin-1")) if opcode is None: - if code == "": + if code == b"": raise ValueError("pickle exhausted before seeing STOP") else: raise ValueError("at position %s, opcode %r unknown" % ( @@ -1845,7 +1858,7 @@ def genops(pickle): else: arg = opcode.arg.reader(pickle) yield opcode, arg, pos - if code == '.': + if code == b'.': assert opcode.name == 'STOP' break @@ -1995,7 +2008,7 @@ class _Example: _dis_test = r""" >>> import pickle ->>> x = [1, 2, (3, 4), {'abc': u"def"}] +>>> x = [1, 2, (3, 4), {str8('abc'): "def"}] >>> pkl = pickle.dumps(x, 0) >>> dis(pkl) 0: ( MARK @@ -2016,7 +2029,7 @@ _dis_test = r""" 27: p PUT 2 30: S STRING 'abc' 37: p PUT 3 - 40: V UNICODE u'def' + 40: V UNICODE 'def' 45: p PUT 4 48: s SETITEM 49: a APPEND @@ -2041,7 +2054,7 @@ Try again with a "binary" pickle. 17: q BINPUT 2 19: U SHORT_BINSTRING 'abc' 24: q BINPUT 3 - 26: X BINUNICODE u'def' + 26: X BINUNICODE 'def' 34: q BINPUT 4 36: s SETITEM 37: e APPENDS (MARK at 3) @@ -2216,13 +2229,14 @@ highest protocol among opcodes = 2 _memo_test = r""" >>> import pickle ->>> from StringIO import StringIO ->>> f = StringIO() +>>> import io +>>> f = io.BytesIO() >>> p = pickle.Pickler(f, 2) >>> x = [1, 2, 3] >>> p.dump(x) >>> p.dump(x) >>> f.seek(0) +0 >>> memo = {} >>> dis(f, memo=memo) 0: \x80 PROTO 2 diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index 5d84eff..9b21a5f 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -21,7 +21,7 @@ protocols = range(pickle.HIGHEST_PROTOCOL + 1) # Return True if opcode code appears in the pickle, else False. def opcode_in_pickle(code, pickle): for op, dummy, dummy in pickletools.genops(pickle): - if op.code == code: + if op.code == code.decode("latin-1"): return True return False @@ -29,7 +29,7 @@ def opcode_in_pickle(code, pickle): def count_opcode(code, pickle): n = 0 for op, dummy, dummy in pickletools.genops(pickle): - if op.code == code: + if op.code == code.decode("latin-1"): n += 1 return n @@ -95,7 +95,7 @@ class use_metaclass(object, metaclass=metaclass): # the object returned by create_data(). # break into multiple strings to avoid confusing font-lock-mode -DATA0 = """(lp1 +DATA0 = b"""(lp1 I0 aL1L aF2 @@ -103,7 +103,7 @@ ac__builtin__ complex p2 """ + \ -"""(F3 +b"""(F3 F0 tRp3 aI1 @@ -118,15 +118,15 @@ aI2147483647 aI-2147483647 aI-2147483648 a""" + \ -"""(S'abc' +b"""(S'abc' p4 g4 """ + \ -"""(i__main__ +b"""(i__main__ C p5 """ + \ -"""(dp6 +b"""(dp6 S'foo' p7 I1 @@ -213,14 +213,14 @@ DATA0_DIS = """\ highest protocol among opcodes = 0 """ -DATA1 = (']q\x01(K\x00L1L\nG@\x00\x00\x00\x00\x00\x00\x00' - 'c__builtin__\ncomplex\nq\x02(G@\x08\x00\x00\x00\x00\x00' - '\x00G\x00\x00\x00\x00\x00\x00\x00\x00tRq\x03K\x01J\xff\xff' - '\xff\xffK\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xff' - 'J\x01\x00\xff\xffJ\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00' - '\x00\x80J\x00\x00\x00\x80(U\x03abcq\x04h\x04(c__main__\n' - 'C\nq\x05oq\x06}q\x07(U\x03fooq\x08K\x01U\x03barq\tK\x02ubh' - '\x06tq\nh\nK\x05e.' +DATA1 = (b']q\x01(K\x00L1L\nG@\x00\x00\x00\x00\x00\x00\x00' + b'c__builtin__\ncomplex\nq\x02(G@\x08\x00\x00\x00\x00\x00' + b'\x00G\x00\x00\x00\x00\x00\x00\x00\x00tRq\x03K\x01J\xff\xff' + b'\xff\xffK\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xff' + b'J\x01\x00\xff\xffJ\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00' + b'\x00\x80J\x00\x00\x00\x80(U\x03abcq\x04h\x04(c__main__\n' + b'C\nq\x05oq\x06}q\x07(U\x03fooq\x08K\x01U\x03barq\tK\x02ubh' + b'\x06tq\nh\nK\x05e.' ) # Disassembly of DATA1. @@ -280,13 +280,13 @@ DATA1_DIS = """\ highest protocol among opcodes = 1 """ -DATA2 = ('\x80\x02]q\x01(K\x00\x8a\x01\x01G@\x00\x00\x00\x00\x00\x00\x00' - 'c__builtin__\ncomplex\nq\x02G@\x08\x00\x00\x00\x00\x00\x00G\x00' - '\x00\x00\x00\x00\x00\x00\x00\x86Rq\x03K\x01J\xff\xff\xff\xffK' - '\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xffJ\x01\x00\xff\xff' - 'J\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00\x00\x80J\x00\x00\x00' - '\x80(U\x03abcq\x04h\x04(c__main__\nC\nq\x05oq\x06}q\x07(U\x03foo' - 'q\x08K\x01U\x03barq\tK\x02ubh\x06tq\nh\nK\x05e.') +DATA2 = (b'\x80\x02]q\x01(K\x00\x8a\x01\x01G@\x00\x00\x00\x00\x00\x00\x00' + b'c__builtin__\ncomplex\nq\x02G@\x08\x00\x00\x00\x00\x00\x00G\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x86Rq\x03K\x01J\xff\xff\xff\xffK' + b'\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xffJ\x01\x00\xff\xff' + b'J\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00\x00\x80J\x00\x00\x00' + b'\x80(U\x03abcq\x04h\x04(c__main__\nC\nq\x05oq\x06}q\x07(U\x03foo' + b'q\x08K\x01U\x03barq\tK\x02ubh\x06tq\nh\nK\x05e.') # Disassembly of DATA2. DATA2_DIS = """\ @@ -465,7 +465,7 @@ class AbstractPickleTests(unittest.TestCase): self.assert_(x[0].attr[1] is x) def test_garyp(self): - self.assertRaises(self.error, self.loads, 'garyp') + self.assertRaises(self.error, self.loads, b'garyp') def test_insecure_strings(self): insecure = ["abc", "2 + 2", # not quoted @@ -479,7 +479,7 @@ class AbstractPickleTests(unittest.TestCase): #"'\\\\a\'\'\'\\\'\\\\\''", ] for s in insecure: - buf = "S" + s + "\012p0\012." + buf = b"S" + bytes(s) + b"\012p0\012." self.assertRaises(ValueError, self.loads, buf) if have_unicode: @@ -505,12 +505,12 @@ class AbstractPickleTests(unittest.TestCase): def test_maxint64(self): maxint64 = (1 << 63) - 1 - data = 'I' + str(maxint64) + '\n.' + data = b'I' + bytes(str(maxint64)) + b'\n.' got = self.loads(data) self.assertEqual(got, maxint64) # Try too with a bogus literal. - data = 'I' + str(maxint64) + 'JUNK\n.' + data = b'I' + bytes(str(maxint64)) + b'JUNK\n.' self.assertRaises(ValueError, self.loads, data) def test_long(self): @@ -535,7 +535,7 @@ class AbstractPickleTests(unittest.TestCase): @run_with_locale('LC_ALL', 'de_DE', 'fr_FR') def test_float_format(self): # make sure that floats are formatted locale independent - self.assertEqual(self.dumps(1.2)[0:3], 'F1.') + self.assertEqual(self.dumps(1.2)[0:3], b'F1.') def test_reduce(self): pass @@ -577,12 +577,12 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: expected = build_none if proto >= 2: - expected = pickle.PROTO + chr(proto) + expected + expected = pickle.PROTO + bytes([proto]) + expected p = self.dumps(None, proto) self.assertEqual(p, expected) oob = protocols[-1] + 1 # a future protocol - badpickle = pickle.PROTO + chr(oob) + build_none + badpickle = pickle.PROTO + bytes([oob]) + build_none try: self.loads(badpickle) except ValueError as detail: @@ -708,8 +708,8 @@ class AbstractPickleTests(unittest.TestCase): # Dump using protocol 1 for comparison. s1 = self.dumps(x, 1) - self.assert_(__name__ in s1) - self.assert_("MyList" in s1) + self.assert_(bytes(__name__) in s1) + self.assert_(b"MyList" in s1) self.assertEqual(opcode_in_pickle(opcode, s1), False) y = self.loads(s1) @@ -718,9 +718,9 @@ class AbstractPickleTests(unittest.TestCase): # Dump using protocol 2 for test. s2 = self.dumps(x, 2) - self.assert_(__name__ not in s2) - self.assert_("MyList" not in s2) - self.assertEqual(opcode_in_pickle(opcode, s2), True) + self.assert_(bytes(__name__) not in s2) + self.assert_(b"MyList" not in s2) + self.assertEqual(opcode_in_pickle(opcode, s2), True, repr(s2)) y = self.loads(s2) self.assertEqual(list(x), list(y)) @@ -770,6 +770,7 @@ class AbstractPickleTests(unittest.TestCase): x = dict.fromkeys(range(n)) for proto in protocols: s = self.dumps(x, proto) + assert isinstance(s, bytes) y = self.loads(s) self.assertEqual(x, y) num_setitems = count_opcode(pickle.SETITEMS, s) diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index 585644e..11254f4 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -1,6 +1,6 @@ import pickle import unittest -from cStringIO import StringIO +import io from test import test_support @@ -26,16 +26,16 @@ class PicklerTests(AbstractPickleTests): error = KeyError def dumps(self, arg, proto=0, fast=0): - f = StringIO() + f = io.BytesIO() p = pickle.Pickler(f, proto) if fast: p.fast = fast p.dump(arg) f.seek(0) - return f.read() + return bytes(f.read()) def loads(self, buf): - f = StringIO(buf) + f = io.BytesIO(buf) u = pickle.Unpickler(f) return u.load() @@ -45,7 +45,7 @@ class PersPicklerTests(AbstractPersistentPicklerTests): class PersPickler(pickle.Pickler): def persistent_id(subself, obj): return self.persistent_id(obj) - f = StringIO() + f = io.BytesIO() p = PersPickler(f, proto) if fast: p.fast = fast @@ -57,7 +57,7 @@ class PersPicklerTests(AbstractPersistentPicklerTests): class PersUnpickler(pickle.Unpickler): def persistent_load(subself, obj): return self.persistent_load(obj) - f = StringIO(buf) + f = io.BytesIO(buf) u = PersUnpickler(f) return u.load() diff --git a/Modules/cPickle.c b/Modules/cPickle.c index 00641d8..639e68b 100644 --- a/Modules/cPickle.c +++ b/Modules/cPickle.c @@ -5241,6 +5241,13 @@ cpm_dumps(PyObject *self, PyObject *args, PyObject *kwds) goto finally; res = PycStringIO->cgetvalue(file); + if (res == NULL) + goto finally; + if (!PyBytes_Check(res)) { + PyObject *tmp = res; + res = PyBytes_FromObject(res); + Py_DECREF(tmp); + } finally: Py_XDECREF(pickler); |