diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2013-04-14 10:37:02 (GMT) |
---|---|---|
committer | Serhiy Storchaka <storchaka@gmail.com> | 2013-04-14 10:37:02 (GMT) |
commit | a3e32c92cf3170f5a91c23c6aacd5be310e9a842 (patch) | |
tree | e0cfcd13fc2e1576411cef5c274977aa6f251fab /Lib/pickle.py | |
parent | c8fb047d6993e5b559bba8449e230364ae47764b (diff) | |
download | cpython-a3e32c92cf3170f5a91c23c6aacd5be310e9a842.zip cpython-a3e32c92cf3170f5a91c23c6aacd5be310e9a842.tar.gz cpython-a3e32c92cf3170f5a91c23c6aacd5be310e9a842.tar.bz2 |
Closes #16551. Cleanup pickle.py.
Diffstat (limited to 'Lib/pickle.py')
-rw-r--r-- | Lib/pickle.py | 223 |
1 files changed, 85 insertions, 138 deletions
diff --git a/Lib/pickle.py b/Lib/pickle.py index 168865d..998fce0 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -26,9 +26,10 @@ Misc variables: from types import FunctionType, BuiltinFunctionType from copyreg import dispatch_table from copyreg import _extension_registry, _inverted_registry, _extension_cache -import marshal +from itertools import islice import sys -import struct +from sys import maxsize +from struct import pack, unpack import re import io import codecs @@ -58,11 +59,6 @@ HIGHEST_PROTOCOL = 3 # there are too many issues with that. DEFAULT_PROTOCOL = 3 -# Why use struct.pack() for pickling but marshal.loads() for -# unpickling? struct.pack() is 40% faster than marshal.dumps(), but -# marshal.loads() is twice as fast as struct.unpack()! -mloads = marshal.loads - class PickleError(Exception): """A common base class for the other pickling exceptions.""" pass @@ -231,7 +227,7 @@ class _Pickler: raise PicklingError("Pickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,)) if self.proto >= 2: - self.write(PROTO + bytes([self.proto])) + self.write(PROTO + pack("<B", self.proto)) self.save(obj) self.write(STOP) @@ -258,20 +254,20 @@ class _Pickler: self.memo[id(obj)] = memo_len, obj # Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i. - def put(self, i, pack=struct.pack): + def put(self, i): if self.bin: if i < 256: - return BINPUT + bytes([i]) + return BINPUT + pack("<B", i) else: return LONG_BINPUT + pack("<I", i) return PUT + repr(i).encode("ascii") + b'\n' # Return a GET (BINGET, LONG_BINGET) opcode string, with argument i. - def get(self, i, pack=struct.pack): + def get(self, i): if self.bin: if i < 256: - return BINGET + bytes([i]) + return BINGET + pack("<B", i) else: return LONG_BINGET + pack("<I", i) @@ -286,20 +282,20 @@ class _Pickler: # Check the memo x = self.memo.get(id(obj)) - if x: + if x is not None: self.write(self.get(x[0])) return # Check the type dispatch table t = type(obj) f = self.dispatch.get(t) - if f: + if f is not None: f(self, obj) # Call unbound method with explicit self return # Check private dispatch table if any, or else copyreg.dispatch_table reduce = getattr(self, 'dispatch_table', dispatch_table).get(t) - if reduce: + if reduce is not None: rv = reduce(obj) else: # Check for a class with a custom metaclass; treat as regular class @@ -313,11 +309,11 @@ class _Pickler: # Check for a __reduce_ex__ method, fall back to __reduce__ reduce = getattr(obj, "__reduce_ex__", None) - if reduce: + if reduce is not None: rv = reduce(self.proto) else: reduce = getattr(obj, "__reduce__", None) - if reduce: + if reduce is not None: rv = reduce() else: raise PicklingError("Can't pickle %r object: %r" % @@ -448,12 +444,12 @@ class _Pickler: def save_bool(self, obj): if self.proto >= 2: - self.write(obj and NEWTRUE or NEWFALSE) + self.write(NEWTRUE if obj else NEWFALSE) else: - self.write(obj and TRUE or FALSE) + self.write(TRUE if obj else FALSE) dispatch[bool] = save_bool - def save_long(self, obj, pack=struct.pack): + def save_long(self, obj): if self.bin: # If the int is small enough to fit in a signed 4-byte 2's-comp # format, we can store it more efficiently than the general @@ -461,39 +457,36 @@ class _Pickler: # First one- and two-byte unsigned ints: if obj >= 0: if obj <= 0xff: - self.write(BININT1 + bytes([obj])) + self.write(BININT1 + pack("<B", obj)) return if obj <= 0xffff: - self.write(BININT2 + bytes([obj&0xff, obj>>8])) + self.write(BININT2 + pack("<H", obj)) return # Next check for 4-byte signed ints: - high_bits = obj >> 31 # note that Python shift sign-extends - if high_bits == 0 or high_bits == -1: - # All high bits are copies of bit 2**31, so the value - # fits in a 4-byte signed int. + if -0x80000000 <= obj <= 0x7fffffff: self.write(BININT + pack("<i", obj)) return if self.proto >= 2: encoded = encode_long(obj) n = len(encoded) if n < 256: - self.write(LONG1 + bytes([n]) + encoded) + self.write(LONG1 + pack("<B", n) + encoded) else: self.write(LONG4 + pack("<i", n) + encoded) return self.write(LONG + repr(obj).encode("ascii") + b'L\n') dispatch[int] = save_long - def save_float(self, obj, pack=struct.pack): + def save_float(self, obj): if self.bin: self.write(BINFLOAT + pack('>d', obj)) else: self.write(FLOAT + repr(obj).encode("ascii") + b'\n') dispatch[float] = save_float - def save_bytes(self, obj, pack=struct.pack): + def save_bytes(self, obj): if self.proto < 3: - if len(obj) == 0: + if not obj: # bytes object is empty self.save_reduce(bytes, (), obj=obj) else: self.save_reduce(codecs.encode, @@ -501,13 +494,13 @@ class _Pickler: return n = len(obj) if n < 256: - self.write(SHORT_BINBYTES + bytes([n]) + bytes(obj)) + self.write(SHORT_BINBYTES + pack("<B", n) + obj) else: - self.write(BINBYTES + pack("<I", n) + bytes(obj)) + self.write(BINBYTES + pack("<I", n) + obj) self.memoize(obj) dispatch[bytes] = save_bytes - def save_str(self, obj, pack=struct.pack): + def save_str(self, obj): if self.bin: encoded = obj.encode('utf-8', 'surrogatepass') n = len(encoded) @@ -515,39 +508,36 @@ class _Pickler: else: obj = obj.replace("\\", "\\u005c") obj = obj.replace("\n", "\\u000a") - self.write(UNICODE + bytes(obj.encode('raw-unicode-escape')) + - b'\n') + self.write(UNICODE + obj.encode('raw-unicode-escape') + b'\n') self.memoize(obj) dispatch[str] = save_str def save_tuple(self, obj): - write = self.write - proto = self.proto - - n = len(obj) - if n == 0: - if proto: - write(EMPTY_TUPLE) + if not obj: # tuple is empty + if self.bin: + self.write(EMPTY_TUPLE) else: - write(MARK + TUPLE) + self.write(MARK + TUPLE) return + n = len(obj) save = self.save memo = self.memo - if n <= 3 and proto >= 2: + if n <= 3 and self.proto >= 2: for element in obj: save(element) # Subtle. Same as in the big comment below. if id(obj) in memo: get = self.get(memo[id(obj)][0]) - write(POP * n + get) + self.write(POP * n + get) else: - write(_tuplesize2code[n]) + self.write(_tuplesize2code[n]) self.memoize(obj) return # proto 0 or proto 1 and tuple isn't empty, or proto > 1 and tuple # has more than 3 elements. + write = self.write write(MARK) for element in obj: save(element) @@ -561,25 +551,23 @@ class _Pickler: # could have been done in the "for element" loop instead, but # recursive tuples are a rare thing. get = self.get(memo[id(obj)][0]) - if proto: + if self.bin: write(POP_MARK + get) else: # proto 0 -- POP_MARK not available write(POP * (n+1) + get) return # No recursion. - self.write(TUPLE) + write(TUPLE) self.memoize(obj) dispatch[tuple] = save_tuple def save_list(self, obj): - write = self.write - if self.bin: - write(EMPTY_LIST) + self.write(EMPTY_LIST) else: # proto 0 -- can't use EMPTY_LIST - write(MARK + LIST) + self.write(MARK + LIST) self.memoize(obj) self._batch_appends(obj) @@ -599,17 +587,9 @@ class _Pickler: write(APPEND) return - items = iter(items) - r = range(self._BATCHSIZE) - while items is not None: - tmp = [] - for i in r: - try: - x = next(items) - tmp.append(x) - except StopIteration: - items = None - break + it = iter(items) + while True: + tmp = list(islice(it, self._BATCHSIZE)) n = len(tmp) if n > 1: write(MARK) @@ -620,14 +600,14 @@ class _Pickler: save(tmp[0]) write(APPEND) # else tmp is empty, and we're done + if n < self._BATCHSIZE: + return def save_dict(self, obj): - write = self.write - if self.bin: - write(EMPTY_DICT) + self.write(EMPTY_DICT) else: # proto 0 -- can't use EMPTY_DICT - write(MARK + DICT) + self.write(MARK + DICT) self.memoize(obj) self._batch_setitems(obj.items()) @@ -648,16 +628,9 @@ class _Pickler: write(SETITEM) return - items = iter(items) - r = range(self._BATCHSIZE) - while items is not None: - tmp = [] - for i in r: - try: - tmp.append(next(items)) - except StopIteration: - items = None - break + it = iter(items) + while True: + tmp = list(islice(it, self._BATCHSIZE)) n = len(tmp) if n > 1: write(MARK) @@ -671,8 +644,10 @@ class _Pickler: save(v) write(SETITEM) # else tmp is empty, and we're done + if n < self._BATCHSIZE: + return - def save_global(self, obj, name=None, pack=struct.pack): + def save_global(self, obj, name=None): write = self.write memo = self.memo @@ -702,9 +677,9 @@ class _Pickler: if code: assert code > 0 if code <= 0xff: - write(EXT1 + bytes([code])) + write(EXT1 + pack("<B", code)) elif code <= 0xffff: - write(EXT2 + bytes([code&0xff, code>>8])) + write(EXT2 + pack("<H", code)) else: write(EXT4 + pack("<i", code)) return @@ -732,25 +707,6 @@ class _Pickler: dispatch[BuiltinFunctionType] = save_global dispatch[type] = save_global -# Pickling helpers - -def _keep_alive(x, memo): - """Keeps a reference to the object x in the memo. - - Because we remember objects by their id, we have - to assure that possibly temporary objects are kept - alive by referencing them. - We store a reference at the id of the memo, which should - normally not be used unless someone tries to deepcopy - the memo itself... - """ - try: - memo[id(memo)].append(x) - except KeyError: - # aha, this is the first one :-) - memo[id(memo)]=[x] - - # A cache for whichmodule(), mapping a function object to the name of # the module in which the function was found. @@ -832,7 +788,7 @@ class _Unpickler: read = self.read dispatch = self.dispatch try: - while 1: + while True: key = read(1) if not key: raise EOFError @@ -862,7 +818,7 @@ class _Unpickler: dispatch = {} def load_proto(self): - proto = ord(self.read(1)) + proto = self.read(1)[0] if not 0 <= proto <= HIGHEST_PROTOCOL: raise ValueError("unsupported pickle protocol: %d" % proto) self.proto = proto @@ -897,40 +853,37 @@ class _Unpickler: elif data == TRUE[1:]: val = True else: - try: - val = int(data, 0) - except ValueError: - val = int(data, 0) + val = int(data, 0) self.append(val) dispatch[INT[0]] = load_int def load_binint(self): - self.append(mloads(b'i' + self.read(4))) + self.append(unpack('<i', self.read(4))[0]) dispatch[BININT[0]] = load_binint def load_binint1(self): - self.append(ord(self.read(1))) + self.append(self.read(1)[0]) dispatch[BININT1[0]] = load_binint1 def load_binint2(self): - self.append(mloads(b'i' + self.read(2) + b'\000\000')) + self.append(unpack('<H', self.read(2))[0]) dispatch[BININT2[0]] = load_binint2 def load_long(self): - val = self.readline()[:-1].decode("ascii") - if val and val[-1] == 'L': + val = self.readline()[:-1] + if val and val[-1] == b'L'[0]: val = val[:-1] self.append(int(val, 0)) dispatch[LONG[0]] = load_long def load_long1(self): - n = ord(self.read(1)) + n = self.read(1)[0] data = self.read(n) self.append(decode_long(data)) dispatch[LONG1[0]] = load_long1 def load_long4(self): - n = mloads(b'i' + self.read(4)) + n, = unpack('<i', self.read(4)) if n < 0: # Corrupt or hostile pickle -- we never write one like this raise UnpicklingError("LONG pickle has negative byte count") @@ -942,28 +895,25 @@ class _Unpickler: self.append(float(self.readline()[:-1])) dispatch[FLOAT[0]] = load_float - def load_binfloat(self, unpack=struct.unpack): + def load_binfloat(self): self.append(unpack('>d', self.read(8))[0]) dispatch[BINFLOAT[0]] = load_binfloat def load_string(self): orig = self.readline() rep = orig[:-1] - for q in (b'"', b"'"): # double or single quote - if rep.startswith(q): - if not rep.endswith(q): - raise ValueError("insecure string pickle") - rep = rep[len(q):-len(q)] - break + # Strip outermost quotes + if rep[0] == rep[-1] and rep[0] in b'"\'': + rep = rep[1:-1] else: - raise ValueError("insecure string pickle: %r" % orig) + raise ValueError("insecure string pickle") self.append(codecs.escape_decode(rep)[0] .decode(self.encoding, self.errors)) dispatch[STRING[0]] = load_string def load_binstring(self): # Deprecated BINSTRING uses signed 32-bit length - len = mloads(b'i' + self.read(4)) + len, = unpack('<i', self.read(4)) if len < 0: raise UnpicklingError("BINSTRING pickle has negative byte count") data = self.read(len) @@ -971,7 +921,7 @@ class _Unpickler: self.append(value) dispatch[BINSTRING[0]] = load_binstring - def load_binbytes(self, unpack=struct.unpack, maxsize=sys.maxsize): + def load_binbytes(self): len, = unpack('<I', self.read(4)) if len > maxsize: raise UnpicklingError("BINBYTES exceeds system's maximum size " @@ -983,7 +933,7 @@ class _Unpickler: self.append(str(self.readline()[:-1], 'raw-unicode-escape')) dispatch[UNICODE[0]] = load_unicode - def load_binunicode(self, unpack=struct.unpack, maxsize=sys.maxsize): + def load_binunicode(self): len, = unpack('<I', self.read(4)) if len > maxsize: raise UnpicklingError("BINUNICODE exceeds system's maximum size " @@ -992,15 +942,15 @@ class _Unpickler: dispatch[BINUNICODE[0]] = load_binunicode def load_short_binstring(self): - len = ord(self.read(1)) - data = bytes(self.read(len)) + len = self.read(1)[0] + data = self.read(len) value = str(data, self.encoding, self.errors) self.append(value) dispatch[SHORT_BINSTRING[0]] = load_short_binstring def load_short_binbytes(self): - len = ord(self.read(1)) - self.append(bytes(self.read(len))) + len = self.read(1)[0] + self.append(self.read(len)) dispatch[SHORT_BINBYTES[0]] = load_short_binbytes def load_tuple(self): @@ -1039,12 +989,9 @@ class _Unpickler: def load_dict(self): k = self.marker() - d = {} items = self.stack[k+1:] - for i in range(0, len(items), 2): - key = items[i] - value = items[i+1] - d[key] = value + d = {items[i]: items[i+1] + for i in range(0, len(items), 2)} self.stack[k:] = [d] dispatch[DICT[0]] = load_dict @@ -1096,17 +1043,17 @@ class _Unpickler: dispatch[GLOBAL[0]] = load_global def load_ext1(self): - code = ord(self.read(1)) + code = self.read(1)[0] self.get_extension(code) dispatch[EXT1[0]] = load_ext1 def load_ext2(self): - code = mloads(b'i' + self.read(2) + b'\000\000') + code, = unpack('<H', self.read(2)) self.get_extension(code) dispatch[EXT2[0]] = load_ext2 def load_ext4(self): - code = mloads(b'i' + self.read(4)) + code, = unpack('<i', self.read(4)) self.get_extension(code) dispatch[EXT4[0]] = load_ext4 @@ -1174,7 +1121,7 @@ class _Unpickler: self.append(self.memo[i]) dispatch[BINGET[0]] = load_binget - def load_long_binget(self, unpack=struct.unpack): + def load_long_binget(self): i, = unpack('<I', self.read(4)) self.append(self.memo[i]) dispatch[LONG_BINGET[0]] = load_long_binget @@ -1193,7 +1140,7 @@ class _Unpickler: self.memo[i] = self.stack[-1] dispatch[BINPUT[0]] = load_binput - def load_long_binput(self, unpack=struct.unpack, maxsize=sys.maxsize): + def load_long_binput(self): i, = unpack('<I', self.read(4)) if i > maxsize: raise ValueError("negative LONG_BINPUT argument") @@ -1238,7 +1185,7 @@ class _Unpickler: state = stack.pop() inst = stack[-1] setstate = getattr(inst, "__setstate__", None) - if setstate: + if setstate is not None: setstate(state) return slotstate = None |