From c9dc4a2a8a6dcfe1674685bea4a4af935c0e37ca Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Sat, 23 Nov 2013 18:59:12 +0100 Subject: Issue #17810: Implement PEP 3154, pickle protocol 4. Most of the work is by Alexandre. --- Doc/library/pickle.rst | 35 +- Doc/whatsnew/3.4.rst | 15 + Lib/copyreg.py | 6 + Lib/pickle.py | 582 +++++++++++++------ Lib/pickletools.py | 471 +++++++++++++--- Lib/test/pickletester.py | 487 +++++++++++----- Lib/test/test_descr.py | 605 ++++++++++++++------ Misc/NEWS | 2 + Modules/_pickle.c | 1380 ++++++++++++++++++++++++++++++++++------------ Objects/classobject.c | 26 +- Objects/descrobject.c | 45 +- Objects/typeobject.c | 466 ++++++++++++---- 12 files changed, 3123 insertions(+), 997 deletions(-) diff --git a/Doc/library/pickle.rst b/Doc/library/pickle.rst index 9404a47..49ec9c1 100644 --- a/Doc/library/pickle.rst +++ b/Doc/library/pickle.rst @@ -459,12 +459,29 @@ implementation of this behaviour:: Classes can alter the default behaviour by providing one or several special methods: +.. method:: object.__getnewargs_ex__() + + In protocols 4 and newer, classes that implements the + :meth:`__getnewargs_ex__` method can dictate the values passed to the + :meth:`__new__` method upon unpickling. The method must return a pair + ``(args, kwargs)`` where *args* is a tuple of positional arguments + and *kwargs* a dictionary of named arguments for constructing the + object. Those will be passed to the :meth:`__new__` method upon + unpickling. + + You should implement this method if the :meth:`__new__` method of your + class requires keyword-only arguments. Otherwise, it is recommended for + compatibility to implement :meth:`__getnewargs__`. + + .. method:: object.__getnewargs__() - In protocol 2 and newer, classes that implements the :meth:`__getnewargs__` - method can dictate the values passed to the :meth:`__new__` method upon - unpickling. This is often needed for classes whose :meth:`__new__` method - requires arguments. + This method serve a similar purpose as :meth:`__getnewargs_ex__` but + for protocols 2 and newer. It must return a tuple of arguments `args` + which will be passed to the :meth:`__new__` method upon unpickling. + + In protocols 4 and newer, :meth:`__getnewargs__` will not be called if + :meth:`__getnewargs_ex__` is defined. .. method:: object.__getstate__() @@ -496,10 +513,10 @@ the methods :meth:`__getstate__` and :meth:`__setstate__`. At unpickling time, some methods like :meth:`__getattr__`, :meth:`__getattribute__`, or :meth:`__setattr__` may be called upon the - instance. In case those methods rely on some internal invariant being true, - the type should implement :meth:`__getnewargs__` to establish such an - invariant; otherwise, neither :meth:`__new__` nor :meth:`__init__` will be - called. + instance. In case those methods rely on some internal invariant being + true, the type should implement :meth:`__getnewargs__` or + :meth:`__getnewargs_ex__` to establish such an invariant; otherwise, + neither :meth:`__new__` nor :meth:`__init__` will be called. .. index:: pair: copy; protocol @@ -511,7 +528,7 @@ objects. [#]_ Although powerful, implementing :meth:`__reduce__` directly in your classes is error prone. For this reason, class designers should use the high-level -interface (i.e., :meth:`__getnewargs__`, :meth:`__getstate__` and +interface (i.e., :meth:`__getnewargs_ex__`, :meth:`__getstate__` and :meth:`__setstate__`) whenever possible. We will show, however, cases where using :meth:`__reduce__` is the only option or leads to more efficient pickling or both. diff --git a/Doc/whatsnew/3.4.rst b/Doc/whatsnew/3.4.rst index b509516..6f949a9 100644 --- a/Doc/whatsnew/3.4.rst +++ b/Doc/whatsnew/3.4.rst @@ -109,6 +109,7 @@ New expected features for Python implementations: Significantly Improved Library Modules: * Single-dispatch generic functions in :mod:`functoools` (:pep:`443`) +* New :mod:`pickle` protocol 4 (:pep:`3154`) * SHA-3 (Keccak) support for :mod:`hashlib`. * TLSv1.1 and TLSv1.2 support for :mod:`ssl`. * :mod:`multiprocessing` now has option to avoid using :func:`os.fork` @@ -285,6 +286,20 @@ described in the PEP. Existing importers should be updated to implement the new methods. +Pickle protocol 4 +================= + +The new :mod:`pickle` protocol addresses a number of issues that were present +in previous protocols, such as the serialization of nested classes, very +large strings and containers, or classes whose :meth:`__new__` method takes +keyword-only arguments. It also brings a couple efficiency improvements. + +.. seealso:: + + :pep:`3154` - Pickle protocol 4 + PEP written by Antoine Pitrou and implemented by Alexandre Vassalotti. + + Other Language Changes ====================== diff --git a/Lib/copyreg.py b/Lib/copyreg.py index 66c0f8a..67f5bb0 100644 --- a/Lib/copyreg.py +++ b/Lib/copyreg.py @@ -87,6 +87,12 @@ def _reduce_ex(self, proto): def __newobj__(cls, *args): return cls.__new__(cls, *args) +def __newobj_ex__(cls, args, kwargs): + """Used by pickle protocol 4, instead of __newobj__ to allow classes with + keyword-only arguments to be pickled correctly. + """ + return cls.__new__(cls, *args, **kwargs) + def _slotnames(cls): """Return a list of slot names for a given class. diff --git a/Lib/pickle.py b/Lib/pickle.py index dbc196a..d1f1538 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -23,7 +23,7 @@ Misc variables: """ -from types import FunctionType, BuiltinFunctionType +from types import FunctionType, BuiltinFunctionType, ModuleType from copyreg import dispatch_table from copyreg import _extension_registry, _inverted_registry, _extension_cache from itertools import islice @@ -42,17 +42,18 @@ __all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler", bytes_types = (bytes, bytearray) # These are purely informational; no code uses these. -format_version = "3.0" # File format version we write +format_version = "4.0" # File format version we write compatible_formats = ["1.0", # Original protocol 0 "1.1", # Protocol 0 with INST added "1.2", # Original protocol 1 "1.3", # Protocol 1 with BINFLOAT added "2.0", # Protocol 2 "3.0", # Protocol 3 + "4.0", # Protocol 4 ] # Old format versions we can read # This is the highest protocol number we know how to read. -HIGHEST_PROTOCOL = 3 +HIGHEST_PROTOCOL = 4 # The protocol we write by default. May be less than HIGHEST_PROTOCOL. # We intentionally write a protocol that Python 2.x cannot read; @@ -164,7 +165,196 @@ _tuplesize2code = [EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3] BINBYTES = b'B' # push bytes; counted binary string argument SHORT_BINBYTES = b'C' # " " ; " " " " < 256 bytes -__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$",x)]) +# Protocol 4 +SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes +BINUNICODE8 = b'\x8d' # push very long string +BINBYTES8 = b'\x8e' # push very long bytes string +EMPTY_SET = b'\x8f' # push empty set on the stack +ADDITEMS = b'\x90' # modify set by adding topmost stack items +FROZENSET = b'\x91' # build frozenset from topmost stack items +NEWOBJ_EX = b'\x92' # like NEWOBJ but work with keyword only arguments +STACK_GLOBAL = b'\x93' # same as GLOBAL but using names on the stacks +MEMOIZE = b'\x94' # store top of the stack in memo +FRAME = b'\x95' # indicate the beginning of a new frame + +__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)]) + + +class _Framer: + + _FRAME_SIZE_TARGET = 64 * 1024 + + def __init__(self, file_write): + self.file_write = file_write + self.current_frame = None + + def _commit_frame(self): + f = self.current_frame + with f.getbuffer() as data: + n = len(data) + write = self.file_write + write(FRAME) + write(pack("= self._FRAME_SIZE_TARGET: + self._commit_frame() + return f.write(data) + +class _Unframer: + + def __init__(self, file_read, file_readline, file_tell=None): + self.file_read = file_read + self.file_readline = file_readline + self.file_tell = file_tell + self.framing_enabled = False + self.current_frame = None + self.frame_start = None + + def read(self, n): + if n == 0: + return b'' + _file_read = self.file_read + if not self.framing_enabled: + return _file_read(n) + f = self.current_frame + if f is not None: + data = f.read(n) + if data: + if len(data) < n: + raise UnpicklingError( + "pickle exhausted before end of frame") + return data + frame_opcode = _file_read(1) + if frame_opcode != FRAME: + raise UnpicklingError( + "expected a FRAME opcode, got {} instead".format(frame_opcode)) + frame_size, = unpack(" sys.maxsize: + raise ValueError("frame size > sys.maxsize: %d" % frame_size) + if self.file_tell is not None: + self.frame_start = self.file_tell() + f = self.current_frame = io.BytesIO(_file_read(frame_size)) + self.readline = f.readline + data = f.read(n) + assert len(data) == n, (len(data), n) + return data + + def readline(self): + if not self.framing_enabled: + return self.file_readline() + else: + return self.current_frame.readline() + + def tell(self): + if self.file_tell is None: + return None + elif self.current_frame is None: + return self.file_tell() + else: + return self.frame_start + self.current_frame.tell() + + +# Tools used for pickling. + +def _getattribute(obj, name, allow_qualname=False): + dotted_path = name.split(".") + if not allow_qualname and len(dotted_path) > 1: + raise AttributeError("Can't get qualified attribute {!r} on {!r}; " + + "use protocols >= 4 to enable support" + .format(name, obj)) + for subpath in dotted_path: + if subpath == '': + raise AttributeError("Can't get local attribute {!r} on {!r}" + .format(name, obj)) + try: + obj = getattr(obj, subpath) + except AttributeError: + raise AttributeError("Can't get attribute {!r} on {!r}" + .format(name, obj)) + return obj + +def whichmodule(obj, name, allow_qualname=False): + """Find the module an object belong to.""" + module_name = getattr(obj, '__module__', None) + if module_name is not None: + return module_name + for module_name, module in sys.modules.items(): + if module_name == '__main__' or module is None: + continue + try: + if _getattribute(module, name, allow_qualname) is obj: + return module_name + except AttributeError: + pass + return '__main__' + +def encode_long(x): + r"""Encode a long to a two's complement little-endian binary string. + Note that 0 is a special case, returning an empty string, to save a + byte in the LONG1 pickling context. + + >>> encode_long(0) + b'' + >>> encode_long(255) + b'\xff\x00' + >>> encode_long(32767) + b'\xff\x7f' + >>> encode_long(-256) + b'\x00\xff' + >>> encode_long(-32768) + b'\x00\x80' + >>> encode_long(-128) + b'\x80' + >>> encode_long(127) + b'\x7f' + >>> + """ + if x == 0: + return b'' + nbytes = (x.bit_length() >> 3) + 1 + result = x.to_bytes(nbytes, byteorder='little', signed=True) + if x < 0 and nbytes > 1: + if result[-1] == 0xff and (result[-2] & 0x80) != 0: + result = result[:-1] + return result + +def decode_long(data): + r"""Decode a long from a two's complement little-endian binary string. + + >>> decode_long(b'') + 0 + >>> decode_long(b"\xff\x00") + 255 + >>> decode_long(b"\xff\x7f") + 32767 + >>> decode_long(b"\x00\xff") + -256 + >>> decode_long(b"\x00\x80") + -32768 + >>> decode_long(b"\x80") + -128 + >>> decode_long(b"\x7f") + 127 + """ + return int.from_bytes(data, byteorder='little', signed=True) + # Pickling machinery @@ -174,9 +364,9 @@ class _Pickler: """This takes a binary file for writing a pickle data stream. The optional protocol argument tells the pickler to use the - given protocol; supported protocols are 0, 1, 2, 3. The default - protocol is 3; a backward-incompatible protocol designed for - Python 3.0. + given protocol; supported protocols are 0, 1, 2, 3 and 4. The + default protocol is 3; a backward-incompatible protocol designed for + Python 3. Specifying a negative protocol version selects the highest protocol version supported. The higher the protocol used, the @@ -189,8 +379,8 @@ class _Pickler: meets this interface. If fix_imports is True and protocol is less than 3, pickle will try to - map the new Python 3.x names to the old module names used in Python - 2.x, so that the pickle data stream is readable with Python 2.x. + map the new Python 3 names to the old module names used in Python 2, + so that the pickle data stream is readable with Python 2. """ if protocol is None: protocol = DEFAULT_PROTOCOL @@ -199,7 +389,7 @@ class _Pickler: elif not 0 <= protocol <= HIGHEST_PROTOCOL: raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL) try: - self.write = file.write + self._file_write = file.write except AttributeError: raise TypeError("file must have a 'write' attribute") self.memo = {} @@ -223,13 +413,22 @@ class _Pickler: """Write a pickled representation of obj to the open file.""" # Check whether Pickler was initialized correctly. This is # only needed to mimic the behavior of _pickle.Pickler.dump(). - if not hasattr(self, "write"): + if not hasattr(self, "_file_write"): raise PicklingError("Pickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,)) if self.proto >= 2: - self.write(PROTO + pack("= 4: + framer = _Framer(self._file_write) + framer.start_framing() + self.write = framer.write + else: + framer = None + self.write = self._file_write self.save(obj) self.write(STOP) + if framer is not None: + framer.end_framing() def memoize(self, obj): """Store an object in the memo.""" @@ -249,19 +448,21 @@ class _Pickler: if self.fast: return assert id(obj) not in self.memo - memo_len = len(self.memo) - self.write(self.put(memo_len)) - self.memo[id(obj)] = memo_len, obj + idx = len(self.memo) + self.write(self.put(idx)) + self.memo[id(obj)] = idx, obj # Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i. - def put(self, i): - if self.bin: - if i < 256: - return BINPUT + pack("= 4: + return MEMOIZE + elif self.bin: + if idx < 256: + return BINPUT + pack("= 2 and getattr(func, "__name__", "") == "__newobj__": - # A __reduce__ implementation can direct protocol 2 to + func_name = getattr(func, "__name__", "") + if self.proto >= 4 and func_name == "__newobj_ex__": + cls, args, kwargs = args + if not hasattr(cls, "__new__"): + raise PicklingError("args[0] from {} args has no __new__" + .format(func_name)) + if obj is not None and cls is not obj.__class__: + raise PicklingError("args[0] from {} args has the wrong class" + .format(func_name)) + save(cls) + save(args) + save(kwargs) + write(NEWOBJ_EX) + elif self.proto >= 2 and func_name == "__newobj__": + # A __reduce__ implementation can direct protocol 2 or newer to # use the more efficient NEWOBJ opcode, while still # allowing protocol 0 and 1 to work normally. For this to # work, the function returned by __reduce__ should be @@ -409,7 +619,13 @@ class _Pickler: write(REDUCE) if obj is not None: - self.memoize(obj) + # If the object is already in the memo, this means it is + # recursive. In this case, throw away everything we put on the + # stack, and fetch the object back from the memo. + if id(obj) in self.memo: + write(POP + self.get(self.memo[id(obj)][0])) + else: + self.memoize(obj) # More new special cases (that work with older protocols as # well): when __reduce__ returns a tuple with 4 or 5 items, @@ -493,8 +709,10 @@ class _Pickler: (str(obj, 'latin1'), 'latin1'), obj=obj) return n = len(obj) - if n < 256: + if n <= 0xff: self.write(SHORT_BINBYTES + pack(" 0xffffffff and self.proto >= 4: + self.write(BINBYTES8 + pack("= 4: + self.write(SHORT_BINUNICODE + pack(" 0xffffffff and self.proto >= 4: + self.write(BINUNICODE8 + pack(" 0: + write(MARK) + for item in batch: + save(item) + write(ADDITEMS) + if n < self._BATCHSIZE: + return + dispatch[set] = save_set + + def save_frozenset(self, obj): + save = self.save + write = self.write + + if self.proto < 4: + self.save_reduce(frozenset, (list(obj),), obj=obj) + return + + write(MARK) + for item in obj: + save(item) + + if id(obj) in self.memo: + # If the object is already in the memo, this means it is + # recursive. In this case, throw away everything we put on the + # stack, and fetch the object back from the memo. + write(POP_MARK + self.get(self.memo[id(obj)][0])) + return + + write(FROZENSET) + self.memoize(obj) + dispatch[frozenset] = save_frozenset + def save_global(self, obj, name=None): write = self.write memo = self.memo + if name is None and self.proto >= 4: + name = getattr(obj, '__qualname__', None) if name is None: name = obj.__name__ - module = getattr(obj, "__module__", None) - if module is None: - module = whichmodule(obj, name) - + module_name = whichmodule(obj, name, allow_qualname=self.proto >= 4) try: - __import__(module, level=0) - mod = sys.modules[module] - klass = getattr(mod, name) + __import__(module_name, level=0) + module = sys.modules[module_name] + obj2 = _getattribute(module, name, allow_qualname=self.proto >= 4) except (ImportError, KeyError, AttributeError): raise PicklingError( "Can't pickle %r: it's not found as %s.%s" % - (obj, module, name)) + (obj, module_name, name)) else: - if klass is not obj: + if obj2 is not obj: raise PicklingError( "Can't pickle %r: it's not the same object as %s.%s" % - (obj, module, name)) + (obj, module_name, name)) if self.proto >= 2: - code = _extension_registry.get((module, name)) + code = _extension_registry.get((module_name, name)) if code: assert code > 0 if code <= 0xff: @@ -684,17 +954,23 @@ class _Pickler: write(EXT4 + pack("= 3. - if self.proto >= 3: - write(GLOBAL + bytes(module, "utf-8") + b'\n' + + if self.proto >= 4: + self.save(module_name) + self.save(name) + write(STACK_GLOBAL) + elif self.proto >= 3: + write(GLOBAL + bytes(module_name, "utf-8") + b'\n' + bytes(name, "utf-8") + b'\n') else: if self.fix_imports: - if (module, name) in _compat_pickle.REVERSE_NAME_MAPPING: - module, name = _compat_pickle.REVERSE_NAME_MAPPING[(module, name)] - if module in _compat_pickle.REVERSE_IMPORT_MAPPING: - module = _compat_pickle.REVERSE_IMPORT_MAPPING[module] + r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING + r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING + if (module_name, name) in r_name_mapping: + module_name, name = r_name_mapping[(module_name, name)] + if module_name in r_import_mapping: + module_name = r_import_mapping[module_name] try: - write(GLOBAL + bytes(module, "ascii") + b'\n' + + write(GLOBAL + bytes(module_name, "ascii") + b'\n' + bytes(name, "ascii") + b'\n') except UnicodeEncodeError: raise PicklingError( @@ -703,40 +979,16 @@ class _Pickler: self.memoize(obj) + def save_method(self, obj): + if obj.__self__ is None or type(obj.__self__) is ModuleType: + self.save_global(obj) + else: + self.save_reduce(getattr, (obj.__self__, obj.__name__), obj=obj) + dispatch[FunctionType] = save_global - dispatch[BuiltinFunctionType] = save_global + dispatch[BuiltinFunctionType] = save_method dispatch[type] = save_global -# A cache for whichmodule(), mapping a function object to the name of -# the module in which the function was found. - -classmap = {} # called classmap for backwards compatibility - -def whichmodule(func, funcname): - """Figure out the module in which a function occurs. - - Search sys.modules for the module. - Cache in classmap. - Return a module name. - If the function cannot be found, return "__main__". - """ - # Python functions should always get an __module__ from their globals. - mod = getattr(func, "__module__", None) - if mod is not None: - return mod - if func in classmap: - return classmap[func] - - for name, module in list(sys.modules.items()): - if module is None: - continue # skip dummy package entries - if name != '__main__' and getattr(module, funcname, None) is func: - break - else: - name = '__main__' - classmap[func] = name - return name - # Unpickling machinery @@ -764,8 +1016,8 @@ class _Unpickler: instances pickled by Python 2.x; these default to 'ASCII' and 'strict', respectively. """ - self.readline = file.readline - self.read = file.read + self._file_readline = file.readline + self._file_read = file.read self.memo = {} self.encoding = encoding self.errors = errors @@ -779,12 +1031,16 @@ class _Unpickler: """ # Check whether Unpickler was initialized correctly. This is # only needed to mimic the behavior of _pickle.Unpickler.dump(). - if not hasattr(self, "read"): + if not hasattr(self, "_file_read"): raise UnpicklingError("Unpickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,)) + self._unframer = _Unframer(self._file_read, self._file_readline) + self.read = self._unframer.read + self.readline = self._unframer.readline self.mark = object() # any new unique object self.stack = [] self.append = self.stack.append + self.proto = 0 read = self.read dispatch = self.dispatch try: @@ -822,6 +1078,8 @@ class _Unpickler: if not 0 <= proto <= HIGHEST_PROTOCOL: raise ValueError("unsupported pickle protocol: %d" % proto) self.proto = proto + if proto >= 4: + self._unframer.framing_enabled = True dispatch[PROTO[0]] = load_proto def load_persid(self): @@ -940,6 +1198,14 @@ class _Unpickler: self.append(str(self.read(len), 'utf-8', 'surrogatepass')) dispatch[BINUNICODE[0]] = load_binunicode + def load_binunicode8(self): + len, = unpack(' maxsize: + raise UnpicklingError("BINUNICODE8 exceeds system's maximum size " + "of %d bytes" % maxsize) + self.append(str(self.read(len), 'utf-8', 'surrogatepass')) + dispatch[BINUNICODE8[0]] = load_binunicode8 + def load_short_binstring(self): len = self.read(1)[0] data = self.read(len) @@ -952,6 +1218,11 @@ class _Unpickler: self.append(self.read(len)) dispatch[SHORT_BINBYTES[0]] = load_short_binbytes + def load_short_binunicode(self): + len = self.read(1)[0] + self.append(str(self.read(len), 'utf-8', 'surrogatepass')) + dispatch[SHORT_BINUNICODE[0]] = load_short_binunicode + def load_tuple(self): k = self.marker() self.stack[k:] = [tuple(self.stack[k+1:])] @@ -981,6 +1252,15 @@ class _Unpickler: self.append({}) dispatch[EMPTY_DICT[0]] = load_empty_dictionary + def load_empty_set(self): + self.append(set()) + dispatch[EMPTY_SET[0]] = load_empty_set + + def load_frozenset(self): + k = self.marker() + self.stack[k:] = [frozenset(self.stack[k+1:])] + dispatch[FROZENSET[0]] = load_frozenset + def load_list(self): k = self.marker() self.stack[k:] = [self.stack[k+1:]] @@ -1029,11 +1309,19 @@ class _Unpickler: def load_newobj(self): args = self.stack.pop() - cls = self.stack[-1] + cls = self.stack.pop() obj = cls.__new__(cls, *args) - self.stack[-1] = obj + self.append(obj) dispatch[NEWOBJ[0]] = load_newobj + def load_newobj_ex(self): + kwargs = self.stack.pop() + args = self.stack.pop() + cls = self.stack.pop() + obj = cls.__new__(cls, *args, **kwargs) + self.append(obj) + dispatch[NEWOBJ_EX[0]] = load_newobj_ex + def load_global(self): module = self.readline()[:-1].decode("utf-8") name = self.readline()[:-1].decode("utf-8") @@ -1041,6 +1329,14 @@ class _Unpickler: self.append(klass) dispatch[GLOBAL[0]] = load_global + def load_stack_global(self): + name = self.stack.pop() + module = self.stack.pop() + if type(name) is not str or type(module) is not str: + raise UnpicklingError("STACK_GLOBAL requires str") + self.append(self.find_class(module, name)) + dispatch[STACK_GLOBAL[0]] = load_stack_global + def load_ext1(self): code = self.read(1)[0] self.get_extension(code) @@ -1080,9 +1376,8 @@ class _Unpickler: if module in _compat_pickle.IMPORT_MAPPING: module = _compat_pickle.IMPORT_MAPPING[module] __import__(module, level=0) - mod = sys.modules[module] - klass = getattr(mod, name) - return klass + return _getattribute(sys.modules[module], name, + allow_qualname=self.proto >= 4) def load_reduce(self): stack = self.stack @@ -1146,6 +1441,11 @@ class _Unpickler: self.memo[i] = self.stack[-1] dispatch[LONG_BINPUT[0]] = load_long_binput + def load_memoize(self): + memo = self.memo + memo[len(memo)] = self.stack[-1] + dispatch[MEMOIZE[0]] = load_memoize + def load_append(self): stack = self.stack value = stack.pop() @@ -1185,6 +1485,20 @@ class _Unpickler: del stack[mark:] dispatch[SETITEMS[0]] = load_setitems + def load_additems(self): + stack = self.stack + mark = self.marker() + set_obj = stack[mark - 1] + items = stack[mark + 1:] + if isinstance(set_obj, set): + set_obj.update(items) + else: + add = set_obj.add + for item in items: + add(item) + del stack[mark:] + dispatch[ADDITEMS[0]] = load_additems + def load_build(self): stack = self.stack state = stack.pop() @@ -1218,86 +1532,46 @@ class _Unpickler: raise _Stop(value) dispatch[STOP[0]] = load_stop -# Encode/decode ints. - -def encode_long(x): - r"""Encode a long to a two's complement little-endian binary string. - Note that 0 is a special case, returning an empty string, to save a - byte in the LONG1 pickling context. - - >>> encode_long(0) - b'' - >>> encode_long(255) - b'\xff\x00' - >>> encode_long(32767) - b'\xff\x7f' - >>> encode_long(-256) - b'\x00\xff' - >>> encode_long(-32768) - b'\x00\x80' - >>> encode_long(-128) - b'\x80' - >>> encode_long(127) - b'\x7f' - >>> - """ - if x == 0: - return b'' - nbytes = (x.bit_length() >> 3) + 1 - result = x.to_bytes(nbytes, byteorder='little', signed=True) - if x < 0 and nbytes > 1: - if result[-1] == 0xff and (result[-2] & 0x80) != 0: - result = result[:-1] - return result - -def decode_long(data): - r"""Decode an int from a two's complement little-endian binary string. - - >>> decode_long(b'') - 0 - >>> decode_long(b"\xff\x00") - 255 - >>> decode_long(b"\xff\x7f") - 32767 - >>> decode_long(b"\x00\xff") - -256 - >>> decode_long(b"\x00\x80") - -32768 - >>> decode_long(b"\x80") - -128 - >>> decode_long(b"\x7f") - 127 - """ - return int.from_bytes(data, byteorder='little', signed=True) # Shorthands -def dump(obj, file, protocol=None, *, fix_imports=True): - Pickler(file, protocol, fix_imports=fix_imports).dump(obj) +def _dump(obj, file, protocol=None, *, fix_imports=True): + _Pickler(file, protocol, fix_imports=fix_imports).dump(obj) -def dumps(obj, protocol=None, *, fix_imports=True): +def _dumps(obj, protocol=None, *, fix_imports=True): f = io.BytesIO() - Pickler(f, protocol, fix_imports=fix_imports).dump(obj) + _Pickler(f, protocol, fix_imports=fix_imports).dump(obj) res = f.getvalue() assert isinstance(res, bytes_types) return res -def load(file, *, fix_imports=True, encoding="ASCII", errors="strict"): - return Unpickler(file, fix_imports=fix_imports, +def _load(file, *, fix_imports=True, encoding="ASCII", errors="strict"): + return _Unpickler(file, fix_imports=fix_imports, encoding=encoding, errors=errors).load() -def loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"): +def _loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"): if isinstance(s, str): raise TypeError("Can't load pickle from unicode string") file = io.BytesIO(s) - return Unpickler(file, fix_imports=fix_imports, - encoding=encoding, errors=errors).load() + return _Unpickler(file, fix_imports=fix_imports, + encoding=encoding, errors=errors).load() # Use the faster _pickle if possible try: - from _pickle import * + from _pickle import ( + PickleError, + PicklingError, + UnpicklingError, + Pickler, + Unpickler, + dump, + dumps, + load, + loads + ) except ImportError: Pickler, Unpickler = _Pickler, _Unpickler + dump, dumps, load, loads = _dump, _dumps, _load, _loads # Doctest def _test(): diff --git a/Lib/pickletools.py b/Lib/pickletools.py index e92146d..d711bf0 100644 --- a/Lib/pickletools.py +++ b/Lib/pickletools.py @@ -11,6 +11,7 @@ dis(pickle, out=None, memo=None, indentlevel=4) ''' import codecs +import io import pickle import re import sys @@ -168,6 +169,7 @@ UP_TO_NEWLINE = -1 TAKEN_FROM_ARGUMENT1 = -2 # num bytes is 1-byte unsigned int TAKEN_FROM_ARGUMENT4 = -3 # num bytes is 4-byte signed little-endian int TAKEN_FROM_ARGUMENT4U = -4 # num bytes is 4-byte unsigned little-endian int +TAKEN_FROM_ARGUMENT8U = -5 # num bytes is 8-byte unsigned little-endian int class ArgumentDescriptor(object): __slots__ = ( @@ -175,7 +177,7 @@ class ArgumentDescriptor(object): 'name', # length of argument, in bytes; an int; UP_TO_NEWLINE and - # TAKEN_FROM_ARGUMENT{1,4} are negative values for variable-length + # TAKEN_FROM_ARGUMENT{1,4,8} are negative values for variable-length # cases 'n', @@ -196,7 +198,8 @@ class ArgumentDescriptor(object): n in (UP_TO_NEWLINE, TAKEN_FROM_ARGUMENT1, TAKEN_FROM_ARGUMENT4, - TAKEN_FROM_ARGUMENT4U)) + TAKEN_FROM_ARGUMENT4U, + TAKEN_FROM_ARGUMENT8U)) self.n = n self.reader = reader @@ -288,6 +291,27 @@ uint4 = ArgumentDescriptor( doc="Four-byte unsigned integer, little-endian.") +def read_uint8(f): + r""" + >>> import io + >>> read_uint8(io.BytesIO(b'\xff\x00\x00\x00\x00\x00\x00\x00')) + 255 + >>> read_uint8(io.BytesIO(b'\xff' * 8)) == 2**64-1 + True + """ + + data = f.read(8) + if len(data) == 8: + return _unpack(">> import io @@ -381,6 +405,36 @@ stringnl_noescape_pair = ArgumentDescriptor( a single blank separating the two strings. """) + +def read_string1(f): + r""" + >>> import io + >>> read_string1(io.BytesIO(b"\x00")) + '' + >>> read_string1(io.BytesIO(b"\x03abcdef")) + 'abc' + """ + + n = read_uint1(f) + assert n >= 0 + data = f.read(n) + if len(data) == n: + return data.decode("latin-1") + raise ValueError("expected %d bytes in a string1, but only %d remain" % + (n, len(data))) + +string1 = ArgumentDescriptor( + name="string1", + n=TAKEN_FROM_ARGUMENT1, + reader=read_string1, + doc="""A counted string. + + The first argument is a 1-byte unsigned int giving the number + of bytes in the string, and the second argument is that many + bytes. + """) + + def read_string4(f): r""" >>> import io @@ -415,28 +469,28 @@ string4 = ArgumentDescriptor( """) -def read_string1(f): +def read_bytes1(f): r""" >>> import io - >>> read_string1(io.BytesIO(b"\x00")) - '' - >>> read_string1(io.BytesIO(b"\x03abcdef")) - 'abc' + >>> read_bytes1(io.BytesIO(b"\x00")) + b'' + >>> read_bytes1(io.BytesIO(b"\x03abcdef")) + b'abc' """ n = read_uint1(f) assert n >= 0 data = f.read(n) if len(data) == n: - return data.decode("latin-1") - raise ValueError("expected %d bytes in a string1, but only %d remain" % + return data + raise ValueError("expected %d bytes in a bytes1, but only %d remain" % (n, len(data))) -string1 = ArgumentDescriptor( - name="string1", +bytes1 = ArgumentDescriptor( + name="bytes1", n=TAKEN_FROM_ARGUMENT1, - reader=read_string1, - doc="""A counted string. + reader=read_bytes1, + doc="""A counted bytes string. The first argument is a 1-byte unsigned int giving the number of bytes in the string, and the second argument is that many @@ -486,6 +540,7 @@ def read_bytes4(f): """ n = read_uint4(f) + assert n >= 0 if n > sys.maxsize: raise ValueError("bytes4 byte count > sys.maxsize: %d" % n) data = f.read(n) @@ -505,6 +560,39 @@ bytes4 = ArgumentDescriptor( """) +def read_bytes8(f): + r""" + >>> import io + >>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x00\x00abc")) + b'' + >>> read_bytes8(io.BytesIO(b"\x03\x00\x00\x00\x00\x00\x00\x00abcdef")) + b'abc' + >>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x03\x00abcdef")) + Traceback (most recent call last): + ... + ValueError: expected 844424930131968 bytes in a bytes8, but only 6 remain + """ + + n = read_uint8(f) + assert n >= 0 + if n > sys.maxsize: + raise ValueError("bytes8 byte count > sys.maxsize: %d" % n) + data = f.read(n) + if len(data) == n: + return data + raise ValueError("expected %d bytes in a bytes8, but only %d remain" % + (n, len(data))) + +bytes8 = ArgumentDescriptor( + name="bytes8", + n=TAKEN_FROM_ARGUMENT8U, + reader=read_bytes8, + doc="""A counted bytes string. + + The first argument is a 8-byte little-endian unsigned int giving + the number of bytes, and the second argument is that many bytes. + """) + def read_unicodestringnl(f): r""" >>> import io @@ -530,6 +618,46 @@ unicodestringnl = ArgumentDescriptor( escape sequences. """) + +def read_unicodestring1(f): + r""" + >>> import io + >>> s = 'abcd\uabcd' + >>> enc = s.encode('utf-8') + >>> enc + b'abcd\xea\xaf\x8d' + >>> n = bytes([len(enc)]) # little-endian 1-byte length + >>> t = read_unicodestring1(io.BytesIO(n + enc + b'junk')) + >>> s == t + True + + >>> read_unicodestring1(io.BytesIO(n + enc[:-1])) + Traceback (most recent call last): + ... + ValueError: expected 7 bytes in a unicodestring1, but only 6 remain + """ + + n = read_uint1(f) + assert n >= 0 + data = f.read(n) + if len(data) == n: + return str(data, 'utf-8', 'surrogatepass') + raise ValueError("expected %d bytes in a unicodestring1, but only %d " + "remain" % (n, len(data))) + +unicodestring1 = ArgumentDescriptor( + name="unicodestring1", + n=TAKEN_FROM_ARGUMENT1, + reader=read_unicodestring1, + doc="""A counted Unicode string. + + The first argument is a 1-byte little-endian signed int + giving the number of bytes in the string, and the second + argument-- the UTF-8 encoding of the Unicode string -- + contains that many bytes. + """) + + def read_unicodestring4(f): r""" >>> import io @@ -549,6 +677,7 @@ def read_unicodestring4(f): """ n = read_uint4(f) + assert n >= 0 if n > sys.maxsize: raise ValueError("unicodestring4 byte count > sys.maxsize: %d" % n) data = f.read(n) @@ -570,6 +699,47 @@ unicodestring4 = ArgumentDescriptor( """) +def read_unicodestring8(f): + r""" + >>> import io + >>> s = 'abcd\uabcd' + >>> enc = s.encode('utf-8') + >>> enc + b'abcd\xea\xaf\x8d' + >>> n = bytes([len(enc)]) + bytes(7) # little-endian 8-byte length + >>> t = read_unicodestring8(io.BytesIO(n + enc + b'junk')) + >>> s == t + True + + >>> read_unicodestring8(io.BytesIO(n + enc[:-1])) + Traceback (most recent call last): + ... + ValueError: expected 7 bytes in a unicodestring8, but only 6 remain + """ + + n = read_uint8(f) + assert n >= 0 + if n > sys.maxsize: + raise ValueError("unicodestring8 byte count > sys.maxsize: %d" % n) + data = f.read(n) + if len(data) == n: + return str(data, 'utf-8', 'surrogatepass') + raise ValueError("expected %d bytes in a unicodestring8, but only %d " + "remain" % (n, len(data))) + +unicodestring8 = ArgumentDescriptor( + name="unicodestring8", + n=TAKEN_FROM_ARGUMENT8U, + reader=read_unicodestring8, + doc="""A counted Unicode string. + + The first argument is a 8-byte little-endian signed int + giving the number of bytes in the string, and the second + argument-- the UTF-8 encoding of the Unicode string -- + contains that many bytes. + """) + + def read_decimalnl_short(f): r""" >>> import io @@ -859,6 +1029,16 @@ pydict = StackObject( obtype=dict, doc="A Python dict object.") +pyset = StackObject( + name="set", + obtype=set, + doc="A Python set object.") + +pyfrozenset = StackObject( + name="frozenset", + obtype=set, + doc="A Python frozenset object.") + anyobject = StackObject( name='any', obtype=object, @@ -1142,6 +1322,19 @@ opcodes = [ literally as the string content. """), + I(name='BINBYTES8', + code='\x8e', + arg=bytes8, + stack_before=[], + stack_after=[pybytes], + proto=4, + doc="""Push a Python bytes object. + + There are two arguments: the first is a 8-byte unsigned int giving + the number of bytes in the string, and the second is that many bytes, + which are taken literally as the string content. + """), + # Ways to spell None. I(name='NONE', @@ -1190,6 +1383,19 @@ opcodes = [ until the next newline character. """), + I(name='SHORT_BINUNICODE', + code='\x8c', + arg=unicodestring1, + stack_before=[], + stack_after=[pyunicode], + proto=4, + doc="""Push a Python Unicode string object. + + There are two arguments: the first is a 1-byte little-endian signed int + giving the number of bytes in the string. The second is that many + bytes, and is the UTF-8 encoding of the Unicode string. + """), + I(name='BINUNICODE', code='X', arg=unicodestring4, @@ -1203,6 +1409,19 @@ opcodes = [ bytes, and is the UTF-8 encoding of the Unicode string. """), + I(name='BINUNICODE8', + code='\x8d', + arg=unicodestring8, + stack_before=[], + stack_after=[pyunicode], + proto=4, + doc="""Push a Python Unicode string object. + + There are two arguments: the first is a 8-byte little-endian signed int + giving the number of bytes in the string. The second is that many + bytes, and is the UTF-8 encoding of the Unicode string. + """), + # Ways to spell floats. I(name='FLOAT', @@ -1428,6 +1647,54 @@ opcodes = [ 1, 2, ..., n, and in that order. """), + # Ways to build sets + + I(name='EMPTY_SET', + code='\x8f', + arg=None, + stack_before=[], + stack_after=[pyset], + proto=4, + doc="Push an empty set."), + + I(name='ADDITEMS', + code='\x90', + arg=None, + stack_before=[pyset, markobject, stackslice], + stack_after=[pyset], + proto=4, + doc="""Add an arbitrary number of items to an existing set. + + The slice of the stack following the topmost markobject is taken as + a sequence of items, added to the set immediately under the topmost + markobject. Everything at and after the topmost markobject is popped, + leaving the mutated set at the top of the stack. + + Stack before: ... pyset markobject item_1 ... item_n + Stack after: ... pyset + + where pyset has been modified via pyset.add(item_i) = item_i for i in + 1, 2, ..., n, and in that order. + """), + + # Way to build frozensets + + I(name='FROZENSET', + code='\x91', + arg=None, + stack_before=[markobject, stackslice], + stack_after=[pyfrozenset], + proto=4, + doc="""Build a frozenset out of the topmost slice, after markobject. + + All the stack entries following the topmost markobject are placed into + a single Python frozenset, which single frozenset object replaces all + of the stack from the topmost markobject onward. For example, + + Stack before: ... markobject 1 2 3 + Stack after: ... frozenset({1, 2, 3}) + """), + # Stack manipulation. I(name='POP', @@ -1549,6 +1816,18 @@ opcodes = [ unsigned little-endian integer following. """), + I(name='MEMOIZE', + code='\x94', + arg=None, + stack_before=[anyobject], + stack_after=[anyobject], + proto=4, + doc="""Store the stack top into the memo. The stack is not popped. + + The index of the memo location to write is the number of + elements currently present in the memo. + """), + # Access the extension registry (predefined objects). Akin to the GET # family. @@ -1614,6 +1893,15 @@ opcodes = [ stack, so unpickling subclasses can override this form of lookup. """), + I(name='STACK_GLOBAL', + code='\x93', + arg=None, + stack_before=[pyunicode, pyunicode], + stack_after=[anyobject], + proto=0, + doc="""Push a global object (module.attr) on the stack. + """), + # Ways to build objects of classes pickle doesn't know about directly # (user-defined classes). I despair of documenting this accurately # and comprehensibly -- you really have to read the pickle code to @@ -1770,6 +2058,21 @@ opcodes = [ onto the stack. """), + I(name='NEWOBJ_EX', + code='\x92', + arg=None, + stack_before=[anyobject, anyobject, anyobject], + stack_after=[anyobject], + proto=4, + doc="""Build an object instance. + + The stack before should be thought of as containing a class + object followed by an argument tuple and by a keyword argument dict + (the dict being the stack top). Call these cls and args. They are + popped off the stack, and the value returned by + cls.__new__(cls, *args, *kwargs) is pushed back onto the stack. + """), + # Machine control. I(name='PROTO', @@ -1797,6 +2100,20 @@ opcodes = [ empty then. """), + # Framing support. + + I(name='FRAME', + code='\x95', + arg=uint8, + stack_before=[], + stack_after=[], + proto=4, + doc="""Indicate the beginning of a new frame. + + The unpickler may use this opcode to safely prefetch data from its + underlying stream. + """), + # Ways to deal with persistent IDs. I(name='PERSID', @@ -1903,6 +2220,38 @@ del assure_pickle_consistency ############################################################################## # A pickle opcode generator. +def _genops(data, yield_end_pos=False): + if isinstance(data, bytes_types): + data = io.BytesIO(data) + + if hasattr(data, "tell"): + getpos = data.tell + else: + getpos = lambda: None + + while True: + pos = getpos() + code = data.read(1) + opcode = code2op.get(code.decode("latin-1")) + if opcode is None: + if code == b"": + raise ValueError("pickle exhausted before seeing STOP") + else: + raise ValueError("at position %s, opcode %r unknown" % ( + "" if pos is None else pos, + code)) + if opcode.arg is None: + arg = None + else: + arg = opcode.arg.reader(data) + if yield_end_pos: + yield opcode, arg, pos, getpos() + else: + yield opcode, arg, pos + if code == b'.': + assert opcode.name == 'STOP' + break + def genops(pickle): """Generate all the opcodes in a pickle. @@ -1926,62 +2275,47 @@ def genops(pickle): used. Else (the pickle doesn't have a tell(), and it's not obvious how to query its current position) pos is None. """ - - if isinstance(pickle, bytes_types): - import io - pickle = io.BytesIO(pickle) - - if hasattr(pickle, "tell"): - getpos = pickle.tell - else: - getpos = lambda: None - - while True: - pos = getpos() - code = pickle.read(1) - opcode = code2op.get(code.decode("latin-1")) - if opcode is None: - if code == b"": - raise ValueError("pickle exhausted before seeing STOP") - else: - raise ValueError("at position %s, opcode %r unknown" % ( - pos is None and "" or pos, - code)) - if opcode.arg is None: - arg = None - else: - arg = opcode.arg.reader(pickle) - yield opcode, arg, pos - if code == b'.': - assert opcode.name == 'STOP' - break + return _genops(pickle) ############################################################################## # A pickle optimizer. def optimize(p): 'Optimize a pickle string by removing unused PUT opcodes' - gets = set() # set of args used by a GET opcode - puts = [] # (arg, startpos, stoppos) for the PUT opcodes - prevpos = None # set to pos if previous opcode was a PUT - for opcode, arg, pos in genops(p): - if prevpos is not None: - puts.append((prevarg, prevpos, pos)) - prevpos = None + not_a_put = object() + gets = { not_a_put } # set of args used by a GET opcode + opcodes = [] # (startpos, stoppos, putid) + proto = 0 + for opcode, arg, pos, end_pos in _genops(p, yield_end_pos=True): if 'PUT' in opcode.name: - prevarg, prevpos = arg, pos - elif 'GET' in opcode.name: - gets.add(arg) - - # Copy the pickle string except for PUTS without a corresponding GET - s = [] - i = 0 - for arg, start, stop in puts: - j = stop if (arg in gets) else start - s.append(p[i:j]) - i = stop - s.append(p[i:]) - return b''.join(s) + opcodes.append((pos, end_pos, arg)) + elif 'FRAME' in opcode.name: + pass + else: + if 'GET' in opcode.name: + gets.add(arg) + elif opcode.name == 'PROTO': + assert pos == 0, pos + proto = arg + opcodes.append((pos, end_pos, not_a_put)) + prevpos, prevarg = pos, None + + # Copy the opcodes except for PUTS without a corresponding GET + out = io.BytesIO() + opcodes = iter(opcodes) + if proto >= 2: + # Write the PROTO header before any framing + start, stop, _ = next(opcodes) + out.write(p[start:stop]) + buf = pickle._Framer(out.write) + if proto >= 4: + buf.start_framing() + for start, stop, putid in opcodes: + if putid in gets: + buf.write(p[start:stop]) + if proto >= 4: + buf.end_framing() + return out.getvalue() ############################################################################## # A symbolic pickle disassembler. @@ -2081,17 +2415,20 @@ def dis(pickle, out=None, memo=None, indentlevel=4, annotate=0): errormsg = markmsg = "no MARK exists on stack" # Check for correct memo usage. - if opcode.name in ("PUT", "BINPUT", "LONG_BINPUT"): - assert arg is not None - if arg in memo: + if opcode.name in ("PUT", "BINPUT", "LONG_BINPUT", "MEMOIZE"): + if opcode.name == "MEMOIZE": + memo_idx = len(memo) + else: + assert arg is not None + memo_idx = arg + if memo_idx in memo: errormsg = "memo key %r already defined" % arg elif not stack: errormsg = "stack is empty -- can't store into memo" elif stack[-1] is markobject: errormsg = "can't store markobject in the memo" else: - memo[arg] = stack[-1] - + memo[memo_idx] = stack[-1] elif opcode.name in ("GET", "BINGET", "LONG_BINGET"): if arg in memo: assert len(after) == 1 diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index 1971120..cadc5a7 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -1,9 +1,10 @@ +import copyreg import io -import unittest import pickle import pickletools +import random import sys -import copyreg +import unittest import weakref from http.cookies import SimpleCookie @@ -95,6 +96,9 @@ class E(C): def __getinitargs__(self): return () +class H(object): + pass + import __main__ __main__.C = C C.__module__ = "__main__" @@ -102,6 +106,8 @@ __main__.D = D D.__module__ = "__main__" __main__.E = E E.__module__ = "__main__" +__main__.H = H +H.__module__ = "__main__" class myint(int): def __init__(self, x): @@ -428,6 +434,7 @@ def create_data(): x.append(5) return x + class AbstractPickleTests(unittest.TestCase): # Subclass must define self.dumps, self.loads. @@ -436,23 +443,41 @@ class AbstractPickleTests(unittest.TestCase): def setUp(self): pass + def assert_is_copy(self, obj, objcopy, msg=None): + """Utility method to verify if two objects are copies of each others. + """ + if msg is None: + msg = "{!r} is not a copy of {!r}".format(obj, objcopy) + self.assertEqual(obj, objcopy, msg=msg) + self.assertIs(type(obj), type(objcopy), msg=msg) + if hasattr(obj, '__dict__'): + self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg) + self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg) + if hasattr(obj, '__slots__'): + self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg) + for slot in obj.__slots__: + self.assertEqual( + hasattr(obj, slot), hasattr(objcopy, slot), msg=msg) + self.assertEqual(getattr(obj, slot, None), + getattr(objcopy, slot, None), msg=msg) + def test_misc(self): # test various datatypes not tested by testdata for proto in protocols: x = myint(4) s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) x = (1, ()) s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) x = initarg(1, x) s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) # XXX test __reduce__ protocol? @@ -461,16 +486,16 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(expected, proto) got = self.loads(s) - self.assertEqual(expected, got) + self.assert_is_copy(expected, got) def test_load_from_data0(self): - self.assertEqual(self._testdata, self.loads(DATA0)) + self.assert_is_copy(self._testdata, self.loads(DATA0)) def test_load_from_data1(self): - self.assertEqual(self._testdata, self.loads(DATA1)) + self.assert_is_copy(self._testdata, self.loads(DATA1)) def test_load_from_data2(self): - self.assertEqual(self._testdata, self.loads(DATA2)) + self.assert_is_copy(self._testdata, self.loads(DATA2)) def test_load_classic_instance(self): # See issue5180. Test loading 2.x pickles that @@ -492,7 +517,7 @@ class AbstractPickleTests(unittest.TestCase): b"X\n" b"p0\n" b"(dp1\nb.").replace(b'X', xname) - self.assertEqual(X(*args), self.loads(pickle0)) + self.assert_is_copy(X(*args), self.loads(pickle0)) # Protocol 1 (binary mode pickle) """ @@ -509,7 +534,7 @@ class AbstractPickleTests(unittest.TestCase): pickle1 = (b'(c__main__\n' b'X\n' b'q\x00oq\x01}q\x02b.').replace(b'X', xname) - self.assertEqual(X(*args), self.loads(pickle1)) + self.assert_is_copy(X(*args), self.loads(pickle1)) # Protocol 2 (pickle2 = b'\x80\x02' + pickle1) """ @@ -527,7 +552,7 @@ class AbstractPickleTests(unittest.TestCase): pickle2 = (b'\x80\x02(c__main__\n' b'X\n' b'q\x00oq\x01}q\x02b.').replace(b'X', xname) - self.assertEqual(X(*args), self.loads(pickle2)) + self.assert_is_copy(X(*args), self.loads(pickle2)) # There are gratuitous differences between pickles produced by # pickle and cPickle, largely because cPickle starts PUT indices at @@ -552,6 +577,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(l, proto) x = self.loads(s) + self.assertIsInstance(x, list) self.assertEqual(len(x), 1) self.assertTrue(x is x[0]) @@ -561,6 +587,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(t, proto) x = self.loads(s) + self.assertIsInstance(x, tuple) self.assertEqual(len(x), 1) self.assertEqual(len(x[0]), 1) self.assertTrue(x is x[0][0]) @@ -571,15 +598,39 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(d, proto) x = self.loads(s) + self.assertIsInstance(x, dict) self.assertEqual(list(x.keys()), [1]) self.assertTrue(x[1] is x) + def test_recursive_set(self): + h = H() + y = set({h}) + h.attr = y + for proto in protocols: + s = self.dumps(y, proto) + x = self.loads(s) + self.assertIsInstance(x, set) + self.assertIs(list(x)[0].attr, x) + self.assertEqual(len(x), 1) + + def test_recursive_frozenset(self): + h = H() + y = frozenset({h}) + h.attr = y + for proto in protocols: + s = self.dumps(y, proto) + x = self.loads(s) + self.assertIsInstance(x, frozenset) + self.assertIs(list(x)[0].attr, x) + self.assertEqual(len(x), 1) + def test_recursive_inst(self): i = C() i.attr = i for proto in protocols: s = self.dumps(i, proto) x = self.loads(s) + self.assertIsInstance(x, C) self.assertEqual(dir(x), dir(i)) self.assertIs(x.attr, x) @@ -592,6 +643,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(l, proto) x = self.loads(s) + self.assertIsInstance(x, list) self.assertEqual(len(x), 1) self.assertEqual(dir(x[0]), dir(i)) self.assertEqual(list(x[0].attr.keys()), [1]) @@ -599,7 +651,8 @@ class AbstractPickleTests(unittest.TestCase): def test_get(self): self.assertRaises(KeyError, self.loads, b'g0\np0') - self.assertEqual(self.loads(b'((Kdtp0\nh\x00l.))'), [(100,), (100,)]) + self.assert_is_copy([(100,), (100,)], + self.loads(b'((Kdtp0\nh\x00l.))')) def test_unicode(self): endcases = ['', '<\\u>', '<\\\u1234>', '<\n>', @@ -610,26 +663,26 @@ class AbstractPickleTests(unittest.TestCase): for u in endcases: p = self.dumps(u, proto) u2 = self.loads(p) - self.assertEqual(u2, u) + self.assert_is_copy(u, u2) def test_unicode_high_plane(self): t = '\U00012345' for proto in protocols: p = self.dumps(t, proto) t2 = self.loads(p) - self.assertEqual(t2, t) + self.assert_is_copy(t, t2) def test_bytes(self): for proto in protocols: for s in b'', b'xyz', b'xyz'*100: p = self.dumps(s, proto) - self.assertEqual(self.loads(p), s) + self.assert_is_copy(s, self.loads(p)) for s in [bytes([i]) for i in range(256)]: p = self.dumps(s, proto) - self.assertEqual(self.loads(p), s) + self.assert_is_copy(s, self.loads(p)) for s in [bytes([i, i]) for i in range(256)]: p = self.dumps(s, proto) - self.assertEqual(self.loads(p), s) + self.assert_is_copy(s, self.loads(p)) def test_ints(self): import sys @@ -639,14 +692,14 @@ class AbstractPickleTests(unittest.TestCase): for expected in (-n, n): s = self.dumps(expected, proto) n2 = self.loads(s) - self.assertEqual(expected, n2) + self.assert_is_copy(expected, n2) n = n >> 1 def test_maxint64(self): maxint64 = (1 << 63) - 1 data = b'I' + str(maxint64).encode("ascii") + b'\n.' got = self.loads(data) - self.assertEqual(got, maxint64) + self.assert_is_copy(maxint64, got) # Try too with a bogus literal. data = b'I' + str(maxint64).encode("ascii") + b'JUNK\n.' @@ -661,7 +714,7 @@ class AbstractPickleTests(unittest.TestCase): for n in npos, -npos: pickle = self.dumps(n, proto) got = self.loads(pickle) - self.assertEqual(n, got) + self.assert_is_copy(n, got) # Try a monster. This is quadratic-time in protos 0 & 1, so don't # bother with those. nbase = int("deadbeeffeedface", 16) @@ -669,7 +722,7 @@ class AbstractPickleTests(unittest.TestCase): for n in nbase, -nbase: p = self.dumps(n, 2) got = self.loads(p) - self.assertEqual(n, got) + self.assert_is_copy(n, got) def test_float(self): test_values = [0.0, 4.94e-324, 1e-310, 7e-308, 6.626e-34, 0.1, 0.5, @@ -679,7 +732,7 @@ class AbstractPickleTests(unittest.TestCase): for value in test_values: pickle = self.dumps(value, proto) got = self.loads(pickle) - self.assertEqual(value, got) + self.assert_is_copy(value, got) @run_with_locale('LC_ALL', 'de_DE', 'fr_FR') def test_float_format(self): @@ -711,6 +764,7 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(a, proto) b = self.loads(s) self.assertEqual(a, b) + self.assertIs(type(a), type(b)) def test_structseq(self): import time @@ -720,48 +774,48 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(t, proto) u = self.loads(s) - self.assertEqual(t, u) + self.assert_is_copy(t, u) if hasattr(os, "stat"): t = os.stat(os.curdir) s = self.dumps(t, proto) u = self.loads(s) - self.assertEqual(t, u) + self.assert_is_copy(t, u) if hasattr(os, "statvfs"): t = os.statvfs(os.curdir) s = self.dumps(t, proto) u = self.loads(s) - self.assertEqual(t, u) + self.assert_is_copy(t, u) def test_ellipsis(self): for proto in protocols: s = self.dumps(..., proto) u = self.loads(s) - self.assertEqual(..., u) + self.assertIs(..., u) def test_notimplemented(self): for proto in protocols: s = self.dumps(NotImplemented, proto) u = self.loads(s) - self.assertEqual(NotImplemented, u) + self.assertIs(NotImplemented, u) # Tests for protocol 2 def test_proto(self): - build_none = pickle.NONE + pickle.STOP for proto in protocols: - expected = build_none + pickled = self.dumps(None, proto) if proto >= 2: - expected = pickle.PROTO + bytes([proto]) + expected - p = self.dumps(None, proto) - self.assertEqual(p, expected) + proto_header = pickle.PROTO + bytes([proto]) + self.assertTrue(pickled.startswith(proto_header)) + else: + self.assertEqual(count_opcode(pickle.PROTO, pickled), 0) oob = protocols[-1] + 1 # a future protocol + build_none = pickle.NONE + pickle.STOP badpickle = pickle.PROTO + bytes([oob]) + build_none try: self.loads(badpickle) - except ValueError as detail: - self.assertTrue(str(detail).startswith( - "unsupported pickle protocol")) + except ValueError as err: + self.assertIn("unsupported pickle protocol", str(err)) else: self.fail("expected bad protocol number to raise ValueError") @@ -770,7 +824,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2) def test_long4(self): @@ -778,7 +832,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2) def test_short_tuples(self): @@ -816,9 +870,9 @@ class AbstractPickleTests(unittest.TestCase): for x in a, b, c, d, e: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y, (proto, x, s, y)) - expected = expected_opcode[proto, len(x)] - self.assertEqual(opcode_in_pickle(expected, s), True) + self.assert_is_copy(x, y) + expected = expected_opcode[min(proto, 3), len(x)] + self.assertTrue(opcode_in_pickle(expected, s)) def test_singletons(self): # Map (proto, singleton) to expected opcode. @@ -842,8 +896,8 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(x, proto) y = self.loads(s) self.assertTrue(x is y, (proto, x, s, y)) - expected = expected_opcode[proto, x] - self.assertEqual(opcode_in_pickle(expected, s), True) + expected = expected_opcode[min(proto, 3), x] + self.assertTrue(opcode_in_pickle(expected, s)) def test_newobj_tuple(self): x = MyTuple([1, 2, 3]) @@ -852,8 +906,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(tuple(x), tuple(y)) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) def test_newobj_list(self): x = MyList([1, 2, 3]) @@ -862,8 +915,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) def test_newobj_generic(self): for proto in protocols: @@ -874,6 +926,7 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(x, proto) y = self.loads(s) detail = (proto, C, B, x, y, type(y)) + self.assert_is_copy(x, y) # XXX revisit self.assertEqual(B(x), B(y), detail) self.assertEqual(x.__dict__, y.__dict__, detail) @@ -912,11 +965,10 @@ class AbstractPickleTests(unittest.TestCase): s1 = self.dumps(x, 1) self.assertIn(__name__.encode("utf-8"), s1) self.assertIn(b"MyList", s1) - self.assertEqual(opcode_in_pickle(opcode, s1), False) + self.assertFalse(opcode_in_pickle(opcode, s1)) y = self.loads(s1) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) # Dump using protocol 2 for test. s2 = self.dumps(x, 2) @@ -925,9 +977,7 @@ class AbstractPickleTests(unittest.TestCase): self.assertEqual(opcode_in_pickle(opcode, s2), True, repr(s2)) y = self.loads(s2) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) - + self.assert_is_copy(x, y) finally: e.restore() @@ -951,7 +1001,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_appends = count_opcode(pickle.APPENDS, s) self.assertEqual(num_appends, proto > 0) @@ -960,7 +1010,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_appends = count_opcode(pickle.APPENDS, s) if proto == 0: self.assertEqual(num_appends, 0) @@ -974,7 +1024,7 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(x, proto) self.assertIsInstance(s, bytes_types) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_setitems = count_opcode(pickle.SETITEMS, s) self.assertEqual(num_setitems, proto > 0) @@ -983,22 +1033,49 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_setitems = count_opcode(pickle.SETITEMS, s) if proto == 0: self.assertEqual(num_setitems, 0) else: self.assertTrue(num_setitems >= 2) + def test_set_chunking(self): + n = 10 # too small to chunk + x = set(range(n)) + for proto in protocols: + s = self.dumps(x, proto) + y = self.loads(s) + self.assert_is_copy(x, y) + num_additems = count_opcode(pickle.ADDITEMS, s) + if proto < 4: + self.assertEqual(num_additems, 0) + else: + self.assertEqual(num_additems, 1) + + n = 2500 # expect at least two chunks when proto >= 4 + x = set(range(n)) + for proto in protocols: + s = self.dumps(x, proto) + y = self.loads(s) + self.assert_is_copy(x, y) + num_additems = count_opcode(pickle.ADDITEMS, s) + if proto < 4: + self.assertEqual(num_additems, 0) + else: + self.assertGreaterEqual(num_additems, 2) + def test_simple_newobj(self): x = object.__new__(SimpleNewObj) # avoid __init__ x.abc = 666 for proto in protocols: s = self.dumps(x, proto) - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2) + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto < 4) + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s), + proto >= 4) y = self.loads(s) # will raise TypeError if __init__ called - self.assertEqual(y.abc, 666) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) def test_newobj_list_slots(self): x = SlotList([1, 2, 3]) @@ -1006,10 +1083,7 @@ class AbstractPickleTests(unittest.TestCase): x.bar = "hello" s = self.dumps(x, 2) y = self.loads(s) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) - self.assertEqual(x.foo, y.foo) - self.assertEqual(x.bar, y.bar) + self.assert_is_copy(x, y) def test_reduce_overrides_default_reduce_ex(self): for proto in protocols: @@ -1058,11 +1132,10 @@ class AbstractPickleTests(unittest.TestCase): @no_tracing def test_bad_getattr(self): + # Issue #3514: crash when there is an infinite loop in __getattr__ x = BadGetattr() - for proto in 0, 1: + for proto in protocols: self.assertRaises(RuntimeError, self.dumps, x, proto) - # protocol 2 don't raise a RuntimeError. - d = self.dumps(x, 2) def test_reduce_bad_iterator(self): # Issue4176: crash when 4th and 5th items of __reduce__() @@ -1095,11 +1168,10 @@ class AbstractPickleTests(unittest.TestCase): obj = [dict(large_dict), dict(large_dict), dict(large_dict)] for proto in protocols: - dumped = self.dumps(obj, proto) - loaded = self.loads(dumped) - self.assertEqual(loaded, obj, - "Failed protocol %d: %r != %r" - % (proto, obj, loaded)) + with self.subTest(proto=proto): + dumped = self.dumps(obj, proto) + loaded = self.loads(dumped) + self.assert_is_copy(obj, loaded) def test_attribute_name_interning(self): # Test that attribute names of pickled objects are interned when @@ -1155,11 +1227,14 @@ class AbstractPickleTests(unittest.TestCase): def test_int_pickling_efficiency(self): # Test compacity of int representation (see issue #12744) for proto in protocols: - sizes = [len(self.dumps(2**n, proto)) for n in range(70)] - # the size function is monotonic - self.assertEqual(sorted(sizes), sizes) - if proto >= 2: - self.assertLessEqual(sizes[-1], 14) + with self.subTest(proto=proto): + pickles = [self.dumps(2**n, proto) for n in range(70)] + sizes = list(map(len, pickles)) + # the size function is monotonic + self.assertEqual(sorted(sizes), sizes) + if proto >= 2: + for p in pickles: + self.assertFalse(opcode_in_pickle(pickle.LONG, p)) def check_negative_32b_binXXX(self, dumped): if sys.maxsize > 2**32: @@ -1242,6 +1317,137 @@ class AbstractPickleTests(unittest.TestCase): else: self._check_pickling_with_opcode(obj, pickle.SETITEMS, proto) + # Exercise framing (proto >= 4) for significant workloads + + FRAME_SIZE_TARGET = 64 * 1024 + + def test_framing_many_objects(self): + obj = list(range(10**5)) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + pickled = self.dumps(obj, proto) + unpickled = self.loads(pickled) + self.assertEqual(obj, unpickled) + # Test the framing heuristic is sane, + # assuming a given frame size target. + bytes_per_frame = (len(pickled) / + pickled.count(b'\x00\x00\x00\x00\x00')) + self.assertGreater(bytes_per_frame, + self.FRAME_SIZE_TARGET / 2) + self.assertLessEqual(bytes_per_frame, + self.FRAME_SIZE_TARGET * 1) + + def test_framing_large_objects(self): + N = 1024 * 1024 + obj = [b'x' * N, b'y' * N, b'z' * N] + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + pickled = self.dumps(obj, proto) + unpickled = self.loads(pickled) + self.assertEqual(obj, unpickled) + # At least one frame was emitted per large bytes object. + n_frames = pickled.count(b'\x00\x00\x00\x00\x00') + self.assertGreaterEqual(n_frames, len(obj)) + + def test_nested_names(self): + global Nested + class Nested: + class A: + class B: + class C: + pass + + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for obj in [Nested.A, Nested.A.B, Nested.A.B.C]: + with self.subTest(proto=proto, obj=obj): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertIs(obj, unpickled) + + def test_py_methods(self): + global PyMethodsTest + class PyMethodsTest: + @staticmethod + def cheese(): + return "cheese" + @classmethod + def wine(cls): + assert cls is PyMethodsTest + return "wine" + def biscuits(self): + assert isinstance(self, PyMethodsTest) + return "biscuits" + class Nested: + "Nested class" + @staticmethod + def ketchup(): + return "ketchup" + @classmethod + def maple(cls): + assert cls is PyMethodsTest.Nested + return "maple" + def pie(self): + assert isinstance(self, PyMethodsTest.Nested) + return "pie" + + py_methods = ( + PyMethodsTest.cheese, + PyMethodsTest.wine, + PyMethodsTest().biscuits, + PyMethodsTest.Nested.ketchup, + PyMethodsTest.Nested.maple, + PyMethodsTest.Nested().pie + ) + py_unbound_methods = ( + (PyMethodsTest.biscuits, PyMethodsTest), + (PyMethodsTest.Nested.pie, PyMethodsTest.Nested) + ) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for method in py_methods: + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(), unpickled()) + for method, cls in py_unbound_methods: + obj = cls() + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(obj), unpickled(obj)) + + + def test_c_methods(self): + global Subclass + class Subclass(tuple): + class Nested(str): + pass + + c_methods = ( + # bound built-in method + ("abcd".index, ("c",)), + # unbound built-in method + (str.index, ("abcd", "c")), + # bound "slot" method + ([1, 2, 3].__len__, ()), + # unbound "slot" method + (list.__len__, ([1, 2, 3],)), + # bound "coexist" method + ({1, 2}.__contains__, (2,)), + # unbound "coexist" method + (set.__contains__, ({1, 2}, 2)), + # built-in class method + (dict.fromkeys, (("a", 1), ("b", 2))), + # built-in static method + (bytearray.maketrans, (b"abc", b"xyz")), + # subclass methods + (Subclass([1,2,2]).count, (2,)), + (Subclass.count, (Subclass([1,2,2]), 2)), + (Subclass.Nested("sweet").count, ("e",)), + (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")), + ) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for method, args in c_methods: + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(*args), unpickled(*args)) + class BigmemPickleTests(unittest.TestCase): @@ -1252,10 +1458,11 @@ class BigmemPickleTests(unittest.TestCase): data = 1 << (8 * size) try: for proto in protocols: - if proto < 2: - continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + if proto < 2: + continue + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) finally: data = None @@ -1268,14 +1475,15 @@ class BigmemPickleTests(unittest.TestCase): data = b"abcd" * (size // 4) try: for proto in protocols: - if proto < 3: - continue - try: - pickled = self.dumps(data, protocol=proto) - self.assertTrue(b"abcd" in pickled[:15]) - self.assertTrue(b"abcd" in pickled[-15:]) - finally: - pickled = None + with self.subTest(proto=proto): + if proto < 3: + continue + try: + pickled = self.dumps(data, protocol=proto) + self.assertTrue(b"abcd" in pickled[:19]) + self.assertTrue(b"abcd" in pickled[-18:]) + finally: + pickled = None finally: data = None @@ -1284,10 +1492,11 @@ class BigmemPickleTests(unittest.TestCase): data = b"a" * size try: for proto in protocols: - if proto < 3: - continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + if proto < 3: + continue + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) finally: data = None @@ -1299,27 +1508,38 @@ class BigmemPickleTests(unittest.TestCase): data = "abcd" * (size // 4) try: for proto in protocols: - try: - pickled = self.dumps(data, protocol=proto) - self.assertTrue(b"abcd" in pickled[:15]) - self.assertTrue(b"abcd" in pickled[-15:]) - finally: - pickled = None + with self.subTest(proto=proto): + try: + pickled = self.dumps(data, protocol=proto) + self.assertTrue(b"abcd" in pickled[:19]) + self.assertTrue(b"abcd" in pickled[-18:]) + finally: + pickled = None finally: data = None - # BINUNICODE (protocols 1, 2 and 3) cannot carry more than - # 2**32 - 1 bytes of utf-8 encoded unicode. + # BINUNICODE (protocols 1, 2 and 3) cannot carry more than 2**32 - 1 bytes + # of utf-8 encoded unicode. BINUNICODE8 (protocol 4) supports these huge + # unicode strings however. - @bigmemtest(size=_4G, memuse=1 + ascii_char_size, dry_run=False) + @bigmemtest(size=_4G, memuse=2 + ascii_char_size, dry_run=False) def test_huge_str_64b(self, size): - data = "a" * size + data = "abcd" * (size // 4) try: for proto in protocols: - if proto == 0: - continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + if proto == 0: + continue + if proto < 4: + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) + else: + try: + pickled = self.dumps(data, protocol=proto) + self.assertTrue(b"abcd" in pickled[:19]) + self.assertTrue(b"abcd" in pickled[-18:]) + finally: + pickled = None finally: data = None @@ -1363,8 +1583,8 @@ class REX_five(object): return object.__reduce__(self) class REX_six(object): - """This class is used to check the 4th argument (list iterator) of the reduce - protocol. + """This class is used to check the 4th argument (list iterator) of + the reduce protocol. """ def __init__(self, items=None): self.items = items if items is not None else [] @@ -1376,8 +1596,8 @@ class REX_six(object): return type(self), (), None, iter(self.items), None class REX_seven(object): - """This class is used to check the 5th argument (dict iterator) of the reduce - protocol. + """This class is used to check the 5th argument (dict iterator) of + the reduce protocol. """ def __init__(self, table=None): self.table = table if table is not None else {} @@ -1415,10 +1635,16 @@ class MyList(list): class MyDict(dict): sample = {"a": 1, "b": 2} +class MySet(set): + sample = {"a", "b"} + +class MyFrozenSet(frozenset): + sample = frozenset({"a", "b"}) + myclasses = [MyInt, MyFloat, MyComplex, MyStr, MyUnicode, - MyTuple, MyList, MyDict] + MyTuple, MyList, MyDict, MySet, MyFrozenSet] class SlotList(MyList): @@ -1428,6 +1654,8 @@ class SimpleNewObj(object): def __init__(self, a, b, c): # raise an error, to make sure this isn't called raise TypeError("SimpleNewObj.__init__() didn't expect to get called") + def __eq__(self, other): + return self.__dict__ == other.__dict__ class BadGetattr: def __getattr__(self, key): @@ -1464,7 +1692,7 @@ class AbstractPickleModuleTests(unittest.TestCase): def test_highest_protocol(self): # Of course this needs to be changed when HIGHEST_PROTOCOL changes. - self.assertEqual(pickle.HIGHEST_PROTOCOL, 3) + self.assertEqual(pickle.HIGHEST_PROTOCOL, 4) def test_callapi(self): f = io.BytesIO() @@ -1645,22 +1873,23 @@ class AbstractPicklerUnpicklerObjectTests(unittest.TestCase): def _check_multiple_unpicklings(self, ioclass): for proto in protocols: - data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len] - f = ioclass() - pickler = self.pickler_class(f, protocol=proto) - pickler.dump(data1) - pickled = f.getvalue() - - N = 5 - f = ioclass(pickled * N) - unpickler = self.unpickler_class(f) - for i in range(N): - if f.seekable(): - pos = f.tell() - self.assertEqual(unpickler.load(), data1) - if f.seekable(): - self.assertEqual(f.tell(), pos + len(pickled)) - self.assertRaises(EOFError, unpickler.load) + with self.subTest(proto=proto): + data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len] + f = ioclass() + pickler = self.pickler_class(f, protocol=proto) + pickler.dump(data1) + pickled = f.getvalue() + + N = 5 + f = ioclass(pickled * N) + unpickler = self.unpickler_class(f) + for i in range(N): + if f.seekable(): + pos = f.tell() + self.assertEqual(unpickler.load(), data1) + if f.seekable(): + self.assertEqual(f.tell(), pos + len(pickled)) + self.assertRaises(EOFError, unpickler.load) def test_multiple_unpicklings_seekable(self): self._check_multiple_unpicklings(io.BytesIO) diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index 5aa4b0a..8ef51de 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -1,8 +1,11 @@ import builtins +import copyreg import gc +import itertools +import math +import pickle import sys import types -import math import unittest import weakref @@ -3153,176 +3156,6 @@ order (MRO) for bases """ self.assertEqual(e.a, 1) self.assertEqual(can_delete_dict(e), can_delete_dict(ValueError())) - def test_pickles(self): - # Testing pickling and copying new-style classes and objects... - import pickle - - def sorteditems(d): - L = list(d.items()) - L.sort() - return L - - global C - class C(object): - def __init__(self, a, b): - super(C, self).__init__() - self.a = a - self.b = b - def __repr__(self): - return "C(%r, %r)" % (self.a, self.b) - - global C1 - class C1(list): - def __new__(cls, a, b): - return super(C1, cls).__new__(cls) - def __getnewargs__(self): - return (self.a, self.b) - def __init__(self, a, b): - self.a = a - self.b = b - def __repr__(self): - return "C1(%r, %r)<%r>" % (self.a, self.b, list(self)) - - global C2 - class C2(int): - def __new__(cls, a, b, val=0): - return super(C2, cls).__new__(cls, val) - def __getnewargs__(self): - return (self.a, self.b, int(self)) - def __init__(self, a, b, val=0): - self.a = a - self.b = b - def __repr__(self): - return "C2(%r, %r)<%r>" % (self.a, self.b, int(self)) - - global C3 - class C3(object): - def __init__(self, foo): - self.foo = foo - def __getstate__(self): - return self.foo - def __setstate__(self, foo): - self.foo = foo - - global C4classic, C4 - class C4classic: # classic - pass - class C4(C4classic, object): # mixed inheritance - pass - - for bin in 0, 1: - for cls in C, C1, C2: - s = pickle.dumps(cls, bin) - cls2 = pickle.loads(s) - self.assertIs(cls2, cls) - - a = C1(1, 2); a.append(42); a.append(24) - b = C2("hello", "world", 42) - s = pickle.dumps((a, b), bin) - x, y = pickle.loads(s) - self.assertEqual(x.__class__, a.__class__) - self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__)) - self.assertEqual(y.__class__, b.__class__) - self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__)) - self.assertEqual(repr(x), repr(a)) - self.assertEqual(repr(y), repr(b)) - # Test for __getstate__ and __setstate__ on new style class - u = C3(42) - s = pickle.dumps(u, bin) - v = pickle.loads(s) - self.assertEqual(u.__class__, v.__class__) - self.assertEqual(u.foo, v.foo) - # Test for picklability of hybrid class - u = C4() - u.foo = 42 - s = pickle.dumps(u, bin) - v = pickle.loads(s) - self.assertEqual(u.__class__, v.__class__) - self.assertEqual(u.foo, v.foo) - - # Testing copy.deepcopy() - import copy - for cls in C, C1, C2: - cls2 = copy.deepcopy(cls) - self.assertIs(cls2, cls) - - a = C1(1, 2); a.append(42); a.append(24) - b = C2("hello", "world", 42) - x, y = copy.deepcopy((a, b)) - self.assertEqual(x.__class__, a.__class__) - self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__)) - self.assertEqual(y.__class__, b.__class__) - self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__)) - self.assertEqual(repr(x), repr(a)) - self.assertEqual(repr(y), repr(b)) - - def test_pickle_slots(self): - # Testing pickling of classes with __slots__ ... - import pickle - # Pickling of classes with __slots__ but without __getstate__ should fail - # (if using protocol 0 or 1) - global B, C, D, E - class B(object): - pass - for base in [object, B]: - class C(base): - __slots__ = ['a'] - class D(C): - pass - try: - pickle.dumps(C(), 0) - except TypeError: - pass - else: - self.fail("should fail: pickle C instance - %s" % base) - try: - pickle.dumps(C(), 0) - except TypeError: - pass - else: - self.fail("should fail: pickle D instance - %s" % base) - # Give C a nice generic __getstate__ and __setstate__ - class C(base): - __slots__ = ['a'] - def __getstate__(self): - try: - d = self.__dict__.copy() - except AttributeError: - d = {} - for cls in self.__class__.__mro__: - for sn in cls.__dict__.get('__slots__', ()): - try: - d[sn] = getattr(self, sn) - except AttributeError: - pass - return d - def __setstate__(self, d): - for k, v in list(d.items()): - setattr(self, k, v) - class D(C): - pass - # Now it should work - x = C() - y = pickle.loads(pickle.dumps(x)) - self.assertNotHasAttr(y, 'a') - x.a = 42 - y = pickle.loads(pickle.dumps(x)) - self.assertEqual(y.a, 42) - x = D() - x.a = 42 - x.b = 100 - y = pickle.loads(pickle.dumps(x)) - self.assertEqual(y.a + y.b, 142) - # A subclass that adds a slot should also work - class E(C): - __slots__ = ['b'] - x = E() - x.a = 42 - x.b = "foo" - y = pickle.loads(pickle.dumps(x)) - self.assertEqual(y.a, x.a) - self.assertEqual(y.b, x.b) - def test_binary_operator_override(self): # Testing overrides of binary operations... class I(int): @@ -4690,11 +4523,439 @@ class MiscTests(unittest.TestCase): self.assertEqual(X.mykey2, 'from Base2') +class PicklingTests(unittest.TestCase): + + def _check_reduce(self, proto, obj, args=(), kwargs={}, state=None, + listitems=None, dictitems=None): + if proto >= 4: + reduce_value = obj.__reduce_ex__(proto) + self.assertEqual(reduce_value[:3], + (copyreg.__newobj_ex__, + (type(obj), args, kwargs), + state)) + if listitems is not None: + self.assertListEqual(list(reduce_value[3]), listitems) + else: + self.assertIsNone(reduce_value[3]) + if dictitems is not None: + self.assertDictEqual(dict(reduce_value[4]), dictitems) + else: + self.assertIsNone(reduce_value[4]) + elif proto >= 2: + reduce_value = obj.__reduce_ex__(proto) + self.assertEqual(reduce_value[:3], + (copyreg.__newobj__, + (type(obj),) + args, + state)) + if listitems is not None: + self.assertListEqual(list(reduce_value[3]), listitems) + else: + self.assertIsNone(reduce_value[3]) + if dictitems is not None: + self.assertDictEqual(dict(reduce_value[4]), dictitems) + else: + self.assertIsNone(reduce_value[4]) + else: + base_type = type(obj).__base__ + reduce_value = (copyreg._reconstructor, + (type(obj), + base_type, + None if base_type is object else base_type(obj))) + if state is not None: + reduce_value += (state,) + self.assertEqual(obj.__reduce_ex__(proto), reduce_value) + self.assertEqual(obj.__reduce__(), reduce_value) + + def test_reduce(self): + protocols = range(pickle.HIGHEST_PROTOCOL + 1) + args = (-101, "spam") + kwargs = {'bacon': -201, 'fish': -301} + state = {'cheese': -401} + + class C1: + def __getnewargs__(self): + return args + obj = C1() + for proto in protocols: + self._check_reduce(proto, obj, args) + + for name, value in state.items(): + setattr(obj, name, value) + for proto in protocols: + self._check_reduce(proto, obj, args, state=state) + + class C2: + def __getnewargs__(self): + return "bad args" + obj = C2() + for proto in protocols: + if proto >= 2: + with self.assertRaises(TypeError): + obj.__reduce_ex__(proto) + + class C3: + def __getnewargs_ex__(self): + return (args, kwargs) + obj = C3() + for proto in protocols: + if proto >= 4: + self._check_reduce(proto, obj, args, kwargs) + elif proto >= 2: + with self.assertRaises(ValueError): + obj.__reduce_ex__(proto) + + class C4: + def __getnewargs_ex__(self): + return (args, "bad dict") + class C5: + def __getnewargs_ex__(self): + return ("bad tuple", kwargs) + class C6: + def __getnewargs_ex__(self): + return () + class C7: + def __getnewargs_ex__(self): + return "bad args" + for proto in protocols: + for cls in C4, C5, C6, C7: + obj = cls() + if proto >= 2: + with self.assertRaises((TypeError, ValueError)): + obj.__reduce_ex__(proto) + + class C8: + def __getnewargs_ex__(self): + return (args, kwargs) + obj = C8() + for proto in protocols: + if 2 <= proto < 4: + with self.assertRaises(ValueError): + obj.__reduce_ex__(proto) + class C9: + def __getnewargs_ex__(self): + return (args, {}) + obj = C9() + for proto in protocols: + self._check_reduce(proto, obj, args) + + class C10: + def __getnewargs_ex__(self): + raise IndexError + obj = C10() + for proto in protocols: + if proto >= 2: + with self.assertRaises(IndexError): + obj.__reduce_ex__(proto) + + class C11: + def __getstate__(self): + return state + obj = C11() + for proto in protocols: + self._check_reduce(proto, obj, state=state) + + class C12: + def __getstate__(self): + return "not dict" + obj = C12() + for proto in protocols: + self._check_reduce(proto, obj, state="not dict") + + class C13: + def __getstate__(self): + raise IndexError + obj = C13() + for proto in protocols: + with self.assertRaises(IndexError): + obj.__reduce_ex__(proto) + if proto < 2: + with self.assertRaises(IndexError): + obj.__reduce__() + + class C14: + __slots__ = tuple(state) + def __init__(self): + for name, value in state.items(): + setattr(self, name, value) + + obj = C14() + for proto in protocols: + if proto >= 2: + self._check_reduce(proto, obj, state=(None, state)) + else: + with self.assertRaises(TypeError): + obj.__reduce_ex__(proto) + with self.assertRaises(TypeError): + obj.__reduce__() + + class C15(dict): + pass + obj = C15({"quebec": -601}) + for proto in protocols: + self._check_reduce(proto, obj, dictitems=dict(obj)) + + class C16(list): + pass + obj = C16(["yukon"]) + for proto in protocols: + self._check_reduce(proto, obj, listitems=list(obj)) + + def _assert_is_copy(self, obj, objcopy, msg=None): + """Utility method to verify if two objects are copies of each others. + """ + if msg is None: + msg = "{!r} is not a copy of {!r}".format(obj, objcopy) + if type(obj).__repr__ is object.__repr__: + # We have this limitation for now because we use the object's repr + # to help us verify that the two objects are copies. This allows + # us to delegate the non-generic verification logic to the objects + # themselves. + raise ValueError("object passed to _assert_is_copy must " + + "override the __repr__ method.") + self.assertIsNot(obj, objcopy, msg=msg) + self.assertIs(type(obj), type(objcopy), msg=msg) + if hasattr(obj, '__dict__'): + self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg) + self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg) + if hasattr(obj, '__slots__'): + self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg) + for slot in obj.__slots__: + self.assertEqual( + hasattr(obj, slot), hasattr(objcopy, slot), msg=msg) + self.assertEqual(getattr(obj, slot, None), + getattr(objcopy, slot, None), msg=msg) + self.assertEqual(repr(obj), repr(objcopy), msg=msg) + + @staticmethod + def _generate_pickle_copiers(): + """Utility method to generate the many possible pickle configurations. + """ + class PickleCopier: + "This class copies object using pickle." + def __init__(self, proto, dumps, loads): + self.proto = proto + self.dumps = dumps + self.loads = loads + def copy(self, obj): + return self.loads(self.dumps(obj, self.proto)) + def __repr__(self): + # We try to be as descriptive as possible here since this is + # the string which we will allow us to tell the pickle + # configuration we are using during debugging. + return ("PickleCopier(proto={}, dumps={}.{}, loads={}.{})" + .format(self.proto, + self.dumps.__module__, self.dumps.__qualname__, + self.loads.__module__, self.loads.__qualname__)) + return (PickleCopier(*args) for args in + itertools.product(range(pickle.HIGHEST_PROTOCOL + 1), + {pickle.dumps, pickle._dumps}, + {pickle.loads, pickle._loads})) + + def test_pickle_slots(self): + # Tests pickling of classes with __slots__. + + # Pickling of classes with __slots__ but without __getstate__ should + # fail (if using protocol 0 or 1) + global C + class C: + __slots__ = ['a'] + with self.assertRaises(TypeError): + pickle.dumps(C(), 0) + + global D + class D(C): + pass + with self.assertRaises(TypeError): + pickle.dumps(D(), 0) + + class C: + "A class with __getstate__ and __setstate__ implemented." + __slots__ = ['a'] + def __getstate__(self): + state = getattr(self, '__dict__', {}).copy() + for cls in type(self).__mro__: + for slot in cls.__dict__.get('__slots__', ()): + try: + state[slot] = getattr(self, slot) + except AttributeError: + pass + return state + def __setstate__(self, state): + for k, v in state.items(): + setattr(self, k, v) + def __repr__(self): + return "%s()<%r>" % (type(self).__name__, self.__getstate__()) + + class D(C): + "A subclass of a class with slots." + pass + + global E + class E(C): + "A subclass with an extra slot." + __slots__ = ['b'] + + # Now it should work + for pickle_copier in self._generate_pickle_copiers(): + with self.subTest(pickle_copier=pickle_copier): + x = C() + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + x.a = 42 + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + x = D() + x.a = 42 + x.b = 100 + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + x = E() + x.a = 42 + x.b = "foo" + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + def test_reduce_copying(self): + # Tests pickling and copying new-style classes and objects. + global C1 + class C1: + "The state of this class is copyable via its instance dict." + ARGS = (1, 2) + NEED_DICT_COPYING = True + def __init__(self, a, b): + super().__init__() + self.a = a + self.b = b + def __repr__(self): + return "C1(%r, %r)" % (self.a, self.b) + + global C2 + class C2(list): + "A list subclass copyable via __getnewargs__." + ARGS = (1, 2) + NEED_DICT_COPYING = False + def __new__(cls, a, b): + self = super().__new__(cls) + self.a = a + self.b = b + return self + def __init__(self, *args): + super().__init__() + # This helps testing that __init__ is not called during the + # unpickling process, which would cause extra appends. + self.append("cheese") + @classmethod + def __getnewargs__(cls): + return cls.ARGS + def __repr__(self): + return "C2(%r, %r)<%r>" % (self.a, self.b, list(self)) + + global C3 + class C3(list): + "A list subclass copyable via __getstate__." + ARGS = (1, 2) + NEED_DICT_COPYING = False + def __init__(self, a, b): + self.a = a + self.b = b + # This helps testing that __init__ is not called during the + # unpickling process, which would cause extra appends. + self.append("cheese") + @classmethod + def __getstate__(cls): + return cls.ARGS + def __setstate__(self, state): + a, b = state + self.a = a + self.b = b + def __repr__(self): + return "C3(%r, %r)<%r>" % (self.a, self.b, list(self)) + + global C4 + class C4(int): + "An int subclass copyable via __getnewargs__." + ARGS = ("hello", "world", 1) + NEED_DICT_COPYING = False + def __new__(cls, a, b, value): + self = super().__new__(cls, value) + self.a = a + self.b = b + return self + @classmethod + def __getnewargs__(cls): + return cls.ARGS + def __repr__(self): + return "C4(%r, %r)<%r>" % (self.a, self.b, int(self)) + + global C5 + class C5(int): + "An int subclass copyable via __getnewargs_ex__." + ARGS = (1, 2) + KWARGS = {'value': 3} + NEED_DICT_COPYING = False + def __new__(cls, a, b, *, value=0): + self = super().__new__(cls, value) + self.a = a + self.b = b + return self + @classmethod + def __getnewargs_ex__(cls): + return (cls.ARGS, cls.KWARGS) + def __repr__(self): + return "C5(%r, %r)<%r>" % (self.a, self.b, int(self)) + + test_classes = (C1, C2, C3, C4, C5) + # Testing copying through pickle + pickle_copiers = self._generate_pickle_copiers() + for cls, pickle_copier in itertools.product(test_classes, pickle_copiers): + with self.subTest(cls=cls, pickle_copier=pickle_copier): + kwargs = getattr(cls, 'KWARGS', {}) + obj = cls(*cls.ARGS, **kwargs) + proto = pickle_copier.proto + if 2 <= proto < 4 and hasattr(cls, '__getnewargs_ex__'): + with self.assertRaises(ValueError): + pickle_copier.dumps(obj, proto) + continue + objcopy = pickle_copier.copy(obj) + self._assert_is_copy(obj, objcopy) + # For test classes that supports this, make sure we didn't go + # around the reduce protocol by simply copying the attribute + # dictionary. We clear attributes using the previous copy to + # not mutate the original argument. + if proto >= 2 and not cls.NEED_DICT_COPYING: + objcopy.__dict__.clear() + objcopy2 = pickle_copier.copy(objcopy) + self._assert_is_copy(obj, objcopy2) + + # Testing copying through copy.deepcopy() + for cls in test_classes: + with self.subTest(cls=cls): + kwargs = getattr(cls, 'KWARGS', {}) + obj = cls(*cls.ARGS, **kwargs) + # XXX: We need to modify the copy module to support PEP 3154's + # reduce protocol 4. + if hasattr(cls, '__getnewargs_ex__'): + continue + objcopy = deepcopy(obj) + self._assert_is_copy(obj, objcopy) + # For test classes that supports this, make sure we didn't go + # around the reduce protocol by simply copying the attribute + # dictionary. We clear attributes using the previous copy to + # not mutate the original argument. + if not cls.NEED_DICT_COPYING: + objcopy.__dict__.clear() + objcopy2 = deepcopy(objcopy) + self._assert_is_copy(obj, objcopy2) + + def test_main(): # Run all local test cases, with PTypesLongInitTest first. support.run_unittest(PTypesLongInitTest, OperatorsTest, ClassPropertiesAndMethods, DictProxyTests, - MiscTests) + MiscTests, PicklingTests) if __name__ == "__main__": test_main() diff --git a/Misc/NEWS b/Misc/NEWS index 336c3f4..06bb771 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -68,6 +68,8 @@ Core and Builtins Library ------- +- Issue #17810: Implement PEP 3154, pickle protocol 4. + - Issue #19668: Added support for the cp1125 encoding. - Issue #19689: Add ssl.create_default_context() factory function. It creates diff --git a/Modules/_pickle.c b/Modules/_pickle.c index 9852cd3..f9aa043 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -6,7 +6,7 @@ PyDoc_STRVAR(pickle_module_doc, /* Bump this when new opcodes are added to the pickle protocol. */ enum { - HIGHEST_PROTOCOL = 3, + HIGHEST_PROTOCOL = 4, DEFAULT_PROTOCOL = 3 }; @@ -71,7 +71,19 @@ enum opcode { /* Protocol 3 (Python 3.x) */ BINBYTES = 'B', - SHORT_BINBYTES = 'C' + SHORT_BINBYTES = 'C', + + /* Protocol 4 */ + SHORT_BINUNICODE = '\x8c', + BINUNICODE8 = '\x8d', + BINBYTES8 = '\x8e', + EMPTY_SET = '\x8f', + ADDITEMS = '\x90', + FROZENSET = '\x91', + NEWOBJ_EX = '\x92', + STACK_GLOBAL = '\x93', + MEMOIZE = '\x94', + FRAME = '\x95' }; /* These aren't opcodes -- they're ways to pickle bools before protocol 2 @@ -103,7 +115,11 @@ enum { MAX_WRITE_BUF_SIZE = 64 * 1024, /* Prefetch size when unpickling (disabled on unpeekable streams) */ - PREFETCH = 8192 * 16 + PREFETCH = 8192 * 16, + + FRAME_SIZE_TARGET = 64 * 1024, + + FRAME_HEADER_SIZE = 9 }; /* Exception classes for pickle. These should override the ones defined in @@ -136,9 +152,6 @@ static PyObject *empty_tuple = NULL; /* For looking up name pairs in copyreg._extension_registry. */ static PyObject *two_tuple = NULL; -_Py_IDENTIFIER(__name__); -_Py_IDENTIFIER(modules); - static int stack_underflow(void) { @@ -332,7 +345,12 @@ typedef struct PicklerObject { Py_ssize_t max_output_len; /* Allocation size of output_buffer. */ int proto; /* Pickle protocol number, >= 0 */ int bin; /* Boolean, true if proto > 0 */ - Py_ssize_t buf_size; /* Size of the current buffered pickle data */ + int framing; /* True when framing is enabled, proto >= 4 */ + Py_ssize_t frame_start; /* Position in output_buffer where the + where the current frame begins. -1 if there + is no frame currently open. */ + + Py_ssize_t buf_size; /* Size of the current buffered pickle data */ int fast; /* Enable fast mode if set to a true value. The fast mode disable the usage of memo, therefore speeding the pickling process by @@ -352,7 +370,8 @@ typedef struct UnpicklerObject { /* The unpickler memo is just an array of PyObject *s. Using a dict is unnecessary, since the keys are contiguous ints. */ PyObject **memo; - Py_ssize_t memo_size; + Py_ssize_t memo_size; /* Capacity of the memo array */ + Py_ssize_t memo_len; /* Number of objects in the memo */ PyObject *arg; PyObject *pers_func; /* persistent_load() method, can be NULL. */ @@ -362,7 +381,9 @@ typedef struct UnpicklerObject { char *input_line; Py_ssize_t input_len; Py_ssize_t next_read_idx; + Py_ssize_t frame_end_idx; Py_ssize_t prefetched_idx; /* index of first prefetched byte */ + PyObject *read; /* read() method of the input stream. */ PyObject *readline; /* readline() method of the input stream. */ PyObject *peek; /* peek() method of the input stream, or NULL */ @@ -380,6 +401,7 @@ typedef struct UnpicklerObject { int proto; /* Protocol of the pickle loaded. */ int fix_imports; /* Indicate whether Unpickler should fix the name of globals pickled by Python 2.x. */ + int framing; /* True when framing is enabled, proto >= 4 */ } UnpicklerObject; /* Forward declarations */ @@ -673,15 +695,63 @@ _Pickler_ClearBuffer(PicklerObject *self) if (self->output_buffer == NULL) return -1; self->output_len = 0; + self->frame_start = -1; return 0; } +static void +_Pickler_WriteFrameHeader(PicklerObject *self, char *qdata, size_t frame_len) +{ + qdata[0] = (unsigned char)FRAME; + qdata[1] = (unsigned char)(frame_len & 0xff); + qdata[2] = (unsigned char)((frame_len >> 8) & 0xff); + qdata[3] = (unsigned char)((frame_len >> 16) & 0xff); + qdata[4] = (unsigned char)((frame_len >> 24) & 0xff); + qdata[5] = (unsigned char)((frame_len >> 32) & 0xff); + qdata[6] = (unsigned char)((frame_len >> 40) & 0xff); + qdata[7] = (unsigned char)((frame_len >> 48) & 0xff); + qdata[8] = (unsigned char)((frame_len >> 56) & 0xff); +} + +static int +_Pickler_CommitFrame(PicklerObject *self) +{ + size_t frame_len; + char *qdata; + + if (!self->framing || self->frame_start == -1) + return 0; + frame_len = self->output_len - self->frame_start - FRAME_HEADER_SIZE; + qdata = PyBytes_AS_STRING(self->output_buffer) + self->frame_start; + _Pickler_WriteFrameHeader(self, qdata, frame_len); + self->frame_start = -1; + return 0; +} + +static int +_Pickler_OpcodeBoundary(PicklerObject *self) +{ + Py_ssize_t frame_len; + + if (!self->framing || self->frame_start == -1) + return 0; + frame_len = self->output_len - self->frame_start - FRAME_HEADER_SIZE; + if (frame_len >= FRAME_SIZE_TARGET) + return _Pickler_CommitFrame(self); + else + return 0; +} + static PyObject * _Pickler_GetString(PicklerObject *self) { PyObject *output_buffer = self->output_buffer; assert(self->output_buffer != NULL); + + if (_Pickler_CommitFrame(self)) + return NULL; + self->output_buffer = NULL; /* Resize down to exact size */ if (_PyBytes_Resize(&output_buffer, self->output_len) < 0) @@ -696,6 +766,7 @@ _Pickler_FlushToFile(PicklerObject *self) assert(self->write != NULL); + /* This will commit the frame first */ output = _Pickler_GetString(self); if (output == NULL) return -1; @@ -706,57 +777,93 @@ _Pickler_FlushToFile(PicklerObject *self) } static Py_ssize_t -_Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t n) +_Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t data_len) { - Py_ssize_t i, required; + Py_ssize_t i, n, required; char *buffer; + int need_new_frame; assert(s != NULL); + need_new_frame = (self->framing && self->frame_start == -1); + + if (need_new_frame) + n = data_len + FRAME_HEADER_SIZE; + else + n = data_len; required = self->output_len + n; - if (required > self->max_output_len) { - if (self->write != NULL && required > MAX_WRITE_BUF_SIZE) { - /* XXX This reallocates a new buffer every time, which is a bit - wasteful. */ - if (_Pickler_FlushToFile(self) < 0) - return -1; - if (_Pickler_ClearBuffer(self) < 0) - return -1; - } - if (self->write != NULL && n > MAX_WRITE_BUF_SIZE) { - /* we already flushed above, so the buffer is empty */ - PyObject *result; - /* XXX we could spare an intermediate copy and pass - a memoryview instead */ - PyObject *output = PyBytes_FromStringAndSize(s, n); - if (s == NULL) + if (self->write != NULL && required > MAX_WRITE_BUF_SIZE) { + /* XXX This reallocates a new buffer every time, which is a bit + wasteful. */ + if (_Pickler_FlushToFile(self) < 0) + return -1; + if (_Pickler_ClearBuffer(self) < 0) + return -1; + /* The previous frame was just committed by _Pickler_FlushToFile */ + need_new_frame = self->framing; + if (need_new_frame) + n = data_len + FRAME_HEADER_SIZE; + else + n = data_len; + required = self->output_len + n; + } + if (self->write != NULL && n > MAX_WRITE_BUF_SIZE) { + /* For large pickle chunks, we write directly to the output + file instead of buffering. Note the buffer is empty at this + point (it was flushed above, since required >= n). */ + PyObject *output, *result; + if (need_new_frame) { + char frame_header[FRAME_HEADER_SIZE]; + _Pickler_WriteFrameHeader(self, frame_header, (size_t) data_len); + output = PyBytes_FromStringAndSize(frame_header, FRAME_HEADER_SIZE); + if (output == NULL) return -1; result = _Pickler_FastCall(self, self->write, output); Py_XDECREF(result); - return (result == NULL) ? -1 : 0; - } - else { - if (self->output_len >= PY_SSIZE_T_MAX / 2 - n) { - PyErr_NoMemory(); - return -1; - } - self->max_output_len = (self->output_len + n) / 2 * 3; - if (_PyBytes_Resize(&self->output_buffer, self->max_output_len) < 0) + if (result == NULL) return -1; } + /* XXX we could spare an intermediate copy and pass + a memoryview instead */ + output = PyBytes_FromStringAndSize(s, data_len); + if (output == NULL) + return -1; + result = _Pickler_FastCall(self, self->write, output); + Py_XDECREF(result); + return (result == NULL) ? -1 : 0; + } + if (required > self->max_output_len) { + /* Make place in buffer for the pickle chunk */ + if (self->output_len >= PY_SSIZE_T_MAX / 2 - n) { + PyErr_NoMemory(); + return -1; + } + self->max_output_len = (self->output_len + n) / 2 * 3; + if (_PyBytes_Resize(&self->output_buffer, self->max_output_len) < 0) + return -1; } buffer = PyBytes_AS_STRING(self->output_buffer); - if (n < 8) { + if (need_new_frame) { + /* Setup new frame */ + Py_ssize_t frame_start = self->output_len; + self->frame_start = frame_start; + for (i = 0; i < FRAME_HEADER_SIZE; i++) { + /* Write an invalid value, for debugging */ + buffer[frame_start + i] = 0xFE; + } + self->output_len += FRAME_HEADER_SIZE; + } + if (data_len < 8) { /* This is faster than memcpy when the string is short. */ - for (i = 0; i < n; i++) { + for (i = 0; i < data_len; i++) { buffer[self->output_len + i] = s[i]; } } else { - memcpy(buffer + self->output_len, s, n); + memcpy(buffer + self->output_len, s, data_len); } - self->output_len += n; - return n; + self->output_len += data_len; + return data_len; } static PicklerObject * @@ -774,6 +881,8 @@ _Pickler_New(void) self->write = NULL; self->proto = 0; self->bin = 0; + self->framing = 0; + self->frame_start = -1; self->fast = 0; self->fast_nesting = 0; self->fix_imports = 0; @@ -868,6 +977,7 @@ _Unpickler_SetStringInput(UnpicklerObject *self, PyObject *input) self->input_buffer = self->buffer.buf; self->input_len = self->buffer.len; self->next_read_idx = 0; + self->frame_end_idx = -1; self->prefetched_idx = self->input_len; return self->input_len; } @@ -932,7 +1042,7 @@ _Unpickler_ReadFromFile(UnpicklerObject *self, Py_ssize_t n) return -1; /* Prefetch some data without advancing the file pointer, if possible */ - if (self->peek) { + if (self->peek && !self->framing) { PyObject *len, *prefetched; len = PyLong_FromSsize_t(PREFETCH); if (len == NULL) { @@ -980,7 +1090,7 @@ _Unpickler_ReadFromFile(UnpicklerObject *self, Py_ssize_t n) Returns -1 (with an exception set) on failure. On success, return the number of chars read. */ static Py_ssize_t -_Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n) +_Unpickler_ReadUnframed(UnpicklerObject *self, char **s, Py_ssize_t n) { Py_ssize_t num_read; @@ -1006,6 +1116,67 @@ _Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n) } static Py_ssize_t +_Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n) +{ + if (self->framing && + (self->frame_end_idx == -1 || + self->frame_end_idx <= self->next_read_idx)) { + /* Need to read new frame */ + char *dummy; + unsigned char *frame_start; + size_t frame_len; + if (_Unpickler_ReadUnframed(self, &dummy, FRAME_HEADER_SIZE) < 0) + return -1; + frame_start = (unsigned char *) dummy; + if (frame_start[0] != (unsigned char)FRAME) { + PyErr_Format(UnpicklingError, + "expected FRAME opcode, got 0x%x instead", + frame_start[0]); + return -1; + } + frame_len = (size_t) frame_start[1]; + frame_len |= (size_t) frame_start[2] << 8; + frame_len |= (size_t) frame_start[3] << 16; + frame_len |= (size_t) frame_start[4] << 24; +#if SIZEOF_SIZE_T >= 8 + frame_len |= (size_t) frame_start[5] << 32; + frame_len |= (size_t) frame_start[6] << 40; + frame_len |= (size_t) frame_start[7] << 48; + frame_len |= (size_t) frame_start[8] << 56; +#else + if (frame_start[5] || frame_start[6] || + frame_start[7] || frame_start[8]) { + PyErr_Format(PyExc_OverflowError, + "Frame size too large for 32-bit build"); + return -1; + } +#endif + if (frame_len > PY_SSIZE_T_MAX) { + PyErr_Format(UnpicklingError, "Invalid frame length"); + return -1; + } + if (frame_len < n) { + PyErr_Format(UnpicklingError, "Bad framing"); + return -1; + } + if (_Unpickler_ReadUnframed(self, &dummy /* unused */, + frame_len) < 0) + return -1; + /* Rewind to start of frame */ + self->frame_end_idx = self->next_read_idx; + self->next_read_idx -= frame_len; + } + if (self->framing) { + /* Check for bad input */ + if (n + self->next_read_idx > self->frame_end_idx) { + PyErr_Format(UnpicklingError, "Bad framing"); + return -1; + } + } + return _Unpickler_ReadUnframed(self, s, n); +} + +static Py_ssize_t _Unpickler_CopyLine(UnpicklerObject *self, char *line, Py_ssize_t len, char **result) { @@ -1102,7 +1273,12 @@ _Unpickler_MemoPut(UnpicklerObject *self, Py_ssize_t idx, PyObject *value) Py_INCREF(value); old_item = self->memo[idx]; self->memo[idx] = value; - Py_XDECREF(old_item); + if (old_item != NULL) { + Py_DECREF(old_item); + } + else { + self->memo_len++; + } return 0; } @@ -1150,6 +1326,7 @@ _Unpickler_New(void) self->input_line = NULL; self->input_len = 0; self->next_read_idx = 0; + self->frame_end_idx = -1; self->prefetched_idx = 0; self->read = NULL; self->readline = NULL; @@ -1160,9 +1337,11 @@ _Unpickler_New(void) self->num_marks = 0; self->marks_size = 0; self->proto = 0; + self->framing = 0; self->fix_imports = 0; memset(&self->buffer, 0, sizeof(Py_buffer)); self->memo_size = 32; + self->memo_len = 0; self->memo = _Unpickler_NewMemo(self->memo_size); self->stack = (Pdata *)Pdata_New(); @@ -1277,36 +1456,44 @@ memo_get(PicklerObject *self, PyObject *key) static int memo_put(PicklerObject *self, PyObject *obj) { - Py_ssize_t x; char pdata[30]; Py_ssize_t len; - int status = 0; + Py_ssize_t idx; + + const char memoize_op = MEMOIZE; if (self->fast) return 0; + if (_Pickler_OpcodeBoundary(self)) + return -1; - x = PyMemoTable_Size(self->memo); - if (PyMemoTable_Set(self->memo, obj, x) < 0) - goto error; + idx = PyMemoTable_Size(self->memo); + if (PyMemoTable_Set(self->memo, obj, idx) < 0) + return -1; - if (!self->bin) { + if (self->proto >= 4) { + if (_Pickler_Write(self, &memoize_op, 1) < 0) + return -1; + return 0; + } + else if (!self->bin) { pdata[0] = PUT; PyOS_snprintf(pdata + 1, sizeof(pdata) - 1, - "%" PY_FORMAT_SIZE_T "d\n", x); + "%" PY_FORMAT_SIZE_T "d\n", idx); len = strlen(pdata); } else { - if (x < 256) { + if (idx < 256) { pdata[0] = BINPUT; - pdata[1] = (unsigned char)x; + pdata[1] = (unsigned char)idx; len = 2; } - else if (x <= 0xffffffffL) { + else if (idx <= 0xffffffffL) { pdata[0] = LONG_BINPUT; - pdata[1] = (unsigned char)(x & 0xff); - pdata[2] = (unsigned char)((x >> 8) & 0xff); - pdata[3] = (unsigned char)((x >> 16) & 0xff); - pdata[4] = (unsigned char)((x >> 24) & 0xff); + pdata[1] = (unsigned char)(idx & 0xff); + pdata[2] = (unsigned char)((idx >> 8) & 0xff); + pdata[3] = (unsigned char)((idx >> 16) & 0xff); + pdata[4] = (unsigned char)((idx >> 24) & 0xff); len = 5; } else { /* unlikely */ @@ -1315,57 +1502,94 @@ memo_put(PicklerObject *self, PyObject *obj) return -1; } } - if (_Pickler_Write(self, pdata, len) < 0) - goto error; + return -1; - if (0) { - error: - status = -1; - } + return 0; +} - return status; +static PyObject * +getattribute(PyObject *obj, PyObject *name, int allow_qualname) { + PyObject *dotted_path; + Py_ssize_t i; + _Py_static_string(PyId_dot, "."); + _Py_static_string(PyId_locals, ""); + + dotted_path = PyUnicode_Split(name, _PyUnicode_FromId(&PyId_dot), -1); + if (dotted_path == NULL) { + return NULL; + } + assert(Py_SIZE(dotted_path) >= 1); + if (!allow_qualname && Py_SIZE(dotted_path) > 1) { + PyErr_Format(PyExc_AttributeError, + "Can't get qualified attribute %R on %R;" + "use protocols >= 4 to enable support", + name, obj); + Py_DECREF(dotted_path); + return NULL; + } + Py_INCREF(obj); + for (i = 0; i < Py_SIZE(dotted_path); i++) { + PyObject *subpath = PyList_GET_ITEM(dotted_path, i); + PyObject *tmp; + PyObject *result = PyUnicode_RichCompare( + subpath, _PyUnicode_FromId(&PyId_locals), Py_EQ); + int is_equal = (result == Py_True); + assert(PyBool_Check(result)); + Py_DECREF(result); + if (is_equal) { + PyErr_Format(PyExc_AttributeError, + "Can't get local attribute %R on %R", name, obj); + Py_DECREF(dotted_path); + Py_DECREF(obj); + return NULL; + } + tmp = PyObject_GetAttr(obj, subpath); + Py_DECREF(obj); + if (tmp == NULL) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + PyErr_Format(PyExc_AttributeError, + "Can't get attribute %R on %R", name, obj); + } + Py_DECREF(dotted_path); + return NULL; + } + obj = tmp; + } + Py_DECREF(dotted_path); + return obj; } static PyObject * -whichmodule(PyObject *global, PyObject *global_name) +whichmodule(PyObject *global, PyObject *global_name, int allow_qualname) { - Py_ssize_t i, j; - static PyObject *module_str = NULL; - static PyObject *main_str = NULL; PyObject *module_name; PyObject *modules_dict; PyObject *module; PyObject *obj; + Py_ssize_t i, j; + _Py_IDENTIFIER(__module__); + _Py_IDENTIFIER(modules); + _Py_IDENTIFIER(__main__); - if (module_str == NULL) { - module_str = PyUnicode_InternFromString("__module__"); - if (module_str == NULL) - return NULL; - main_str = PyUnicode_InternFromString("__main__"); - if (main_str == NULL) - return NULL; - } - - module_name = PyObject_GetAttr(global, module_str); + module_name = _PyObject_GetAttrId(global, &PyId___module__); - /* In some rare cases (e.g., bound methods of extension types), - __module__ can be None. If it is so, then search sys.modules - for the module of global. */ - if (module_name == Py_None) { - Py_DECREF(module_name); - goto search; + if (module_name == NULL) { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) + return NULL; + PyErr_Clear(); } - - if (module_name) { - return module_name; + else { + /* In some rare cases (e.g., bound methods of extension types), + __module__ can be None. If it is so, then search sys.modules for + the module of global. */ + if (module_name != Py_None) + return module_name; + Py_CLEAR(module_name); } - if (PyErr_ExceptionMatches(PyExc_AttributeError)) - PyErr_Clear(); - else - return NULL; + assert(module_name == NULL); - search: modules_dict = _PySys_GetObjectId(&PyId_modules); if (modules_dict == NULL) { PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules"); @@ -1373,34 +1597,35 @@ whichmodule(PyObject *global, PyObject *global_name) } i = 0; - module_name = NULL; while ((j = PyDict_Next(modules_dict, &i, &module_name, &module))) { - if (PyObject_RichCompareBool(module_name, main_str, Py_EQ) == 1) + PyObject *result = PyUnicode_RichCompare( + module_name, _PyUnicode_FromId(&PyId___main__), Py_EQ); + int is_equal = (result == Py_True); + assert(PyBool_Check(result)); + Py_DECREF(result); + if (is_equal) + continue; + if (module == Py_None) continue; - obj = PyObject_GetAttr(module, global_name); + obj = getattribute(module, global_name, allow_qualname); if (obj == NULL) { - if (PyErr_ExceptionMatches(PyExc_AttributeError)) - PyErr_Clear(); - else + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) return NULL; + PyErr_Clear(); continue; } - if (obj != global) { + if (obj == global) { Py_DECREF(obj); - continue; + Py_INCREF(module_name); + return module_name; } - Py_DECREF(obj); - break; } /* If no module is found, use __main__. */ - if (!j) { - module_name = main_str; - } - + module_name = _PyUnicode_FromId(&PyId___main__); Py_INCREF(module_name); return module_name; } @@ -1744,22 +1969,17 @@ save_bytes(PicklerObject *self, PyObject *obj) reduce_value = Py_BuildValue("(O())", (PyObject*)&PyBytes_Type); } else { - static PyObject *latin1 = NULL; PyObject *unicode_str = PyUnicode_DecodeLatin1(PyBytes_AS_STRING(obj), PyBytes_GET_SIZE(obj), "strict"); + _Py_IDENTIFIER(latin1); + if (unicode_str == NULL) return -1; - if (latin1 == NULL) { - latin1 = PyUnicode_InternFromString("latin1"); - if (latin1 == NULL) { - Py_DECREF(unicode_str); - return -1; - } - } reduce_value = Py_BuildValue("(O(OO))", - codecs_encode, unicode_str, latin1); + codecs_encode, unicode_str, + _PyUnicode_FromId(&PyId_latin1)); Py_DECREF(unicode_str); } @@ -1773,14 +1993,14 @@ save_bytes(PicklerObject *self, PyObject *obj) } else { Py_ssize_t size; - char header[5]; + char header[9]; Py_ssize_t len; size = PyBytes_GET_SIZE(obj); if (size < 0) return -1; - if (size < 256) { + if (size <= 0xff) { header[0] = SHORT_BINBYTES; header[1] = (unsigned char)size; len = 2; @@ -1793,6 +2013,14 @@ save_bytes(PicklerObject *self, PyObject *obj) header[4] = (unsigned char)((size >> 24) & 0xff); len = 5; } + else if (self->proto >= 4) { + int i; + header[0] = BINBYTES8; + for (i = 0; i < 8; i++) { + header[i+1] = (unsigned char)((size >> (8 * i)) & 0xff); + } + len = 8; + } else { PyErr_SetString(PyExc_OverflowError, "cannot serialize a bytes object larger than 4 GiB"); @@ -1882,26 +2110,39 @@ done: static int write_utf8(PicklerObject *self, char *data, Py_ssize_t size) { - char pdata[5]; + char header[9]; + Py_ssize_t len; + + if (size <= 0xff && self->proto >= 4) { + header[0] = SHORT_BINUNICODE; + header[1] = (unsigned char)(size & 0xff); + len = 2; + } + else if (size <= 0xffffffffUL) { + header[0] = BINUNICODE; + header[1] = (unsigned char)(size & 0xff); + header[2] = (unsigned char)((size >> 8) & 0xff); + header[3] = (unsigned char)((size >> 16) & 0xff); + header[4] = (unsigned char)((size >> 24) & 0xff); + len = 5; + } + else if (self->proto >= 4) { + int i; -#if SIZEOF_SIZE_T > 4 - if (size > 0xffffffffUL) { - /* string too large */ + header[0] = BINUNICODE8; + for (i = 0; i < 8; i++) { + header[i+1] = (unsigned char)((size >> (8 * i)) & 0xff); + } + len = 9; + } + else { PyErr_SetString(PyExc_OverflowError, "cannot serialize a string larger than 4GiB"); return -1; } -#endif - - pdata[0] = BINUNICODE; - pdata[1] = (unsigned char)(size & 0xff); - pdata[2] = (unsigned char)((size >> 8) & 0xff); - pdata[3] = (unsigned char)((size >> 16) & 0xff); - pdata[4] = (unsigned char)((size >> 24) & 0xff); - if (_Pickler_Write(self, pdata, sizeof(pdata)) < 0) + if (_Pickler_Write(self, header, len) < 0) return -1; - if (_Pickler_Write(self, data, size) < 0) return -1; @@ -2598,6 +2839,214 @@ save_dict(PicklerObject *self, PyObject *obj) } static int +save_set(PicklerObject *self, PyObject *obj) +{ + PyObject *item; + int i; + Py_ssize_t set_size, ppos = 0; + Py_hash_t hash; + + const char empty_set_op = EMPTY_SET; + const char mark_op = MARK; + const char additems_op = ADDITEMS; + + if (self->proto < 4) { + PyObject *items; + PyObject *reduce_value; + int status; + + items = PySequence_List(obj); + if (items == NULL) { + return -1; + } + reduce_value = Py_BuildValue("(O(O))", (PyObject*)&PySet_Type, items); + Py_DECREF(items); + if (reduce_value == NULL) { + return -1; + } + /* save_reduce() will memoize the object automatically. */ + status = save_reduce(self, reduce_value, obj); + Py_DECREF(reduce_value); + return status; + } + + if (_Pickler_Write(self, &empty_set_op, 1) < 0) + return -1; + + if (memo_put(self, obj) < 0) + return -1; + + set_size = PySet_GET_SIZE(obj); + if (set_size == 0) + return 0; /* nothing to do */ + + /* Write in batches of BATCHSIZE. */ + do { + i = 0; + if (_Pickler_Write(self, &mark_op, 1) < 0) + return -1; + while (_PySet_NextEntry(obj, &ppos, &item, &hash)) { + if (save(self, item, 0) < 0) + return -1; + if (++i == BATCHSIZE) + break; + } + if (_Pickler_Write(self, &additems_op, 1) < 0) + return -1; + if (PySet_GET_SIZE(obj) != set_size) { + PyErr_Format( + PyExc_RuntimeError, + "set changed size during iteration"); + return -1; + } + } while (i == BATCHSIZE); + + return 0; +} + +static int +save_frozenset(PicklerObject *self, PyObject *obj) +{ + PyObject *iter; + + const char mark_op = MARK; + const char frozenset_op = FROZENSET; + + if (self->fast && !fast_save_enter(self, obj)) + return -1; + + if (self->proto < 4) { + PyObject *items; + PyObject *reduce_value; + int status; + + items = PySequence_List(obj); + if (items == NULL) { + return -1; + } + reduce_value = Py_BuildValue("(O(O))", (PyObject*)&PyFrozenSet_Type, + items); + Py_DECREF(items); + if (reduce_value == NULL) { + return -1; + } + /* save_reduce() will memoize the object automatically. */ + status = save_reduce(self, reduce_value, obj); + Py_DECREF(reduce_value); + return status; + } + + if (_Pickler_Write(self, &mark_op, 1) < 0) + return -1; + + iter = PyObject_GetIter(obj); + for (;;) { + PyObject *item; + + item = PyIter_Next(iter); + if (item == NULL) { + if (PyErr_Occurred()) { + Py_DECREF(iter); + return -1; + } + break; + } + if (save(self, item, 0) < 0) { + Py_DECREF(item); + Py_DECREF(iter); + return -1; + } + Py_DECREF(item); + } + Py_DECREF(iter); + + /* If the object is already in the memo, this means it is + recursive. In this case, throw away everything we put on the + stack, and fetch the object back from the memo. */ + if (PyMemoTable_Get(self->memo, obj)) { + const char pop_mark_op = POP_MARK; + + if (_Pickler_Write(self, &pop_mark_op, 1) < 0) + return -1; + if (memo_get(self, obj) < 0) + return -1; + return 0; + } + + if (_Pickler_Write(self, &frozenset_op, 1) < 0) + return -1; + if (memo_put(self, obj) < 0) + return -1; + + return 0; +} + +static int +fix_imports(PyObject **module_name, PyObject **global_name) +{ + PyObject *key; + PyObject *item; + + key = PyTuple_Pack(2, *module_name, *global_name); + if (key == NULL) + return -1; + item = PyDict_GetItemWithError(name_mapping_3to2, key); + Py_DECREF(key); + if (item) { + PyObject *fixed_module_name; + PyObject *fixed_global_name; + + if (!PyTuple_Check(item) || PyTuple_GET_SIZE(item) != 2) { + PyErr_Format(PyExc_RuntimeError, + "_compat_pickle.REVERSE_NAME_MAPPING values " + "should be 2-tuples, not %.200s", + Py_TYPE(item)->tp_name); + return -1; + } + fixed_module_name = PyTuple_GET_ITEM(item, 0); + fixed_global_name = PyTuple_GET_ITEM(item, 1); + if (!PyUnicode_Check(fixed_module_name) || + !PyUnicode_Check(fixed_global_name)) { + PyErr_Format(PyExc_RuntimeError, + "_compat_pickle.REVERSE_NAME_MAPPING values " + "should be pairs of str, not (%.200s, %.200s)", + Py_TYPE(fixed_module_name)->tp_name, + Py_TYPE(fixed_global_name)->tp_name); + return -1; + } + + Py_CLEAR(*module_name); + Py_CLEAR(*global_name); + Py_INCREF(fixed_module_name); + Py_INCREF(fixed_global_name); + *module_name = fixed_module_name; + *global_name = fixed_global_name; + } + else if (PyErr_Occurred()) { + return -1; + } + + item = PyDict_GetItemWithError(import_mapping_3to2, *module_name); + if (item) { + if (!PyUnicode_Check(item)) { + PyErr_Format(PyExc_RuntimeError, + "_compat_pickle.REVERSE_IMPORT_MAPPING values " + "should be strings, not %.200s", + Py_TYPE(item)->tp_name); + return -1; + } + Py_CLEAR(*module_name); + Py_INCREF(item); + *module_name = item; + } + else if (PyErr_Occurred()) { + return -1; + } + + return 0; +} + +static int save_global(PicklerObject *self, PyObject *obj, PyObject *name) { PyObject *global_name = NULL; @@ -2605,20 +3054,32 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) PyObject *module = NULL; PyObject *cls; int status = 0; + _Py_IDENTIFIER(__name__); + _Py_IDENTIFIER(__qualname__); const char global_op = GLOBAL; if (name) { + Py_INCREF(name); global_name = name; - Py_INCREF(global_name); } else { - global_name = _PyObject_GetAttrId(obj, &PyId___name__); - if (global_name == NULL) - goto error; + if (self->proto >= 4) { + global_name = _PyObject_GetAttrId(obj, &PyId___qualname__); + if (global_name == NULL) { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) + goto error; + PyErr_Clear(); + } + } + if (global_name == NULL) { + global_name = _PyObject_GetAttrId(obj, &PyId___name__); + if (global_name == NULL) + goto error; + } } - module_name = whichmodule(obj, global_name); + module_name = whichmodule(obj, global_name, self->proto >= 4); if (module_name == NULL) goto error; @@ -2637,11 +3098,11 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) obj, module_name); goto error; } - cls = PyObject_GetAttr(module, global_name); + cls = getattribute(module, global_name, self->proto >= 4); if (cls == NULL) { PyErr_Format(PicklingError, - "Can't pickle %R: attribute lookup %S.%S failed", - obj, module_name, global_name); + "Can't pickle %R: attribute lookup %S on %S failed", + obj, global_name, module_name); goto error; } if (cls != obj) { @@ -2715,120 +3176,82 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) goto error; } else { - /* Generate a normal global opcode if we are using a pickle - protocol <= 2, or if the object is not registered in the - extension registry. */ - PyObject *encoded; - PyObject *(*unicode_encoder)(PyObject *); - gen_global: - if (_Pickler_Write(self, &global_op, 1) < 0) - goto error; + if (self->proto >= 4) { + const char stack_global_op = STACK_GLOBAL; + + save(self, module_name, 0); + save(self, global_name, 0); - /* Since Python 3.0 now supports non-ASCII identifiers, we encode both - the module name and the global name using UTF-8. We do so only when - we are using the pickle protocol newer than version 3. This is to - ensure compatibility with older Unpickler running on Python 2.x. */ - if (self->proto >= 3) { - unicode_encoder = PyUnicode_AsUTF8String; + if (_Pickler_Write(self, &stack_global_op, 1) < 0) + goto error; } else { - unicode_encoder = PyUnicode_AsASCIIString; - } + /* Generate a normal global opcode if we are using a pickle + protocol < 4, or if the object is not registered in the + extension registry. */ + PyObject *encoded; + PyObject *(*unicode_encoder)(PyObject *); - /* For protocol < 3 and if the user didn't request against doing so, - we convert module names to the old 2.x module names. */ - if (self->fix_imports) { - PyObject *key; - PyObject *item; - - key = PyTuple_Pack(2, module_name, global_name); - if (key == NULL) + if (_Pickler_Write(self, &global_op, 1) < 0) goto error; - item = PyDict_GetItemWithError(name_mapping_3to2, key); - Py_DECREF(key); - if (item) { - if (!PyTuple_Check(item) || PyTuple_GET_SIZE(item) != 2) { - PyErr_Format(PyExc_RuntimeError, - "_compat_pickle.REVERSE_NAME_MAPPING values " - "should be 2-tuples, not %.200s", - Py_TYPE(item)->tp_name); - goto error; - } - Py_CLEAR(module_name); - Py_CLEAR(global_name); - module_name = PyTuple_GET_ITEM(item, 0); - global_name = PyTuple_GET_ITEM(item, 1); - if (!PyUnicode_Check(module_name) || - !PyUnicode_Check(global_name)) { - PyErr_Format(PyExc_RuntimeError, - "_compat_pickle.REVERSE_NAME_MAPPING values " - "should be pairs of str, not (%.200s, %.200s)", - Py_TYPE(module_name)->tp_name, - Py_TYPE(global_name)->tp_name); + + /* For protocol < 3 and if the user didn't request against doing + so, we convert module names to the old 2.x module names. */ + if (self->proto < 3 && self->fix_imports) { + if (fix_imports(&module_name, &global_name) < 0) { goto error; } - Py_INCREF(module_name); - Py_INCREF(global_name); } - else if (PyErr_Occurred()) { + + /* Since Python 3.0 now supports non-ASCII identifiers, we encode + both the module name and the global name using UTF-8. We do so + only when we are using the pickle protocol newer than version + 3. This is to ensure compatibility with older Unpickler running + on Python 2.x. */ + if (self->proto == 3) { + unicode_encoder = PyUnicode_AsUTF8String; + } + else { + unicode_encoder = PyUnicode_AsASCIIString; + } + encoded = unicode_encoder(module_name); + if (encoded == NULL) { + if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) + PyErr_Format(PicklingError, + "can't pickle module identifier '%S' using " + "pickle protocol %i", + module_name, self->proto); goto error; } - - item = PyDict_GetItemWithError(import_mapping_3to2, module_name); - if (item) { - if (!PyUnicode_Check(item)) { - PyErr_Format(PyExc_RuntimeError, - "_compat_pickle.REVERSE_IMPORT_MAPPING values " - "should be strings, not %.200s", - Py_TYPE(item)->tp_name); - goto error; - } - Py_CLEAR(module_name); - module_name = item; - Py_INCREF(module_name); - } - else if (PyErr_Occurred()) { + if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), + PyBytes_GET_SIZE(encoded)) < 0) { + Py_DECREF(encoded); goto error; } - } - - /* Save the name of the module. */ - encoded = unicode_encoder(module_name); - if (encoded == NULL) { - if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) - PyErr_Format(PicklingError, - "can't pickle module identifier '%S' using " - "pickle protocol %i", module_name, self->proto); - goto error; - } - if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), - PyBytes_GET_SIZE(encoded)) < 0) { Py_DECREF(encoded); - goto error; - } - Py_DECREF(encoded); - if(_Pickler_Write(self, "\n", 1) < 0) - goto error; + if(_Pickler_Write(self, "\n", 1) < 0) + goto error; - /* Save the name of the module. */ - encoded = unicode_encoder(global_name); - if (encoded == NULL) { - if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) - PyErr_Format(PicklingError, - "can't pickle global identifier '%S' using " - "pickle protocol %i", global_name, self->proto); - goto error; - } - if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), - PyBytes_GET_SIZE(encoded)) < 0) { + /* Save the name of the module. */ + encoded = unicode_encoder(global_name); + if (encoded == NULL) { + if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) + PyErr_Format(PicklingError, + "can't pickle global identifier '%S' using " + "pickle protocol %i", + global_name, self->proto); + goto error; + } + if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), + PyBytes_GET_SIZE(encoded)) < 0) { + Py_DECREF(encoded); + goto error; + } Py_DECREF(encoded); - goto error; + if (_Pickler_Write(self, "\n", 1) < 0) + goto error; } - Py_DECREF(encoded); - if(_Pickler_Write(self, "\n", 1) < 0) - goto error; - /* Memoize the object. */ if (memo_put(self, obj) < 0) goto error; @@ -2927,14 +3350,9 @@ static PyObject * get_class(PyObject *obj) { PyObject *cls; - static PyObject *str_class; + _Py_IDENTIFIER(__class__); - if (str_class == NULL) { - str_class = PyUnicode_InternFromString("__class__"); - if (str_class == NULL) - return NULL; - } - cls = PyObject_GetAttr(obj, str_class); + cls = _PyObject_GetAttrId(obj, &PyId___class__); if (cls == NULL) { if (PyErr_ExceptionMatches(PyExc_AttributeError)) { PyErr_Clear(); @@ -2957,12 +3375,12 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj) PyObject *listitems = Py_None; PyObject *dictitems = Py_None; Py_ssize_t size; - - int use_newobj = self->proto >= 2; + int use_newobj = 0, use_newobj_ex = 0; const char reduce_op = REDUCE; const char build_op = BUILD; const char newobj_op = NEWOBJ; + const char newobj_ex_op = NEWOBJ_EX; size = PyTuple_Size(args); if (size < 2 || size > 5) { @@ -3007,33 +3425,75 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj) return -1; } - /* Protocol 2 special case: if callable's name is __newobj__, use - NEWOBJ. */ - if (use_newobj) { - static PyObject *newobj_str = NULL; + if (self->proto >= 2) { PyObject *name; - - if (newobj_str == NULL) { - newobj_str = PyUnicode_InternFromString("__newobj__"); - if (newobj_str == NULL) - return -1; - } + _Py_IDENTIFIER(__name__); name = _PyObject_GetAttrId(callable, &PyId___name__); if (name == NULL) { - if (PyErr_ExceptionMatches(PyExc_AttributeError)) - PyErr_Clear(); - else + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { return -1; - use_newobj = 0; + } + PyErr_Clear(); + } + else if (self->proto >= 4) { + _Py_IDENTIFIER(__newobj_ex__); + use_newobj_ex = PyUnicode_Check(name) && + PyUnicode_Compare( + name, _PyUnicode_FromId(&PyId___newobj_ex__)) == 0; + Py_DECREF(name); } else { + _Py_IDENTIFIER(__newobj__); use_newobj = PyUnicode_Check(name) && - PyUnicode_Compare(name, newobj_str) == 0; + PyUnicode_Compare( + name, _PyUnicode_FromId(&PyId___newobj__)) == 0; Py_DECREF(name); } } - if (use_newobj) { + + if (use_newobj_ex) { + PyObject *cls; + PyObject *args; + PyObject *kwargs; + + if (Py_SIZE(argtup) != 3) { + PyErr_Format(PicklingError, + "length of the NEWOBJ_EX argument tuple must be " + "exactly 3, not %zd", Py_SIZE(argtup)); + return -1; + } + + cls = PyTuple_GET_ITEM(argtup, 0); + if (!PyType_Check(cls)) { + PyErr_Format(PicklingError, + "first item from NEWOBJ_EX argument tuple must " + "be a class, not %.200s", Py_TYPE(cls)->tp_name); + return -1; + } + args = PyTuple_GET_ITEM(argtup, 1); + if (!PyTuple_Check(args)) { + PyErr_Format(PicklingError, + "second item from NEWOBJ_EX argument tuple must " + "be a tuple, not %.200s", Py_TYPE(args)->tp_name); + return -1; + } + kwargs = PyTuple_GET_ITEM(argtup, 2); + if (!PyDict_Check(kwargs)) { + PyErr_Format(PicklingError, + "third item from NEWOBJ_EX argument tuple must " + "be a dict, not %.200s", Py_TYPE(kwargs)->tp_name); + return -1; + } + + if (save(self, cls, 0) < 0 || + save(self, args, 0) < 0 || + save(self, kwargs, 0) < 0 || + _Pickler_Write(self, &newobj_ex_op, 1) < 0) { + return -1; + } + } + else if (use_newobj) { PyObject *cls; PyObject *newargtup; PyObject *obj_class; @@ -3117,8 +3577,23 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj) the caller do not want to memoize the object. Not particularly useful, but that is to mimic the behavior save_reduce() in pickle.py when obj is None. */ - if (obj && memo_put(self, obj) < 0) - return -1; + if (obj != NULL) { + /* If the object is already in the memo, this means it is + recursive. In this case, throw away everything we put on the + stack, and fetch the object back from the memo. */ + if (PyMemoTable_Get(self->memo, obj)) { + const char pop_op = POP; + + if (_Pickler_Write(self, &pop_op, 1) < 0) + return -1; + if (memo_get(self, obj) < 0) + return -1; + + return 0; + } + else if (memo_put(self, obj) < 0) + return -1; + } if (listitems && batch_list(self, listitems) < 0) return -1; @@ -3136,6 +3611,34 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj) } static int +save_method(PicklerObject *self, PyObject *obj) +{ + PyObject *method_self = PyCFunction_GET_SELF(obj); + + if (method_self == NULL || PyModule_Check(method_self)) { + return save_global(self, obj, NULL); + } + else { + PyObject *builtins; + PyObject *getattr; + PyObject *reduce_value; + int status = -1; + _Py_IDENTIFIER(getattr); + + builtins = PyEval_GetBuiltins(); + getattr = _PyDict_GetItemId(builtins, &PyId_getattr); + reduce_value = \ + Py_BuildValue("O(Os)", getattr, method_self, + ((PyCFunctionObject *)obj)->m_ml->ml_name); + if (reduce_value != NULL) { + status = save_reduce(self, reduce_value, obj); + Py_DECREF(reduce_value); + } + return status; + } +} + +static int save(PicklerObject *self, PyObject *obj, int pers_save) { PyTypeObject *type; @@ -3213,6 +3716,14 @@ save(PicklerObject *self, PyObject *obj, int pers_save) status = save_dict(self, obj); goto done; } + else if (type == &PySet_Type) { + status = save_set(self, obj); + goto done; + } + else if (type == &PyFrozenSet_Type) { + status = save_frozenset(self, obj); + goto done; + } else if (type == &PyList_Type) { status = save_list(self, obj); goto done; @@ -3236,7 +3747,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save) } } else if (type == &PyCFunction_Type) { - status = save_global(self, obj, NULL); + status = save_method(self, obj); goto done; } @@ -3269,18 +3780,9 @@ save(PicklerObject *self, PyObject *obj, int pers_save) goto done; } else { - static PyObject *reduce_str = NULL; - static PyObject *reduce_ex_str = NULL; + _Py_IDENTIFIER(__reduce__); + _Py_IDENTIFIER(__reduce_ex__); - /* Cache the name of the reduce methods. */ - if (reduce_str == NULL) { - reduce_str = PyUnicode_InternFromString("__reduce__"); - if (reduce_str == NULL) - goto error; - reduce_ex_str = PyUnicode_InternFromString("__reduce_ex__"); - if (reduce_ex_str == NULL) - goto error; - } /* XXX: If the __reduce__ method is defined, __reduce_ex__ is automatically defined as __reduce__. While this is convenient, this @@ -3291,7 +3793,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save) don't actually have to check for a __reduce__ method. */ /* Check for a __reduce_ex__ method. */ - reduce_func = PyObject_GetAttr(obj, reduce_ex_str); + reduce_func = _PyObject_GetAttrId(obj, &PyId___reduce_ex__); if (reduce_func != NULL) { PyObject *proto; proto = PyLong_FromLong(self->proto); @@ -3305,7 +3807,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save) else goto error; /* Check for a __reduce__ method. */ - reduce_func = PyObject_GetAttr(obj, reduce_str); + reduce_func = _PyObject_GetAttrId(obj, &PyId___reduce__); if (reduce_func != NULL) { reduce_value = PyObject_Call(reduce_func, empty_tuple, NULL); } @@ -3338,6 +3840,8 @@ save(PicklerObject *self, PyObject *obj, int pers_save) status = -1; } done: + if (status == 0) + status = _Pickler_OpcodeBoundary(self); Py_LeaveRecursiveCall(); Py_XDECREF(reduce_func); Py_XDECREF(reduce_value); @@ -3358,6 +3862,8 @@ dump(PicklerObject *self, PyObject *obj) header[1] = (unsigned char)self->proto; if (_Pickler_Write(self, header, 2) < 0) return -1; + if (self->proto >= 4) + self->framing = 1; } if (save(self, obj, 0) < 0 || @@ -3478,9 +3984,9 @@ PyDoc_STRVAR(Pickler_doc, "This takes a binary file for writing a pickle data stream.\n" "\n" "The optional protocol argument tells the pickler to use the\n" -"given protocol; supported protocols are 0, 1, 2, 3. The default\n" -"protocol is 3; a backward-incompatible protocol designed for\n" -"Python 3.0.\n" +"given protocol; supported protocols are 0, 1, 2, 3 and 4. The\n" +"default protocol is 3; a backward-incompatible protocol designed for\n" +"Python 3.\n" "\n" "Specifying a negative protocol version selects the highest\n" "protocol version supported. The higher the protocol used, the\n" @@ -3493,8 +3999,8 @@ PyDoc_STRVAR(Pickler_doc, "meets this interface.\n" "\n" "If fix_imports is True and protocol is less than 3, pickle will try to\n" -"map the new Python 3.x names to the old module names used in Python\n" -"2.x, so that the pickle data stream is readable with Python 2.x.\n"); +"map the new Python 3 names to the old module names used in Python 2,\n" +"so that the pickle data stream is readable with Python 2.\n"); static int Pickler_init(PicklerObject *self, PyObject *args, PyObject *kwds) @@ -3987,17 +4493,15 @@ load_bool(UnpicklerObject *self, PyObject *boolean) * as a C Py_ssize_t, or -1 if it's higher than PY_SSIZE_T_MAX. */ static Py_ssize_t -calc_binsize(char *bytes, int size) +calc_binsize(char *bytes, int nbytes) { unsigned char *s = (unsigned char *)bytes; + int i; size_t x = 0; - assert(size == 4); - - x = (size_t) s[0]; - x |= (size_t) s[1] << 8; - x |= (size_t) s[2] << 16; - x |= (size_t) s[3] << 24; + for (i = 0; i < nbytes; i++) { + x |= (size_t) s[i] << (8 * i); + } if (x > PY_SSIZE_T_MAX) return -1; @@ -4011,21 +4515,21 @@ calc_binsize(char *bytes, int size) * of x-platform bugs. */ static long -calc_binint(char *bytes, int size) +calc_binint(char *bytes, int nbytes) { unsigned char *s = (unsigned char *)bytes; - int i = size; + int i; long x = 0; - for (i = 0; i < size; i++) { - x |= (long)s[i] << (i * 8); + for (i = 0; i < nbytes; i++) { + x |= (long)s[i] << (8 * i); } /* Unlike BININT1 and BININT2, BININT (more accurately BININT4) * is signed, so on a box with longs bigger than 4 bytes we need * to extend a BININT's sign bit to the full width. */ - if (SIZEOF_LONG > 4 && size == 4) { + if (SIZEOF_LONG > 4 && nbytes == 4) { x |= -(x & (1L << 31)); } @@ -4233,49 +4737,27 @@ load_string(UnpicklerObject *self) } static int -load_binbytes(UnpicklerObject *self) +load_counted_binbytes(UnpicklerObject *self, int nbytes) { PyObject *bytes; - Py_ssize_t x; + Py_ssize_t size; char *s; - if (_Unpickler_Read(self, &s, 4) < 0) + if (_Unpickler_Read(self, &s, nbytes) < 0) return -1; - x = calc_binsize(s, 4); - if (x < 0) { + size = calc_binsize(s, nbytes); + if (size < 0) { PyErr_Format(PyExc_OverflowError, "BINBYTES exceeds system's maximum size of %zd bytes", PY_SSIZE_T_MAX); return -1; } - if (_Unpickler_Read(self, &s, x) < 0) - return -1; - bytes = PyBytes_FromStringAndSize(s, x); - if (bytes == NULL) - return -1; - - PDATA_PUSH(self->stack, bytes, -1); - return 0; -} - -static int -load_short_binbytes(UnpicklerObject *self) -{ - PyObject *bytes; - Py_ssize_t x; - char *s; - - if (_Unpickler_Read(self, &s, 1) < 0) - return -1; - - x = (unsigned char)s[0]; - - if (_Unpickler_Read(self, &s, x) < 0) + if (_Unpickler_Read(self, &s, size) < 0) return -1; - bytes = PyBytes_FromStringAndSize(s, x); + bytes = PyBytes_FromStringAndSize(s, size); if (bytes == NULL) return -1; @@ -4284,51 +4766,27 @@ load_short_binbytes(UnpicklerObject *self) } static int -load_binstring(UnpicklerObject *self) +load_counted_binstring(UnpicklerObject *self, int nbytes) { PyObject *str; - Py_ssize_t x; + Py_ssize_t size; char *s; - if (_Unpickler_Read(self, &s, 4) < 0) + if (_Unpickler_Read(self, &s, nbytes) < 0) return -1; - x = calc_binint(s, 4); - if (x < 0) { - PyErr_SetString(UnpicklingError, - "BINSTRING pickle has negative byte count"); + size = calc_binsize(s, nbytes); + if (size < 0) { + PyErr_Format(UnpicklingError, + "BINSTRING exceeds system's maximum size of %zd bytes", + PY_SSIZE_T_MAX); return -1; } - if (_Unpickler_Read(self, &s, x) < 0) - return -1; - - /* Convert Python 2.x strings to unicode. */ - str = PyUnicode_Decode(s, x, self->encoding, self->errors); - if (str == NULL) - return -1; - - PDATA_PUSH(self->stack, str, -1); - return 0; -} - -static int -load_short_binstring(UnpicklerObject *self) -{ - PyObject *str; - Py_ssize_t x; - char *s; - - if (_Unpickler_Read(self, &s, 1) < 0) - return -1; - - x = (unsigned char)s[0]; - - if (_Unpickler_Read(self, &s, x) < 0) + if (_Unpickler_Read(self, &s, size) < 0) return -1; - /* Convert Python 2.x strings to unicode. */ - str = PyUnicode_Decode(s, x, self->encoding, self->errors); + str = PyUnicode_Decode(s, size, self->encoding, self->errors); if (str == NULL) return -1; @@ -4357,16 +4815,16 @@ load_unicode(UnpicklerObject *self) } static int -load_binunicode(UnpicklerObject *self) +load_counted_binunicode(UnpicklerObject *self, int nbytes) { PyObject *str; Py_ssize_t size; char *s; - if (_Unpickler_Read(self, &s, 4) < 0) + if (_Unpickler_Read(self, &s, nbytes) < 0) return -1; - size = calc_binsize(s, 4); + size = calc_binsize(s, nbytes); if (size < 0) { PyErr_Format(PyExc_OverflowError, "BINUNICODE exceeds system's maximum size of %zd bytes", @@ -4374,7 +4832,6 @@ load_binunicode(UnpicklerObject *self) return -1; } - if (_Unpickler_Read(self, &s, size) < 0) return -1; @@ -4446,6 +4903,17 @@ load_empty_dict(UnpicklerObject *self) } static int +load_empty_set(UnpicklerObject *self) +{ + PyObject *set; + + if ((set = PySet_New(NULL)) == NULL) + return -1; + PDATA_PUSH(self->stack, set, -1); + return 0; +} + +static int load_list(UnpicklerObject *self) { PyObject *list; @@ -4487,6 +4955,29 @@ load_dict(UnpicklerObject *self) return 0; } +static int +load_frozenset(UnpicklerObject *self) +{ + PyObject *items; + PyObject *frozenset; + Py_ssize_t i; + + if ((i = marker(self)) < 0) + return -1; + + items = Pdata_poptuple(self->stack, i); + if (items == NULL) + return -1; + + frozenset = PyFrozenSet_New(items); + Py_DECREF(items); + if (frozenset == NULL) + return -1; + + PDATA_PUSH(self->stack, frozenset, -1); + return 0; +} + static PyObject * instantiate(PyObject *cls, PyObject *args) { @@ -4638,6 +5129,57 @@ load_newobj(UnpicklerObject *self) } static int +load_newobj_ex(UnpicklerObject *self) +{ + PyObject *cls, *args, *kwargs; + PyObject *obj; + + PDATA_POP(self->stack, kwargs); + if (kwargs == NULL) { + return -1; + } + PDATA_POP(self->stack, args); + if (args == NULL) { + Py_DECREF(kwargs); + return -1; + } + PDATA_POP(self->stack, cls); + if (cls == NULL) { + Py_DECREF(kwargs); + Py_DECREF(args); + return -1; + } + + if (!PyType_Check(cls)) { + Py_DECREF(kwargs); + Py_DECREF(args); + Py_DECREF(cls); + PyErr_Format(UnpicklingError, + "NEWOBJ_EX class argument must be a type, not %.200s", + Py_TYPE(cls)->tp_name); + return -1; + } + + if (((PyTypeObject *)cls)->tp_new == NULL) { + Py_DECREF(kwargs); + Py_DECREF(args); + Py_DECREF(cls); + PyErr_SetString(UnpicklingError, + "NEWOBJ_EX class argument doesn't have __new__"); + return -1; + } + obj = ((PyTypeObject *)cls)->tp_new((PyTypeObject *)cls, args, kwargs); + Py_DECREF(kwargs); + Py_DECREF(args); + Py_DECREF(cls); + if (obj == NULL) { + return -1; + } + PDATA_PUSH(self->stack, obj, -1); + return 0; +} + +static int load_global(UnpicklerObject *self) { PyObject *global = NULL; @@ -4674,6 +5216,31 @@ load_global(UnpicklerObject *self) } static int +load_stack_global(UnpicklerObject *self) +{ + PyObject *global; + PyObject *module_name; + PyObject *global_name; + + PDATA_POP(self->stack, global_name); + PDATA_POP(self->stack, module_name); + if (module_name == NULL || !PyUnicode_CheckExact(module_name) || + global_name == NULL || !PyUnicode_CheckExact(global_name)) { + PyErr_SetString(UnpicklingError, "STACK_GLOBAL requires str"); + Py_XDECREF(global_name); + Py_XDECREF(module_name); + return -1; + } + global = find_class(self, module_name, global_name); + Py_DECREF(global_name); + Py_DECREF(module_name); + if (global == NULL) + return -1; + PDATA_PUSH(self->stack, global, -1); + return 0; +} + +static int load_persid(UnpicklerObject *self) { PyObject *pid; @@ -5017,6 +5584,18 @@ load_long_binput(UnpicklerObject *self) } static int +load_memoize(UnpicklerObject *self) +{ + PyObject *value; + + if (Py_SIZE(self->stack) <= 0) + return stack_underflow(); + value = self->stack->data[Py_SIZE(self->stack) - 1]; + + return _Unpickler_MemoPut(self, self->memo_len, value); +} + +static int do_append(UnpicklerObject *self, Py_ssize_t x) { PyObject *value; @@ -5132,6 +5711,59 @@ load_setitems(UnpicklerObject *self) } static int +load_additems(UnpicklerObject *self) +{ + PyObject *set; + Py_ssize_t mark, len, i; + + mark = marker(self); + len = Py_SIZE(self->stack); + if (mark > len || mark <= 0) + return stack_underflow(); + if (len == mark) /* nothing to do */ + return 0; + + set = self->stack->data[mark - 1]; + + if (PySet_Check(set)) { + PyObject *items; + int status; + + items = Pdata_poptuple(self->stack, mark); + if (items == NULL) + return -1; + + status = _PySet_Update(set, items); + Py_DECREF(items); + return status; + } + else { + PyObject *add_func; + _Py_IDENTIFIER(add); + + add_func = _PyObject_GetAttrId(set, &PyId_add); + if (add_func == NULL) + return -1; + for (i = mark; i < len; i++) { + PyObject *result; + PyObject *item; + + item = self->stack->data[i]; + result = _Unpickler_FastCall(self, add_func, item); + if (result == NULL) { + Pdata_clear(self->stack, i + 1); + Py_SIZE(self->stack) = mark; + return -1; + } + Py_DECREF(result); + } + Py_SIZE(self->stack) = mark; + } + + return 0; +} + +static int load_build(UnpicklerObject *self) { PyObject *state, *inst, *slotstate; @@ -5325,6 +5957,7 @@ load_proto(UnpicklerObject *self) i = (unsigned char)s[0]; if (i <= HIGHEST_PROTOCOL) { self->proto = i; + self->framing = (self->proto >= 4); return 0; } @@ -5340,6 +5973,8 @@ load(UnpicklerObject *self) char *s; self->num_marks = 0; + self->proto = 0; + self->framing = 0; if (Py_SIZE(self->stack)) Pdata_clear(self->stack, 0); @@ -5365,13 +6000,16 @@ load(UnpicklerObject *self) OP_ARG(LONG4, load_counted_long, 4) OP(FLOAT, load_float) OP(BINFLOAT, load_binfloat) - OP(BINBYTES, load_binbytes) - OP(SHORT_BINBYTES, load_short_binbytes) - OP(BINSTRING, load_binstring) - OP(SHORT_BINSTRING, load_short_binstring) + OP_ARG(SHORT_BINBYTES, load_counted_binbytes, 1) + OP_ARG(BINBYTES, load_counted_binbytes, 4) + OP_ARG(BINBYTES8, load_counted_binbytes, 8) + OP_ARG(SHORT_BINSTRING, load_counted_binstring, 1) + OP_ARG(BINSTRING, load_counted_binstring, 4) OP(STRING, load_string) OP(UNICODE, load_unicode) - OP(BINUNICODE, load_binunicode) + OP_ARG(SHORT_BINUNICODE, load_counted_binunicode, 1) + OP_ARG(BINUNICODE, load_counted_binunicode, 4) + OP_ARG(BINUNICODE8, load_counted_binunicode, 8) OP_ARG(EMPTY_TUPLE, load_counted_tuple, 0) OP_ARG(TUPLE1, load_counted_tuple, 1) OP_ARG(TUPLE2, load_counted_tuple, 2) @@ -5381,10 +6019,15 @@ load(UnpicklerObject *self) OP(LIST, load_list) OP(EMPTY_DICT, load_empty_dict) OP(DICT, load_dict) + OP(EMPTY_SET, load_empty_set) + OP(ADDITEMS, load_additems) + OP(FROZENSET, load_frozenset) OP(OBJ, load_obj) OP(INST, load_inst) OP(NEWOBJ, load_newobj) + OP(NEWOBJ_EX, load_newobj_ex) OP(GLOBAL, load_global) + OP(STACK_GLOBAL, load_stack_global) OP(APPEND, load_append) OP(APPENDS, load_appends) OP(BUILD, load_build) @@ -5396,6 +6039,7 @@ load(UnpicklerObject *self) OP(BINPUT, load_binput) OP(LONG_BINPUT, load_long_binput) OP(PUT, load_put) + OP(MEMOIZE, load_memoize) OP(POP, load_pop) OP(POP_MARK, load_pop_mark) OP(SETITEM, load_setitem) @@ -5485,6 +6129,7 @@ Unpickler_find_class(UnpicklerObject *self, PyObject *args) PyObject *modules_dict; PyObject *module; PyObject *module_name, *global_name; + _Py_IDENTIFIER(modules); if (!PyArg_UnpackTuple(args, "find_class", 2, 2, &module_name, &global_name)) @@ -5556,11 +6201,11 @@ Unpickler_find_class(UnpicklerObject *self, PyObject *args) module = PyImport_Import(module_name); if (module == NULL) return NULL; - global = PyObject_GetAttr(module, global_name); + global = getattribute(module, global_name, self->proto >= 4); Py_DECREF(module); } else { - global = PyObject_GetAttr(module, global_name); + global = getattribute(module, global_name, self->proto >= 4); } return global; } @@ -5723,6 +6368,7 @@ Unpickler_init(UnpicklerObject *self, PyObject *args, PyObject *kwds) self->arg = NULL; self->proto = 0; + self->framing = 0; return 0; } diff --git a/Objects/classobject.c b/Objects/classobject.c index 27f7ef4..272f575 100644 --- a/Objects/classobject.c +++ b/Objects/classobject.c @@ -69,6 +69,30 @@ PyMethod_New(PyObject *func, PyObject *self) return (PyObject *)im; } +static PyObject * +method_reduce(PyMethodObject *im) +{ + PyObject *self = PyMethod_GET_SELF(im); + PyObject *func = PyMethod_GET_FUNCTION(im); + PyObject *builtins; + PyObject *getattr; + PyObject *funcname; + _Py_IDENTIFIER(getattr); + + funcname = _PyObject_GetAttrId(func, &PyId___name__); + if (funcname == NULL) { + return NULL; + } + builtins = PyEval_GetBuiltins(); + getattr = _PyDict_GetItemId(builtins, &PyId_getattr); + return Py_BuildValue("O(ON)", getattr, self, funcname); +} + +static PyMethodDef method_methods[] = { + {"__reduce__", (PyCFunction)method_reduce, METH_NOARGS, NULL}, + {NULL, NULL} +}; + /* Descriptors for PyMethod attributes */ /* im_func and im_self are stored in the PyMethod object */ @@ -367,7 +391,7 @@ PyTypeObject PyMethod_Type = { offsetof(PyMethodObject, im_weakreflist), /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + method_methods, /* tp_methods */ method_memberlist, /* tp_members */ method_getset, /* tp_getset */ 0, /* tp_base */ diff --git a/Objects/descrobject.c b/Objects/descrobject.c index d4f8048..da88f86 100644 --- a/Objects/descrobject.c +++ b/Objects/descrobject.c @@ -398,6 +398,24 @@ descr_get_qualname(PyDescrObject *descr) return descr->d_qualname; } +static PyObject * +descr_reduce(PyDescrObject *descr) +{ + PyObject *builtins; + PyObject *getattr; + _Py_IDENTIFIER(getattr); + + builtins = PyEval_GetBuiltins(); + getattr = _PyDict_GetItemId(builtins, &PyId_getattr); + return Py_BuildValue("O(OO)", getattr, PyDescr_TYPE(descr), + PyDescr_NAME(descr)); +} + +static PyMethodDef descr_methods[] = { + {"__reduce__", (PyCFunction)descr_reduce, METH_NOARGS, NULL}, + {NULL, NULL} +}; + static PyMemberDef descr_members[] = { {"__objclass__", T_OBJECT, offsetof(PyDescrObject, d_type), READONLY}, {"__name__", T_OBJECT, offsetof(PyDescrObject, d_name), READONLY}, @@ -494,7 +512,7 @@ PyTypeObject PyMethodDescr_Type = { 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + descr_methods, /* tp_methods */ descr_members, /* tp_members */ method_getset, /* tp_getset */ 0, /* tp_base */ @@ -532,7 +550,7 @@ PyTypeObject PyClassMethodDescr_Type = { 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + descr_methods, /* tp_methods */ descr_members, /* tp_members */ method_getset, /* tp_getset */ 0, /* tp_base */ @@ -569,7 +587,7 @@ PyTypeObject PyMemberDescr_Type = { 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + descr_methods, /* tp_methods */ descr_members, /* tp_members */ member_getset, /* tp_getset */ 0, /* tp_base */ @@ -643,7 +661,7 @@ PyTypeObject PyWrapperDescr_Type = { 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + descr_methods, /* tp_methods */ descr_members, /* tp_members */ wrapperdescr_getset, /* tp_getset */ 0, /* tp_base */ @@ -1085,6 +1103,23 @@ wrapper_repr(wrapperobject *wp) wp->self); } +static PyObject * +wrapper_reduce(wrapperobject *wp) +{ + PyObject *builtins; + PyObject *getattr; + _Py_IDENTIFIER(getattr); + + builtins = PyEval_GetBuiltins(); + getattr = _PyDict_GetItemId(builtins, &PyId_getattr); + return Py_BuildValue("O(OO)", getattr, wp->self, PyDescr_NAME(wp->descr)); +} + +static PyMethodDef wrapper_methods[] = { + {"__reduce__", (PyCFunction)wrapper_reduce, METH_NOARGS, NULL}, + {NULL, NULL} +}; + static PyMemberDef wrapper_members[] = { {"__self__", T_OBJECT, offsetof(wrapperobject, self), READONLY}, {0} @@ -1193,7 +1228,7 @@ PyTypeObject _PyMethodWrapper_Type = { 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + wrapper_methods, /* tp_methods */ wrapper_members, /* tp_members */ wrapper_getsets, /* tp_getset */ 0, /* tp_base */ diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 09f77fa..5e951de 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -3405,151 +3405,430 @@ import_copyreg(void) return cached_copyreg_module; } -static PyObject * -slotnames(PyObject *cls) +Py_LOCAL(PyObject *) +_PyType_GetSlotNames(PyTypeObject *cls) { - PyObject *clsdict; PyObject *copyreg; PyObject *slotnames; _Py_IDENTIFIER(__slotnames__); _Py_IDENTIFIER(_slotnames); - clsdict = ((PyTypeObject *)cls)->tp_dict; - slotnames = _PyDict_GetItemId(clsdict, &PyId___slotnames__); - if (slotnames != NULL && PyList_Check(slotnames)) { + assert(PyType_Check(cls)); + + /* Get the slot names from the cache in the class if possible. */ + slotnames = _PyDict_GetItemIdWithError(cls->tp_dict, &PyId___slotnames__); + if (slotnames != NULL) { + if (slotnames != Py_None && !PyList_Check(slotnames)) { + PyErr_Format(PyExc_TypeError, + "%.200s.__slotnames__ should be a list or None, " + "not %.200s", + cls->tp_name, Py_TYPE(slotnames)->tp_name); + return NULL; + } Py_INCREF(slotnames); return slotnames; } + else { + if (PyErr_Occurred()) { + return NULL; + } + /* The class does not have the slot names cached yet. */ + } copyreg = import_copyreg(); if (copyreg == NULL) return NULL; - slotnames = _PyObject_CallMethodId(copyreg, &PyId__slotnames, "O", cls); + /* Use _slotnames function from the copyreg module to find the slots + by this class and its bases. This function will cache the result + in __slotnames__. */ + slotnames = _PyObject_CallMethodIdObjArgs(copyreg, &PyId__slotnames, + cls, NULL); Py_DECREF(copyreg); - if (slotnames != NULL && - slotnames != Py_None && - !PyList_Check(slotnames)) - { + if (slotnames == NULL) + return NULL; + + if (slotnames != Py_None && !PyList_Check(slotnames)) { PyErr_SetString(PyExc_TypeError, - "copyreg._slotnames didn't return a list or None"); + "copyreg._slotnames didn't return a list or None"); Py_DECREF(slotnames); - slotnames = NULL; + return NULL; } return slotnames; } -static PyObject * -reduce_2(PyObject *obj) +Py_LOCAL(PyObject *) +_PyObject_GetState(PyObject *obj) { - PyObject *cls, *getnewargs; - PyObject *args = NULL, *args2 = NULL; - PyObject *getstate = NULL, *state = NULL, *names = NULL; - PyObject *slots = NULL, *listitems = NULL, *dictitems = NULL; - PyObject *copyreg = NULL, *newobj = NULL, *res = NULL; - Py_ssize_t i, n; - _Py_IDENTIFIER(__getnewargs__); + PyObject *state; + PyObject *getstate; _Py_IDENTIFIER(__getstate__); - _Py_IDENTIFIER(__newobj__); - cls = (PyObject *) Py_TYPE(obj); + getstate = _PyObject_GetAttrId(obj, &PyId___getstate__); + if (getstate == NULL) { + PyObject *slotnames; - getnewargs = _PyObject_GetAttrId(obj, &PyId___getnewargs__); - if (getnewargs != NULL) { - args = PyObject_CallObject(getnewargs, NULL); - Py_DECREF(getnewargs); - if (args != NULL && !PyTuple_Check(args)) { - PyErr_Format(PyExc_TypeError, - "__getnewargs__ should return a tuple, " - "not '%.200s'", Py_TYPE(args)->tp_name); - goto end; + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return NULL; } - } - else { PyErr_Clear(); - args = PyTuple_New(0); - } - if (args == NULL) - goto end; - getstate = _PyObject_GetAttrId(obj, &PyId___getstate__); - if (getstate != NULL) { - state = PyObject_CallObject(getstate, NULL); - Py_DECREF(getstate); - if (state == NULL) - goto end; - } - else { - PyObject **dict; - PyErr_Clear(); - dict = _PyObject_GetDictPtr(obj); - if (dict && *dict) - state = *dict; - else - state = Py_None; - Py_INCREF(state); - names = slotnames(cls); - if (names == NULL) - goto end; - if (names != Py_None && PyList_GET_SIZE(names) > 0) { - assert(PyList_Check(names)); + { + PyObject **dict; + dict = _PyObject_GetDictPtr(obj); + /* It is possible that the object's dict is not initialized + yet. In this case, we will return None for the state. + We also return None if the dict is empty to make the behavior + consistent regardless whether the dict was initialized or not. + This make unit testing easier. */ + if (dict != NULL && *dict != NULL && PyDict_Size(*dict) > 0) { + state = *dict; + } + else { + state = Py_None; + } + Py_INCREF(state); + } + + slotnames = _PyType_GetSlotNames(Py_TYPE(obj)); + if (slotnames == NULL) { + Py_DECREF(state); + return NULL; + } + + assert(slotnames == Py_None || PyList_Check(slotnames)); + if (slotnames != Py_None && Py_SIZE(slotnames) > 0) { + PyObject *slots; + Py_ssize_t slotnames_size, i; + slots = PyDict_New(); - if (slots == NULL) - goto end; - n = 0; - /* Can't pre-compute the list size; the list - is stored on the class so accessible to other - threads, which may be run by DECREF */ - for (i = 0; i < PyList_GET_SIZE(names); i++) { + if (slots == NULL) { + Py_DECREF(slotnames); + Py_DECREF(state); + return NULL; + } + + slotnames_size = Py_SIZE(slotnames); + for (i = 0; i < slotnames_size; i++) { PyObject *name, *value; - name = PyList_GET_ITEM(names, i); + + name = PyList_GET_ITEM(slotnames, i); value = PyObject_GetAttr(obj, name); - if (value == NULL) + if (value == NULL) { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + goto error; + } + /* It is not an error if the attribute is not present. */ PyErr_Clear(); + } else { - int err = PyDict_SetItem(slots, name, - value); + int err = PyDict_SetItem(slots, name, value); Py_DECREF(value); - if (err) - goto end; - n++; + if (err) { + goto error; + } + } + + /* The list is stored on the class so it may mutates while we + iterate over it */ + if (slotnames_size != Py_SIZE(slotnames)) { + PyErr_Format(PyExc_RuntimeError, + "__slotsname__ changed size during iteration"); + goto error; + } + + /* We handle errors within the loop here. */ + if (0) { + error: + Py_DECREF(slotnames); + Py_DECREF(slots); + Py_DECREF(state); + return NULL; } } - if (n) { - state = Py_BuildValue("(NO)", state, slots); - if (state == NULL) - goto end; + + /* If we found some slot attributes, pack them in a tuple along + the orginal attribute dictionary. */ + if (PyDict_Size(slots) > 0) { + PyObject *state2; + + state2 = PyTuple_Pack(2, state, slots); + Py_DECREF(state); + if (state2 == NULL) { + Py_DECREF(slotnames); + Py_DECREF(slots); + return NULL; + } + state = state2; } + Py_DECREF(slots); + } + Py_DECREF(slotnames); + } + else { /* getstate != NULL */ + state = PyObject_CallObject(getstate, NULL); + Py_DECREF(getstate); + if (state == NULL) + return NULL; + } + + return state; +} + +Py_LOCAL(int) +_PyObject_GetNewArguments(PyObject *obj, PyObject **args, PyObject **kwargs) +{ + PyObject *getnewargs, *getnewargs_ex; + _Py_IDENTIFIER(__getnewargs_ex__); + _Py_IDENTIFIER(__getnewargs__); + + if (args == NULL || kwargs == NULL) { + PyErr_BadInternalCall(); + return -1; + } + + /* We first attempt to fetch the arguments for __new__ by calling + __getnewargs_ex__ on the object. */ + getnewargs_ex = _PyObject_GetAttrId(obj, &PyId___getnewargs_ex__); + if (getnewargs_ex != NULL) { + PyObject *newargs = PyObject_CallObject(getnewargs_ex, NULL); + Py_DECREF(getnewargs_ex); + if (newargs == NULL) { + return -1; + } + if (!PyTuple_Check(newargs)) { + PyErr_Format(PyExc_TypeError, + "__getnewargs_ex__ should return a tuple, " + "not '%.200s'", Py_TYPE(newargs)->tp_name); + Py_DECREF(newargs); + return -1; + } + if (Py_SIZE(newargs) != 2) { + PyErr_Format(PyExc_ValueError, + "__getnewargs_ex__ should return a tuple of " + "length 2, not %zd", Py_SIZE(newargs)); + Py_DECREF(newargs); + return -1; + } + *args = PyTuple_GET_ITEM(newargs, 0); + Py_INCREF(*args); + *kwargs = PyTuple_GET_ITEM(newargs, 1); + Py_INCREF(*kwargs); + Py_DECREF(newargs); + + /* XXX We should perhaps allow None to be passed here. */ + if (!PyTuple_Check(*args)) { + PyErr_Format(PyExc_TypeError, + "first item of the tuple returned by " + "__getnewargs_ex__ must be a tuple, not '%.200s'", + Py_TYPE(*args)->tp_name); + Py_CLEAR(*args); + Py_CLEAR(*kwargs); + return -1; + } + if (!PyDict_Check(*kwargs)) { + PyErr_Format(PyExc_TypeError, + "second item of the tuple returned by " + "__getnewargs_ex__ must be a dict, not '%.200s'", + Py_TYPE(*kwargs)->tp_name); + Py_CLEAR(*args); + Py_CLEAR(*kwargs); + return -1; + } + return 0; + } else { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return -1; + } + PyErr_Clear(); + } + + /* The object does not have __getnewargs_ex__ so we fallback on using + __getnewargs__ instead. */ + getnewargs = _PyObject_GetAttrId(obj, &PyId___getnewargs__); + if (getnewargs != NULL) { + *args = PyObject_CallObject(getnewargs, NULL); + Py_DECREF(getnewargs); + if (*args == NULL) { + return -1; + } + if (!PyTuple_Check(*args)) { + PyErr_Format(PyExc_TypeError, + "__getnewargs__ should return a tuple, " + "not '%.200s'", Py_TYPE(*args)->tp_name); + Py_CLEAR(*args); + return -1; + } + *kwargs = NULL; + return 0; + } else { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return -1; } + PyErr_Clear(); + } + + /* The object does not have __getnewargs_ex__ and __getnewargs__. This may + means __new__ does not takes any arguments on this object, or that the + object does not implement the reduce protocol for pickling or + copying. */ + *args = NULL; + *kwargs = NULL; + return 0; +} + +Py_LOCAL(int) +_PyObject_GetItemsIter(PyObject *obj, PyObject **listitems, + PyObject **dictitems) +{ + if (listitems == NULL || dictitems == NULL) { + PyErr_BadInternalCall(); + return -1; } if (!PyList_Check(obj)) { - listitems = Py_None; - Py_INCREF(listitems); + *listitems = Py_None; + Py_INCREF(*listitems); } else { - listitems = PyObject_GetIter(obj); + *listitems = PyObject_GetIter(obj); if (listitems == NULL) - goto end; + return -1; } if (!PyDict_Check(obj)) { - dictitems = Py_None; - Py_INCREF(dictitems); + *dictitems = Py_None; + Py_INCREF(*dictitems); } else { + PyObject *items; _Py_IDENTIFIER(items); - PyObject *items = _PyObject_CallMethodId(obj, &PyId_items, ""); - if (items == NULL) - goto end; - dictitems = PyObject_GetIter(items); + + items = _PyObject_CallMethodIdObjArgs(obj, &PyId_items, NULL); + if (items == NULL) { + Py_CLEAR(*listitems); + return -1; + } + *dictitems = PyObject_GetIter(items); Py_DECREF(items); - if (dictitems == NULL) - goto end; + if (*dictitems == NULL) { + Py_CLEAR(*listitems); + return -1; + } + } + + assert(*listitems != NULL && *dictitems != NULL); + + return 0; +} + +static PyObject * +reduce_4(PyObject *obj) +{ + PyObject *args = NULL, *kwargs = NULL; + PyObject *copyreg; + PyObject *newobj, *newargs, *state, *listitems, *dictitems; + PyObject *result; + _Py_IDENTIFIER(__newobj_ex__); + + if (_PyObject_GetNewArguments(obj, &args, &kwargs) < 0) { + return NULL; + } + if (args == NULL) { + args = PyTuple_New(0); + if (args == NULL) + return NULL; + } + if (kwargs == NULL) { + kwargs = PyDict_New(); + if (kwargs == NULL) + return NULL; } copyreg = import_copyreg(); + if (copyreg == NULL) { + Py_DECREF(args); + Py_DECREF(kwargs); + return NULL; + } + newobj = _PyObject_GetAttrId(copyreg, &PyId___newobj_ex__); + Py_DECREF(copyreg); + if (newobj == NULL) { + Py_DECREF(args); + Py_DECREF(kwargs); + return NULL; + } + newargs = PyTuple_Pack(3, Py_TYPE(obj), args, kwargs); + Py_DECREF(args); + Py_DECREF(kwargs); + if (newargs == NULL) { + Py_DECREF(newobj); + return NULL; + } + state = _PyObject_GetState(obj); + if (state == NULL) { + Py_DECREF(newobj); + Py_DECREF(newargs); + return NULL; + } + if (_PyObject_GetItemsIter(obj, &listitems, &dictitems) < 0) { + Py_DECREF(newobj); + Py_DECREF(newargs); + Py_DECREF(state); + return NULL; + } + + result = PyTuple_Pack(5, newobj, newargs, state, listitems, dictitems); + Py_DECREF(newobj); + Py_DECREF(newargs); + Py_DECREF(state); + Py_DECREF(listitems); + Py_DECREF(dictitems); + return result; +} + +static PyObject * +reduce_2(PyObject *obj) +{ + PyObject *cls; + PyObject *args = NULL, *args2 = NULL, *kwargs = NULL; + PyObject *state = NULL, *listitems = NULL, *dictitems = NULL; + PyObject *copyreg = NULL, *newobj = NULL, *res = NULL; + Py_ssize_t i, n; + _Py_IDENTIFIER(__newobj__); + + if (_PyObject_GetNewArguments(obj, &args, &kwargs) < 0) { + return NULL; + } + if (args == NULL) { + assert(kwargs == NULL); + args = PyTuple_New(0); + if (args == NULL) { + return NULL; + } + } + else if (kwargs != NULL) { + if (PyDict_Size(kwargs) > 0) { + PyErr_SetString(PyExc_ValueError, + "must use protocol 4 or greater to copy this " + "object; since __getnewargs_ex__ returned " + "keyword arguments."); + Py_DECREF(args); + Py_DECREF(kwargs); + return NULL; + } + Py_CLEAR(kwargs); + } + + state = _PyObject_GetState(obj); + if (state == NULL) + goto end; + + if (_PyObject_GetItemsIter(obj, &listitems, &dictitems) < 0) + goto end; + + copyreg = import_copyreg(); if (copyreg == NULL) goto end; newobj = _PyObject_GetAttrId(copyreg, &PyId___newobj__); @@ -3560,6 +3839,7 @@ reduce_2(PyObject *obj) args2 = PyTuple_New(n+1); if (args2 == NULL) goto end; + cls = (PyObject *) Py_TYPE(obj); Py_INCREF(cls); PyTuple_SET_ITEM(args2, 0, cls); for (i = 0; i < n; i++) { @@ -3573,9 +3853,7 @@ reduce_2(PyObject *obj) end: Py_XDECREF(args); Py_XDECREF(args2); - Py_XDECREF(slots); Py_XDECREF(state); - Py_XDECREF(names); Py_XDECREF(listitems); Py_XDECREF(dictitems); Py_XDECREF(copyreg); @@ -3603,7 +3881,9 @@ _common_reduce(PyObject *self, int proto) { PyObject *copyreg, *res; - if (proto >= 2) + if (proto >= 4) + return reduce_4(self); + else if (proto >= 2) return reduce_2(self); copyreg = import_copyreg(); -- cgit v0.12