diff options
Diffstat (limited to 'Lib/pickle.py')
-rw-r--r-- | Lib/pickle.py | 866 |
1 files changed, 538 insertions, 328 deletions
diff --git a/Lib/pickle.py b/Lib/pickle.py index 386ffba..67382ae 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -23,12 +23,13 @@ Misc variables: """ -from types import FunctionType, BuiltinFunctionType +from types import FunctionType from copyreg import dispatch_table from copyreg import _extension_registry, _inverted_registry, _extension_cache -import marshal +from itertools import islice import sys -import struct +from sys import maxsize +from struct import pack, unpack import re import io import codecs @@ -41,28 +42,24 @@ __all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler", bytes_types = (bytes, bytearray) # These are purely informational; no code uses these. -format_version = "3.0" # File format version we write +format_version = "4.0" # File format version we write compatible_formats = ["1.0", # Original protocol 0 "1.1", # Protocol 0 with INST added "1.2", # Original protocol 1 "1.3", # Protocol 1 with BINFLOAT added "2.0", # Protocol 2 "3.0", # Protocol 3 + "4.0", # Protocol 4 ] # Old format versions we can read # This is the highest protocol number we know how to read. -HIGHEST_PROTOCOL = 3 +HIGHEST_PROTOCOL = 4 # The protocol we write by default. May be less than HIGHEST_PROTOCOL. # We intentionally write a protocol that Python 2.x cannot read; # there are too many issues with that. DEFAULT_PROTOCOL = 3 -# Why use struct.pack() for pickling but marshal.loads() for -# unpickling? struct.pack() is 40% faster than marshal.dumps(), but -# marshal.loads() is twice as fast as struct.unpack()! -mloads = marshal.loads - class PickleError(Exception): """A common base class for the other pickling exceptions.""" pass @@ -168,7 +165,183 @@ _tuplesize2code = [EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3] BINBYTES = b'B' # push bytes; counted binary string argument SHORT_BINBYTES = b'C' # " " ; " " " " < 256 bytes -__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$",x)]) +# Protocol 4 +SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes +BINUNICODE8 = b'\x8d' # push very long string +BINBYTES8 = b'\x8e' # push very long bytes string +EMPTY_SET = b'\x8f' # push empty set on the stack +ADDITEMS = b'\x90' # modify set by adding topmost stack items +FROZENSET = b'\x91' # build frozenset from topmost stack items +NEWOBJ_EX = b'\x92' # like NEWOBJ but work with keyword only arguments +STACK_GLOBAL = b'\x93' # same as GLOBAL but using names on the stacks +MEMOIZE = b'\x94' # store top of the stack in memo +FRAME = b'\x95' # indicate the beginning of a new frame + +__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)]) + + +class _Framer: + + _FRAME_SIZE_TARGET = 64 * 1024 + + def __init__(self, file_write): + self.file_write = file_write + self.current_frame = None + + def start_framing(self): + self.current_frame = io.BytesIO() + + def end_framing(self): + if self.current_frame and self.current_frame.tell() > 0: + self.commit_frame(force=True) + self.current_frame = None + + def commit_frame(self, force=False): + if self.current_frame: + f = self.current_frame + if f.tell() >= self._FRAME_SIZE_TARGET or force: + with f.getbuffer() as data: + n = len(data) + write = self.file_write + write(FRAME) + write(pack("<Q", n)) + write(data) + f.seek(0) + f.truncate() + + def write(self, data): + if self.current_frame: + return self.current_frame.write(data) + else: + return self.file_write(data) + + +class _Unframer: + + def __init__(self, file_read, file_readline, file_tell=None): + self.file_read = file_read + self.file_readline = file_readline + self.current_frame = None + + def read(self, n): + if self.current_frame: + data = self.current_frame.read(n) + if not data and n != 0: + self.current_frame = None + return self.file_read(n) + if len(data) < n: + raise UnpicklingError( + "pickle exhausted before end of frame") + return data + else: + return self.file_read(n) + + def readline(self): + if self.current_frame: + data = self.current_frame.readline() + if not data: + self.current_frame = None + return self.file_readline() + if data[-1] != b'\n'[0]: + raise UnpicklingError( + "pickle exhausted before end of frame") + return data + else: + return self.file_readline() + + def load_frame(self, frame_size): + if self.current_frame and self.current_frame.read() != b'': + raise UnpicklingError( + "beginning of a new frame before end of current frame") + self.current_frame = io.BytesIO(self.file_read(frame_size)) + + +# Tools used for pickling. + +def _getattribute(obj, name, allow_qualname=False): + dotted_path = name.split(".") + if not allow_qualname and len(dotted_path) > 1: + raise AttributeError("Can't get qualified attribute {!r} on {!r}; " + + "use protocols >= 4 to enable support" + .format(name, obj)) + for subpath in dotted_path: + if subpath == '<locals>': + raise AttributeError("Can't get local attribute {!r} on {!r}" + .format(name, obj)) + try: + obj = getattr(obj, subpath) + except AttributeError: + raise AttributeError("Can't get attribute {!r} on {!r}" + .format(name, obj)) + return obj + +def whichmodule(obj, name, allow_qualname=False): + """Find the module an object belong to.""" + module_name = getattr(obj, '__module__', None) + if module_name is not None: + return module_name + # Protect the iteration by using a list copy of sys.modules against dynamic + # modules that trigger imports of other modules upon calls to getattr. + for module_name, module in list(sys.modules.items()): + if module_name == '__main__' or module is None: + continue + try: + if _getattribute(module, name, allow_qualname) is obj: + return module_name + except AttributeError: + pass + return '__main__' + +def encode_long(x): + r"""Encode a long to a two's complement little-endian binary string. + Note that 0 is a special case, returning an empty string, to save a + byte in the LONG1 pickling context. + + >>> encode_long(0) + b'' + >>> encode_long(255) + b'\xff\x00' + >>> encode_long(32767) + b'\xff\x7f' + >>> encode_long(-256) + b'\x00\xff' + >>> encode_long(-32768) + b'\x00\x80' + >>> encode_long(-128) + b'\x80' + >>> encode_long(127) + b'\x7f' + >>> + """ + if x == 0: + return b'' + nbytes = (x.bit_length() >> 3) + 1 + result = x.to_bytes(nbytes, byteorder='little', signed=True) + if x < 0 and nbytes > 1: + if result[-1] == 0xff and (result[-2] & 0x80) != 0: + result = result[:-1] + return result + +def decode_long(data): + r"""Decode a long from a two's complement little-endian binary string. + + >>> decode_long(b'') + 0 + >>> decode_long(b"\xff\x00") + 255 + >>> decode_long(b"\xff\x7f") + 32767 + >>> decode_long(b"\x00\xff") + -256 + >>> decode_long(b"\x00\x80") + -32768 + >>> decode_long(b"\x80") + -128 + >>> decode_long(b"\x7f") + 127 + """ + return int.from_bytes(data, byteorder='little', signed=True) + # Pickling machinery @@ -177,24 +350,25 @@ class _Pickler: def __init__(self, file, protocol=None, *, fix_imports=True): """This takes a binary file for writing a pickle data stream. - The optional protocol argument tells the pickler to use the - given protocol; supported protocols are 0, 1, 2, 3. The default - protocol is 3; a backward-incompatible protocol designed for - Python 3.0. + The optional *protocol* argument tells the pickler to use the + given protocol; supported protocols are 0, 1, 2, 3 and 4. The + default protocol is 3; a backward-incompatible protocol designed + for Python 3. Specifying a negative protocol version selects the highest protocol version supported. The higher the protocol used, the more recent the version of Python needed to read the pickle produced. - The file argument must have a write() method that accepts a single - bytes argument. It can thus be a file object opened for binary - writing, a io.BytesIO instance, or any other custom object that - meets this interface. + The *file* argument must have a write() method that accepts a + single bytes argument. It can thus be a file object opened for + binary writing, a io.BytesIO instance, or any other custom + object that meets this interface. - If fix_imports is True and protocol is less than 3, pickle will try to - map the new Python 3.x names to the old module names used in Python - 2.x, so that the pickle data stream is readable with Python 2.x. + If *fix_imports* is True and *protocol* is less than 3, pickle + will try to map the new Python 3 names to the old module names + used in Python 2, so that the pickle data stream is readable + with Python 2. """ if protocol is None: protocol = DEFAULT_PROTOCOL @@ -203,9 +377,11 @@ class _Pickler: elif not 0 <= protocol <= HIGHEST_PROTOCOL: raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL) try: - self.write = file.write + self._file_write = file.write except AttributeError: raise TypeError("file must have a 'write' attribute") + self.framer = _Framer(self._file_write) + self.write = self.framer.write self.memo = {} self.proto = int(protocol) self.bin = protocol >= 1 @@ -216,10 +392,9 @@ class _Pickler: """Clears the pickler's "memo". The memo is the data structure that remembers which objects the - pickler has already seen, so that shared or recursive objects are - pickled by reference and not by value. This method is useful when - re-using picklers. - + pickler has already seen, so that shared or recursive objects + are pickled by reference and not by value. This method is + useful when re-using picklers. """ self.memo.clear() @@ -227,13 +402,16 @@ class _Pickler: """Write a pickled representation of obj to the open file.""" # Check whether Pickler was initialized correctly. This is # only needed to mimic the behavior of _pickle.Pickler.dump(). - if not hasattr(self, "write"): + if not hasattr(self, "_file_write"): raise PicklingError("Pickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,)) if self.proto >= 2: - self.write(PROTO + bytes([self.proto])) + self.write(PROTO + pack("<B", self.proto)) + if self.proto >= 4: + self.framer.start_framing() self.save(obj) self.write(STOP) + self.framer.end_framing() def memoize(self, obj): """Store an object in the memo.""" @@ -253,31 +431,35 @@ class _Pickler: if self.fast: return assert id(obj) not in self.memo - memo_len = len(self.memo) - self.write(self.put(memo_len)) - self.memo[id(obj)] = memo_len, obj + idx = len(self.memo) + self.write(self.put(idx)) + self.memo[id(obj)] = idx, obj # Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i. - def put(self, i, pack=struct.pack): - if self.bin: - if i < 256: - return BINPUT + bytes([i]) + def put(self, idx): + if self.proto >= 4: + return MEMOIZE + elif self.bin: + if idx < 256: + return BINPUT + pack("<B", idx) else: - return LONG_BINPUT + pack("<I", i) - - return PUT + repr(i).encode("ascii") + b'\n' + return LONG_BINPUT + pack("<I", idx) + else: + return PUT + repr(idx).encode("ascii") + b'\n' # Return a GET (BINGET, LONG_BINGET) opcode string, with argument i. - def get(self, i, pack=struct.pack): + def get(self, i): if self.bin: if i < 256: - return BINGET + bytes([i]) + return BINGET + pack("<B", i) else: return LONG_BINGET + pack("<I", i) return GET + repr(i).encode("ascii") + b'\n' def save(self, obj, save_persistent_id=True): + self.framer.commit_frame() + # Check for persistent id (defined by a subclass) pid = self.persistent_id(obj) if pid is not None and save_persistent_id: @@ -286,20 +468,20 @@ class _Pickler: # Check the memo x = self.memo.get(id(obj)) - if x: + if x is not None: self.write(self.get(x[0])) return # Check the type dispatch table t = type(obj) f = self.dispatch.get(t) - if f: + if f is not None: f(self, obj) # Call unbound method with explicit self return # Check private dispatch table if any, or else copyreg.dispatch_table reduce = getattr(self, 'dispatch_table', dispatch_table).get(t) - if reduce: + if reduce is not None: rv = reduce(obj) else: # Check for a class with a custom metaclass; treat as regular class @@ -313,11 +495,11 @@ class _Pickler: # Check for a __reduce_ex__ method, fall back to __reduce__ reduce = getattr(obj, "__reduce_ex__", None) - if reduce: + if reduce is not None: rv = reduce(self.proto) else: reduce = getattr(obj, "__reduce__", None) - if reduce: + if reduce is not None: rv = reduce() else: raise PicklingError("Can't pickle %r object: %r" % @@ -353,24 +535,33 @@ class _Pickler: else: self.write(PERSID + str(pid).encode("ascii") + b'\n') - def save_reduce(self, func, args, state=None, - listitems=None, dictitems=None, obj=None): + def save_reduce(self, func, args, state=None, listitems=None, + dictitems=None, obj=None): # This API is called by some subclasses - # Assert that args is a tuple if not isinstance(args, tuple): - raise PicklingError("args from save_reduce() should be a tuple") - - # Assert that func is callable + raise PicklingError("args from save_reduce() must be a tuple") if not callable(func): - raise PicklingError("func from save_reduce() should be callable") + raise PicklingError("func from save_reduce() must be callable") save = self.save write = self.write - # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ - if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__": - # A __reduce__ implementation can direct protocol 2 to + func_name = getattr(func, "__name__", "") + if self.proto >= 4 and func_name == "__newobj_ex__": + cls, args, kwargs = args + if not hasattr(cls, "__new__"): + raise PicklingError("args[0] from {} args has no __new__" + .format(func_name)) + if obj is not None and cls is not obj.__class__: + raise PicklingError("args[0] from {} args has the wrong class" + .format(func_name)) + save(cls) + save(args) + save(kwargs) + write(NEWOBJ_EX) + elif self.proto >= 2 and func_name == "__newobj__": + # A __reduce__ implementation can direct protocol 2 or newer to # use the more efficient NEWOBJ opcode, while still # allowing protocol 0 and 1 to work normally. For this to # work, the function returned by __reduce__ should be @@ -413,7 +604,13 @@ class _Pickler: write(REDUCE) if obj is not None: - self.memoize(obj) + # If the object is already in the memo, this means it is + # recursive. In this case, throw away everything we put on the + # stack, and fetch the object back from the memo. + if id(obj) in self.memo: + write(POP + self.get(self.memo[id(obj)][0])) + else: + self.memoize(obj) # More new special cases (that work with older protocols as # well): when __reduce__ returns a tuple with 4 or 5 items, @@ -438,22 +635,14 @@ class _Pickler: self.write(NONE) dispatch[type(None)] = save_none - def save_ellipsis(self, obj): - self.save_global(Ellipsis, 'Ellipsis') - dispatch[type(Ellipsis)] = save_ellipsis - - def save_notimplemented(self, obj): - self.save_global(NotImplemented, 'NotImplemented') - dispatch[type(NotImplemented)] = save_notimplemented - def save_bool(self, obj): if self.proto >= 2: - self.write(obj and NEWTRUE or NEWFALSE) + self.write(NEWTRUE if obj else NEWFALSE) else: - self.write(obj and TRUE or FALSE) + self.write(TRUE if obj else FALSE) dispatch[bool] = save_bool - def save_long(self, obj, pack=struct.pack): + def save_long(self, obj): if self.bin: # If the int is small enough to fit in a signed 4-byte 2's-comp # format, we can store it more efficiently than the general @@ -461,93 +650,95 @@ class _Pickler: # First one- and two-byte unsigned ints: if obj >= 0: if obj <= 0xff: - self.write(BININT1 + bytes([obj])) + self.write(BININT1 + pack("<B", obj)) return if obj <= 0xffff: - self.write(BININT2 + bytes([obj&0xff, obj>>8])) + self.write(BININT2 + pack("<H", obj)) return # Next check for 4-byte signed ints: - high_bits = obj >> 31 # note that Python shift sign-extends - if high_bits == 0 or high_bits == -1: - # All high bits are copies of bit 2**31, so the value - # fits in a 4-byte signed int. + if -0x80000000 <= obj <= 0x7fffffff: self.write(BININT + pack("<i", obj)) return if self.proto >= 2: encoded = encode_long(obj) n = len(encoded) if n < 256: - self.write(LONG1 + bytes([n]) + encoded) + self.write(LONG1 + pack("<B", n) + encoded) else: self.write(LONG4 + pack("<i", n) + encoded) return self.write(LONG + repr(obj).encode("ascii") + b'L\n') dispatch[int] = save_long - def save_float(self, obj, pack=struct.pack): + def save_float(self, obj): if self.bin: self.write(BINFLOAT + pack('>d', obj)) else: self.write(FLOAT + repr(obj).encode("ascii") + b'\n') dispatch[float] = save_float - def save_bytes(self, obj, pack=struct.pack): + def save_bytes(self, obj): if self.proto < 3: - if len(obj) == 0: + if not obj: # bytes object is empty self.save_reduce(bytes, (), obj=obj) else: self.save_reduce(codecs.encode, (str(obj, 'latin1'), 'latin1'), obj=obj) return n = len(obj) - if n < 256: - self.write(SHORT_BINBYTES + bytes([n]) + bytes(obj)) + if n <= 0xff: + self.write(SHORT_BINBYTES + pack("<B", n) + obj) + elif n > 0xffffffff and self.proto >= 4: + self.write(BINBYTES8 + pack("<Q", n) + obj) else: - self.write(BINBYTES + pack("<I", n) + bytes(obj)) + self.write(BINBYTES + pack("<I", n) + obj) self.memoize(obj) dispatch[bytes] = save_bytes - def save_str(self, obj, pack=struct.pack): + def save_str(self, obj): if self.bin: encoded = obj.encode('utf-8', 'surrogatepass') n = len(encoded) - self.write(BINUNICODE + pack("<I", n) + encoded) + if n <= 0xff and self.proto >= 4: + self.write(SHORT_BINUNICODE + pack("<B", n) + encoded) + elif n > 0xffffffff and self.proto >= 4: + self.write(BINUNICODE8 + pack("<Q", n) + encoded) + else: + self.write(BINUNICODE + pack("<I", n) + encoded) else: obj = obj.replace("\\", "\\u005c") obj = obj.replace("\n", "\\u000a") - self.write(UNICODE + bytes(obj.encode('raw-unicode-escape')) + + self.write(UNICODE + obj.encode('raw-unicode-escape') + b'\n') self.memoize(obj) dispatch[str] = save_str def save_tuple(self, obj): - write = self.write - proto = self.proto - - n = len(obj) - if n == 0: - if proto: - write(EMPTY_TUPLE) + if not obj: # tuple is empty + if self.bin: + self.write(EMPTY_TUPLE) else: - write(MARK + TUPLE) + self.write(MARK + TUPLE) return + n = len(obj) save = self.save memo = self.memo - if n <= 3 and proto >= 2: + if n <= 3 and self.proto >= 2: for element in obj: save(element) # Subtle. Same as in the big comment below. if id(obj) in memo: get = self.get(memo[id(obj)][0]) - write(POP * n + get) + self.write(POP * n + get) else: - write(_tuplesize2code[n]) + self.write(_tuplesize2code[n]) self.memoize(obj) return # proto 0 or proto 1 and tuple isn't empty, or proto > 1 and tuple # has more than 3 elements. + write = self.write write(MARK) for element in obj: save(element) @@ -561,25 +752,23 @@ class _Pickler: # could have been done in the "for element" loop instead, but # recursive tuples are a rare thing. get = self.get(memo[id(obj)][0]) - if proto: + if self.bin: write(POP_MARK + get) else: # proto 0 -- POP_MARK not available write(POP * (n+1) + get) return # No recursion. - self.write(TUPLE) + write(TUPLE) self.memoize(obj) dispatch[tuple] = save_tuple def save_list(self, obj): - write = self.write - if self.bin: - write(EMPTY_LIST) + self.write(EMPTY_LIST) else: # proto 0 -- can't use EMPTY_LIST - write(MARK + LIST) + self.write(MARK + LIST) self.memoize(obj) self._batch_appends(obj) @@ -599,17 +788,9 @@ class _Pickler: write(APPEND) return - items = iter(items) - r = range(self._BATCHSIZE) - while items is not None: - tmp = [] - for i in r: - try: - x = next(items) - tmp.append(x) - except StopIteration: - items = None - break + it = iter(items) + while True: + tmp = list(islice(it, self._BATCHSIZE)) n = len(tmp) if n > 1: write(MARK) @@ -620,14 +801,14 @@ class _Pickler: save(tmp[0]) write(APPEND) # else tmp is empty, and we're done + if n < self._BATCHSIZE: + return def save_dict(self, obj): - write = self.write - if self.bin: - write(EMPTY_DICT) + self.write(EMPTY_DICT) else: # proto 0 -- can't use EMPTY_DICT - write(MARK + DICT) + self.write(MARK + DICT) self.memoize(obj) self._batch_setitems(obj.items()) @@ -648,16 +829,9 @@ class _Pickler: write(SETITEM) return - items = iter(items) - r = range(self._BATCHSIZE) - while items is not None: - tmp = [] - for i in r: - try: - tmp.append(next(items)) - except StopIteration: - items = None - break + it = iter(items) + while True: + tmp = list(islice(it, self._BATCHSIZE)) n = len(tmp) if n > 1: write(MARK) @@ -671,55 +845,109 @@ class _Pickler: save(v) write(SETITEM) # else tmp is empty, and we're done + if n < self._BATCHSIZE: + return + + def save_set(self, obj): + save = self.save + write = self.write + + if self.proto < 4: + self.save_reduce(set, (list(obj),), obj=obj) + return + + write(EMPTY_SET) + self.memoize(obj) + + it = iter(obj) + while True: + batch = list(islice(it, self._BATCHSIZE)) + n = len(batch) + if n > 0: + write(MARK) + for item in batch: + save(item) + write(ADDITEMS) + if n < self._BATCHSIZE: + return + dispatch[set] = save_set + + def save_frozenset(self, obj): + save = self.save + write = self.write + + if self.proto < 4: + self.save_reduce(frozenset, (list(obj),), obj=obj) + return + + write(MARK) + for item in obj: + save(item) + + if id(obj) in self.memo: + # If the object is already in the memo, this means it is + # recursive. In this case, throw away everything we put on the + # stack, and fetch the object back from the memo. + write(POP_MARK + self.get(self.memo[id(obj)][0])) + return - def save_global(self, obj, name=None, pack=struct.pack): + write(FROZENSET) + self.memoize(obj) + dispatch[frozenset] = save_frozenset + + def save_global(self, obj, name=None): write = self.write memo = self.memo + if name is None and self.proto >= 4: + name = getattr(obj, '__qualname__', None) if name is None: name = obj.__name__ - module = getattr(obj, "__module__", None) - if module is None: - module = whichmodule(obj, name) - + module_name = whichmodule(obj, name, allow_qualname=self.proto >= 4) try: - __import__(module, level=0) - mod = sys.modules[module] - klass = getattr(mod, name) + __import__(module_name, level=0) + module = sys.modules[module_name] + obj2 = _getattribute(module, name, allow_qualname=self.proto >= 4) except (ImportError, KeyError, AttributeError): raise PicklingError( "Can't pickle %r: it's not found as %s.%s" % - (obj, module, name)) + (obj, module_name, name)) else: - if klass is not obj: + if obj2 is not obj: raise PicklingError( "Can't pickle %r: it's not the same object as %s.%s" % - (obj, module, name)) + (obj, module_name, name)) if self.proto >= 2: - code = _extension_registry.get((module, name)) + code = _extension_registry.get((module_name, name)) if code: assert code > 0 if code <= 0xff: - write(EXT1 + bytes([code])) + write(EXT1 + pack("<B", code)) elif code <= 0xffff: - write(EXT2 + bytes([code&0xff, code>>8])) + write(EXT2 + pack("<H", code)) else: write(EXT4 + pack("<i", code)) return # Non-ASCII identifiers are supported only with protocols >= 3. - if self.proto >= 3: - write(GLOBAL + bytes(module, "utf-8") + b'\n' + + if self.proto >= 4: + self.save(module_name) + self.save(name) + write(STACK_GLOBAL) + elif self.proto >= 3: + write(GLOBAL + bytes(module_name, "utf-8") + b'\n' + bytes(name, "utf-8") + b'\n') else: if self.fix_imports: - if (module, name) in _compat_pickle.REVERSE_NAME_MAPPING: - module, name = _compat_pickle.REVERSE_NAME_MAPPING[(module, name)] - if module in _compat_pickle.REVERSE_IMPORT_MAPPING: - module = _compat_pickle.REVERSE_IMPORT_MAPPING[module] + r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING + r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING + if (module_name, name) in r_name_mapping: + module_name, name = r_name_mapping[(module_name, name)] + elif module_name in r_import_mapping: + module_name = r_import_mapping[module_name] try: - write(GLOBAL + bytes(module, "ascii") + b'\n' + + write(GLOBAL + bytes(module_name, "ascii") + b'\n' + bytes(name, "ascii") + b'\n') except UnicodeEncodeError: raise PicklingError( @@ -738,58 +966,8 @@ class _Pickler: return self.save_global(obj) dispatch[FunctionType] = save_global - dispatch[BuiltinFunctionType] = save_global dispatch[type] = save_type -# Pickling helpers - -def _keep_alive(x, memo): - """Keeps a reference to the object x in the memo. - - Because we remember objects by their id, we have - to assure that possibly temporary objects are kept - alive by referencing them. - We store a reference at the id of the memo, which should - normally not be used unless someone tries to deepcopy - the memo itself... - """ - try: - memo[id(memo)].append(x) - except KeyError: - # aha, this is the first one :-) - memo[id(memo)]=[x] - - -# A cache for whichmodule(), mapping a function object to the name of -# the module in which the function was found. - -classmap = {} # called classmap for backwards compatibility - -def whichmodule(func, funcname): - """Figure out the module in which a function occurs. - - Search sys.modules for the module. - Cache in classmap. - Return a module name. - If the function cannot be found, return "__main__". - """ - # Python functions should always get an __module__ from their globals. - mod = getattr(func, "__module__", None) - if mod is not None: - return mod - if func in classmap: - return classmap[func] - - for name, module in list(sys.modules.items()): - if module is None: - continue # skip dummy package entries - if name != '__main__' and getattr(module, funcname, None) is func: - break - else: - name = '__main__' - classmap[func] = name - return name - # Unpickling machinery @@ -799,8 +977,14 @@ class _Unpickler: encoding="ASCII", errors="strict"): """This takes a binary file for reading a pickle data stream. - The protocol version of the pickle is detected automatically, so no - proto argument is needed. + The protocol version of the pickle is detected automatically, so + no proto argument is needed. + + The argument *file* must have two methods, a read() method that + takes an integer argument, and a readline() method that requires + no arguments. Both methods should return bytes. Thus *file* + can be a binary file object opened for reading, a io.BytesIO + object, or any other custom object that meets this interface. The file-like object must have two methods, a read() method that takes an integer argument, and a readline() method that @@ -809,16 +993,17 @@ class _Unpickler: reading, a BytesIO object, or any other custom object that meets this interface. - Optional keyword arguments are *fix_imports*, *encoding* and *errors*, - which are used to control compatiblity support for pickle stream - generated by Python 2.x. If *fix_imports* is True, pickle will try to - map the old Python 2.x names to the new names used in Python 3.x. The - *encoding* and *errors* tell pickle how to decode 8-bit string - instances pickled by Python 2.x; these default to 'ASCII' and - 'strict', respectively. + Optional keyword arguments are *fix_imports*, *encoding* and + *errors*, which are used to control compatiblity support for + pickle stream generated by Python 2. If *fix_imports* is True, + pickle will try to map the old Python 2 names to the new names + used in Python 3. The *encoding* and *errors* tell pickle how + to decode 8-bit string instances pickled by Python 2; these + default to 'ASCII' and 'strict', respectively. *encoding* can be + 'bytes' to read theses 8-bit string instances as bytes objects. """ - self.readline = file.readline - self.read = file.read + self._file_readline = file.readline + self._file_read = file.read self.memo = {} self.encoding = encoding self.errors = errors @@ -832,16 +1017,20 @@ class _Unpickler: """ # Check whether Unpickler was initialized correctly. This is # only needed to mimic the behavior of _pickle.Unpickler.dump(). - if not hasattr(self, "read"): + if not hasattr(self, "_file_read"): raise UnpicklingError("Unpickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,)) + self._unframer = _Unframer(self._file_read, self._file_readline) + self.read = self._unframer.read + self.readline = self._unframer.readline self.mark = object() # any new unique object self.stack = [] self.append = self.stack.append + self.proto = 0 read = self.read dispatch = self.dispatch try: - while 1: + while True: key = read(1) if not key: raise EOFError @@ -871,12 +1060,19 @@ class _Unpickler: dispatch = {} def load_proto(self): - proto = ord(self.read(1)) + proto = self.read(1)[0] if not 0 <= proto <= HIGHEST_PROTOCOL: raise ValueError("unsupported pickle protocol: %d" % proto) self.proto = proto dispatch[PROTO[0]] = load_proto + def load_frame(self): + frame_size, = unpack('<Q', self.read(8)) + if frame_size > sys.maxsize: + raise ValueError("frame size > sys.maxsize: %d" % frame_size) + self._unframer.load_frame(frame_size) + dispatch[FRAME[0]] = load_frame + def load_persid(self): pid = self.readline()[:-1].decode("ascii") self.append(self.persistent_load(pid)) @@ -906,43 +1102,40 @@ class _Unpickler: elif data == TRUE[1:]: val = True else: - try: - val = int(data, 0) - except ValueError: - val = int(data, 0) + val = int(data, 0) self.append(val) dispatch[INT[0]] = load_int def load_binint(self): - self.append(mloads(b'i' + self.read(4))) + self.append(unpack('<i', self.read(4))[0]) dispatch[BININT[0]] = load_binint def load_binint1(self): - self.append(ord(self.read(1))) + self.append(self.read(1)[0]) dispatch[BININT1[0]] = load_binint1 def load_binint2(self): - self.append(mloads(b'i' + self.read(2) + b'\000\000')) + self.append(unpack('<H', self.read(2))[0]) dispatch[BININT2[0]] = load_binint2 def load_long(self): - val = self.readline()[:-1].decode("ascii") - if val and val[-1] == 'L': + val = self.readline()[:-1] + if val and val[-1] == b'L'[0]: val = val[:-1] self.append(int(val, 0)) dispatch[LONG[0]] = load_long def load_long1(self): - n = ord(self.read(1)) + n = self.read(1)[0] data = self.read(n) self.append(decode_long(data)) dispatch[LONG1[0]] = load_long1 def load_long4(self): - n = mloads(b'i' + self.read(4)) + n, = unpack('<i', self.read(4)) if n < 0: # Corrupt or hostile pickle -- we never write one like this - raise UnpicklingError("LONG pickle has negative byte count"); + raise UnpicklingError("LONG pickle has negative byte count") data = self.read(n) self.append(decode_long(data)) dispatch[LONG4[0]] = load_long4 @@ -951,39 +1144,43 @@ class _Unpickler: self.append(float(self.readline()[:-1])) dispatch[FLOAT[0]] = load_float - def load_binfloat(self, unpack=struct.unpack): + def load_binfloat(self): self.append(unpack('>d', self.read(8))[0]) dispatch[BINFLOAT[0]] = load_binfloat + def _decode_string(self, value): + # Used to allow strings from Python 2 to be decoded either as + # bytes or Unicode strings. This should be used only with the + # STRING, BINSTRING and SHORT_BINSTRING opcodes. + if self.encoding == "bytes": + return value + else: + return value.decode(self.encoding, self.errors) + def load_string(self): - orig = self.readline() - rep = orig[:-1] - for q in (b'"', b"'"): # double or single quote - if rep.startswith(q): - if len(rep) < 2 or not rep.endswith(q): - raise ValueError("insecure string pickle") - rep = rep[len(q):-len(q)] - break + data = self.readline()[:-1] + # Strip outermost quotes + if len(data) >= 2 and data[0] == data[-1] and data[0] in b'"\'': + data = data[1:-1] else: - raise ValueError("insecure string pickle: %r" % orig) - self.append(codecs.escape_decode(rep)[0] - .decode(self.encoding, self.errors)) + raise UnpicklingError("the STRING opcode argument must be quoted") + self.append(self._decode_string(codecs.escape_decode(data)[0])) dispatch[STRING[0]] = load_string def load_binstring(self): # Deprecated BINSTRING uses signed 32-bit length - len = mloads(b'i' + self.read(4)) + len, = unpack('<i', self.read(4)) if len < 0: - raise UnpicklingError("BINSTRING pickle has negative byte count"); + raise UnpicklingError("BINSTRING pickle has negative byte count") data = self.read(len) - value = str(data, self.encoding, self.errors) - self.append(value) + self.append(self._decode_string(data)) dispatch[BINSTRING[0]] = load_binstring - def load_binbytes(self, unpack=struct.unpack, maxsize=sys.maxsize): + def load_binbytes(self): len, = unpack('<I', self.read(4)) if len > maxsize: - raise UnpicklingError("BINBYTES exceeds system's maximum size of %d bytes" % maxsize); + raise UnpicklingError("BINBYTES exceeds system's maximum size " + "of %d bytes" % maxsize) self.append(self.read(len)) dispatch[BINBYTES[0]] = load_binbytes @@ -991,25 +1188,38 @@ class _Unpickler: self.append(str(self.readline()[:-1], 'raw-unicode-escape')) dispatch[UNICODE[0]] = load_unicode - def load_binunicode(self, unpack=struct.unpack, maxsize=sys.maxsize): + def load_binunicode(self): len, = unpack('<I', self.read(4)) if len > maxsize: - raise UnpicklingError("BINUNICODE exceeds system's maximum size of %d bytes" % maxsize); + raise UnpicklingError("BINUNICODE exceeds system's maximum size " + "of %d bytes" % maxsize) self.append(str(self.read(len), 'utf-8', 'surrogatepass')) dispatch[BINUNICODE[0]] = load_binunicode + def load_binunicode8(self): + len, = unpack('<Q', self.read(8)) + if len > maxsize: + raise UnpicklingError("BINUNICODE8 exceeds system's maximum size " + "of %d bytes" % maxsize) + self.append(str(self.read(len), 'utf-8', 'surrogatepass')) + dispatch[BINUNICODE8[0]] = load_binunicode8 + def load_short_binstring(self): - len = ord(self.read(1)) - data = bytes(self.read(len)) - value = str(data, self.encoding, self.errors) - self.append(value) + len = self.read(1)[0] + data = self.read(len) + self.append(self._decode_string(data)) dispatch[SHORT_BINSTRING[0]] = load_short_binstring def load_short_binbytes(self): - len = ord(self.read(1)) - self.append(bytes(self.read(len))) + len = self.read(1)[0] + self.append(self.read(len)) dispatch[SHORT_BINBYTES[0]] = load_short_binbytes + def load_short_binunicode(self): + len = self.read(1)[0] + self.append(str(self.read(len), 'utf-8', 'surrogatepass')) + dispatch[SHORT_BINUNICODE[0]] = load_short_binunicode + def load_tuple(self): k = self.marker() self.stack[k:] = [tuple(self.stack[k+1:])] @@ -1039,6 +1249,15 @@ class _Unpickler: self.append({}) dispatch[EMPTY_DICT[0]] = load_empty_dictionary + def load_empty_set(self): + self.append(set()) + dispatch[EMPTY_SET[0]] = load_empty_set + + def load_frozenset(self): + k = self.marker() + self.stack[k:] = [frozenset(self.stack[k+1:])] + dispatch[FROZENSET[0]] = load_frozenset + def load_list(self): k = self.marker() self.stack[k:] = [self.stack[k+1:]] @@ -1046,12 +1265,9 @@ class _Unpickler: def load_dict(self): k = self.marker() - d = {} items = self.stack[k+1:] - for i in range(0, len(items), 2): - key = items[i] - value = items[i+1] - d[key] = value + d = {items[i]: items[i+1] + for i in range(0, len(items), 2)} self.stack[k:] = [d] dispatch[DICT[0]] = load_dict @@ -1090,11 +1306,19 @@ class _Unpickler: def load_newobj(self): args = self.stack.pop() - cls = self.stack[-1] + cls = self.stack.pop() obj = cls.__new__(cls, *args) - self.stack[-1] = obj + self.append(obj) dispatch[NEWOBJ[0]] = load_newobj + def load_newobj_ex(self): + kwargs = self.stack.pop() + args = self.stack.pop() + cls = self.stack.pop() + obj = cls.__new__(cls, *args, **kwargs) + self.append(obj) + dispatch[NEWOBJ_EX[0]] = load_newobj_ex + def load_global(self): module = self.readline()[:-1].decode("utf-8") name = self.readline()[:-1].decode("utf-8") @@ -1102,18 +1326,26 @@ class _Unpickler: self.append(klass) dispatch[GLOBAL[0]] = load_global + def load_stack_global(self): + name = self.stack.pop() + module = self.stack.pop() + if type(name) is not str or type(module) is not str: + raise UnpicklingError("STACK_GLOBAL requires str") + self.append(self.find_class(module, name)) + dispatch[STACK_GLOBAL[0]] = load_stack_global + def load_ext1(self): - code = ord(self.read(1)) + code = self.read(1)[0] self.get_extension(code) dispatch[EXT1[0]] = load_ext1 def load_ext2(self): - code = mloads(b'i' + self.read(2) + b'\000\000') + code, = unpack('<H', self.read(2)) self.get_extension(code) dispatch[EXT2[0]] = load_ext2 def load_ext4(self): - code = mloads(b'i' + self.read(4)) + code, = unpack('<i', self.read(4)) self.get_extension(code) dispatch[EXT4[0]] = load_ext4 @@ -1127,7 +1359,7 @@ class _Unpickler: if not key: if code <= 0: # note that 0 is forbidden # Corrupt or hostile pickle. - raise UnpicklingError("EXT specifies code <= 0"); + raise UnpicklingError("EXT specifies code <= 0") raise ValueError("unregistered extension code %d" % code) obj = self.find_class(*key) _extension_cache[code] = obj @@ -1138,12 +1370,11 @@ class _Unpickler: if self.proto < 3 and self.fix_imports: if (module, name) in _compat_pickle.NAME_MAPPING: module, name = _compat_pickle.NAME_MAPPING[(module, name)] - if module in _compat_pickle.IMPORT_MAPPING: + elif module in _compat_pickle.IMPORT_MAPPING: module = _compat_pickle.IMPORT_MAPPING[module] __import__(module, level=0) - mod = sys.modules[module] - klass = getattr(mod, name) - return klass + return _getattribute(sys.modules[module], name, + allow_qualname=self.proto >= 4) def load_reduce(self): stack = self.stack @@ -1181,7 +1412,7 @@ class _Unpickler: self.append(self.memo[i]) dispatch[BINGET[0]] = load_binget - def load_long_binget(self, unpack=struct.unpack): + def load_long_binget(self): i, = unpack('<I', self.read(4)) self.append(self.memo[i]) dispatch[LONG_BINGET[0]] = load_long_binget @@ -1200,13 +1431,18 @@ class _Unpickler: self.memo[i] = self.stack[-1] dispatch[BINPUT[0]] = load_binput - def load_long_binput(self, unpack=struct.unpack, maxsize=sys.maxsize): + def load_long_binput(self): i, = unpack('<I', self.read(4)) if i > maxsize: raise ValueError("negative LONG_BINPUT argument") self.memo[i] = self.stack[-1] dispatch[LONG_BINPUT[0]] = load_long_binput + def load_memoize(self): + memo = self.memo + memo[len(memo)] = self.stack[-1] + dispatch[MEMOIZE[0]] = load_memoize + def load_append(self): stack = self.stack value = stack.pop() @@ -1246,12 +1482,26 @@ class _Unpickler: del stack[mark:] dispatch[SETITEMS[0]] = load_setitems + def load_additems(self): + stack = self.stack + mark = self.marker() + set_obj = stack[mark - 1] + items = stack[mark + 1:] + if isinstance(set_obj, set): + set_obj.update(items) + else: + add = set_obj.add + for item in items: + add(item) + del stack[mark:] + dispatch[ADDITEMS[0]] = load_additems + def load_build(self): stack = self.stack state = stack.pop() inst = stack[-1] setstate = getattr(inst, "__setstate__", None) - if setstate: + if setstate is not None: setstate(state) return slotstate = None @@ -1279,86 +1529,46 @@ class _Unpickler: raise _Stop(value) dispatch[STOP[0]] = load_stop -# Encode/decode ints. - -def encode_long(x): - r"""Encode a long to a two's complement little-endian binary string. - Note that 0 is a special case, returning an empty string, to save a - byte in the LONG1 pickling context. - - >>> encode_long(0) - b'' - >>> encode_long(255) - b'\xff\x00' - >>> encode_long(32767) - b'\xff\x7f' - >>> encode_long(-256) - b'\x00\xff' - >>> encode_long(-32768) - b'\x00\x80' - >>> encode_long(-128) - b'\x80' - >>> encode_long(127) - b'\x7f' - >>> - """ - if x == 0: - return b'' - nbytes = (x.bit_length() >> 3) + 1 - result = x.to_bytes(nbytes, byteorder='little', signed=True) - if x < 0 and nbytes > 1: - if result[-1] == 0xff and (result[-2] & 0x80) != 0: - result = result[:-1] - return result - -def decode_long(data): - r"""Decode an int from a two's complement little-endian binary string. - - >>> decode_long(b'') - 0 - >>> decode_long(b"\xff\x00") - 255 - >>> decode_long(b"\xff\x7f") - 32767 - >>> decode_long(b"\x00\xff") - -256 - >>> decode_long(b"\x00\x80") - -32768 - >>> decode_long(b"\x80") - -128 - >>> decode_long(b"\x7f") - 127 - """ - return int.from_bytes(data, byteorder='little', signed=True) # Shorthands -def dump(obj, file, protocol=None, *, fix_imports=True): - Pickler(file, protocol, fix_imports=fix_imports).dump(obj) +def _dump(obj, file, protocol=None, *, fix_imports=True): + _Pickler(file, protocol, fix_imports=fix_imports).dump(obj) -def dumps(obj, protocol=None, *, fix_imports=True): +def _dumps(obj, protocol=None, *, fix_imports=True): f = io.BytesIO() - Pickler(f, protocol, fix_imports=fix_imports).dump(obj) + _Pickler(f, protocol, fix_imports=fix_imports).dump(obj) res = f.getvalue() assert isinstance(res, bytes_types) return res -def load(file, *, fix_imports=True, encoding="ASCII", errors="strict"): - return Unpickler(file, fix_imports=fix_imports, +def _load(file, *, fix_imports=True, encoding="ASCII", errors="strict"): + return _Unpickler(file, fix_imports=fix_imports, encoding=encoding, errors=errors).load() -def loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"): +def _loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"): if isinstance(s, str): raise TypeError("Can't load pickle from unicode string") file = io.BytesIO(s) - return Unpickler(file, fix_imports=fix_imports, - encoding=encoding, errors=errors).load() + return _Unpickler(file, fix_imports=fix_imports, + encoding=encoding, errors=errors).load() # Use the faster _pickle if possible try: - from _pickle import * + from _pickle import ( + PickleError, + PicklingError, + UnpicklingError, + Pickler, + Unpickler, + dump, + dumps, + load, + loads + ) except ImportError: Pickler, Unpickler = _Pickler, _Unpickler + dump, dumps, load, loads = _dump, _dumps, _load, _loads # Doctest def _test(): |