summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAntoine Pitrou <solipsis@pitrou.net>2013-11-23 17:59:12 (GMT)
committerAntoine Pitrou <solipsis@pitrou.net>2013-11-23 17:59:12 (GMT)
commitc9dc4a2a8a6dcfe1674685bea4a4af935c0e37ca (patch)
treeafbde5318538e73815668dc73a0fb91dfb88ca61
parent95401c5f6b9f07b094924559177c9b30a1c38998 (diff)
downloadcpython-c9dc4a2a8a6dcfe1674685bea4a4af935c0e37ca.zip
cpython-c9dc4a2a8a6dcfe1674685bea4a4af935c0e37ca.tar.gz
cpython-c9dc4a2a8a6dcfe1674685bea4a4af935c0e37ca.tar.bz2
Issue #17810: Implement PEP 3154, pickle protocol 4.
Most of the work is by Alexandre.
-rw-r--r--Doc/library/pickle.rst35
-rw-r--r--Doc/whatsnew/3.4.rst15
-rw-r--r--Lib/copyreg.py6
-rw-r--r--Lib/pickle.py582
-rw-r--r--Lib/pickletools.py471
-rw-r--r--Lib/test/pickletester.py487
-rw-r--r--Lib/test/test_descr.py605
-rw-r--r--Misc/NEWS2
-rw-r--r--Modules/_pickle.c1380
-rw-r--r--Objects/classobject.c26
-rw-r--r--Objects/descrobject.c45
-rw-r--r--Objects/typeobject.c466
12 files changed, 3123 insertions, 997 deletions
diff --git a/Doc/library/pickle.rst b/Doc/library/pickle.rst
index 9404a47..49ec9c1 100644
--- a/Doc/library/pickle.rst
+++ b/Doc/library/pickle.rst
@@ -459,12 +459,29 @@ implementation of this behaviour::
Classes can alter the default behaviour by providing one or several special
methods:
+.. method:: object.__getnewargs_ex__()
+
+ In protocols 4 and newer, classes that implements the
+ :meth:`__getnewargs_ex__` method can dictate the values passed to the
+ :meth:`__new__` method upon unpickling. The method must return a pair
+ ``(args, kwargs)`` where *args* is a tuple of positional arguments
+ and *kwargs* a dictionary of named arguments for constructing the
+ object. Those will be passed to the :meth:`__new__` method upon
+ unpickling.
+
+ You should implement this method if the :meth:`__new__` method of your
+ class requires keyword-only arguments. Otherwise, it is recommended for
+ compatibility to implement :meth:`__getnewargs__`.
+
+
.. method:: object.__getnewargs__()
- In protocol 2 and newer, classes that implements the :meth:`__getnewargs__`
- method can dictate the values passed to the :meth:`__new__` method upon
- unpickling. This is often needed for classes whose :meth:`__new__` method
- requires arguments.
+ This method serve a similar purpose as :meth:`__getnewargs_ex__` but
+ for protocols 2 and newer. It must return a tuple of arguments `args`
+ which will be passed to the :meth:`__new__` method upon unpickling.
+
+ In protocols 4 and newer, :meth:`__getnewargs__` will not be called if
+ :meth:`__getnewargs_ex__` is defined.
.. method:: object.__getstate__()
@@ -496,10 +513,10 @@ the methods :meth:`__getstate__` and :meth:`__setstate__`.
At unpickling time, some methods like :meth:`__getattr__`,
:meth:`__getattribute__`, or :meth:`__setattr__` may be called upon the
- instance. In case those methods rely on some internal invariant being true,
- the type should implement :meth:`__getnewargs__` to establish such an
- invariant; otherwise, neither :meth:`__new__` nor :meth:`__init__` will be
- called.
+ instance. In case those methods rely on some internal invariant being
+ true, the type should implement :meth:`__getnewargs__` or
+ :meth:`__getnewargs_ex__` to establish such an invariant; otherwise,
+ neither :meth:`__new__` nor :meth:`__init__` will be called.
.. index:: pair: copy; protocol
@@ -511,7 +528,7 @@ objects. [#]_
Although powerful, implementing :meth:`__reduce__` directly in your classes is
error prone. For this reason, class designers should use the high-level
-interface (i.e., :meth:`__getnewargs__`, :meth:`__getstate__` and
+interface (i.e., :meth:`__getnewargs_ex__`, :meth:`__getstate__` and
:meth:`__setstate__`) whenever possible. We will show, however, cases where
using :meth:`__reduce__` is the only option or leads to more efficient pickling
or both.
diff --git a/Doc/whatsnew/3.4.rst b/Doc/whatsnew/3.4.rst
index b509516..6f949a9 100644
--- a/Doc/whatsnew/3.4.rst
+++ b/Doc/whatsnew/3.4.rst
@@ -109,6 +109,7 @@ New expected features for Python implementations:
Significantly Improved Library Modules:
* Single-dispatch generic functions in :mod:`functoools` (:pep:`443`)
+* New :mod:`pickle` protocol 4 (:pep:`3154`)
* SHA-3 (Keccak) support for :mod:`hashlib`.
* TLSv1.1 and TLSv1.2 support for :mod:`ssl`.
* :mod:`multiprocessing` now has option to avoid using :func:`os.fork`
@@ -285,6 +286,20 @@ described in the PEP. Existing importers should be updated to implement
the new methods.
+Pickle protocol 4
+=================
+
+The new :mod:`pickle` protocol addresses a number of issues that were present
+in previous protocols, such as the serialization of nested classes, very
+large strings and containers, or classes whose :meth:`__new__` method takes
+keyword-only arguments. It also brings a couple efficiency improvements.
+
+.. seealso::
+
+ :pep:`3154` - Pickle protocol 4
+ PEP written by Antoine Pitrou and implemented by Alexandre Vassalotti.
+
+
Other Language Changes
======================
diff --git a/Lib/copyreg.py b/Lib/copyreg.py
index 66c0f8a..67f5bb0 100644
--- a/Lib/copyreg.py
+++ b/Lib/copyreg.py
@@ -87,6 +87,12 @@ def _reduce_ex(self, proto):
def __newobj__(cls, *args):
return cls.__new__(cls, *args)
+def __newobj_ex__(cls, args, kwargs):
+ """Used by pickle protocol 4, instead of __newobj__ to allow classes with
+ keyword-only arguments to be pickled correctly.
+ """
+ return cls.__new__(cls, *args, **kwargs)
+
def _slotnames(cls):
"""Return a list of slot names for a given class.
diff --git a/Lib/pickle.py b/Lib/pickle.py
index dbc196a..d1f1538 100644
--- a/Lib/pickle.py
+++ b/Lib/pickle.py
@@ -23,7 +23,7 @@ Misc variables:
"""
-from types import FunctionType, BuiltinFunctionType
+from types import FunctionType, BuiltinFunctionType, ModuleType
from copyreg import dispatch_table
from copyreg import _extension_registry, _inverted_registry, _extension_cache
from itertools import islice
@@ -42,17 +42,18 @@ __all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler",
bytes_types = (bytes, bytearray)
# These are purely informational; no code uses these.
-format_version = "3.0" # File format version we write
+format_version = "4.0" # File format version we write
compatible_formats = ["1.0", # Original protocol 0
"1.1", # Protocol 0 with INST added
"1.2", # Original protocol 1
"1.3", # Protocol 1 with BINFLOAT added
"2.0", # Protocol 2
"3.0", # Protocol 3
+ "4.0", # Protocol 4
] # Old format versions we can read
# This is the highest protocol number we know how to read.
-HIGHEST_PROTOCOL = 3
+HIGHEST_PROTOCOL = 4
# The protocol we write by default. May be less than HIGHEST_PROTOCOL.
# We intentionally write a protocol that Python 2.x cannot read;
@@ -164,7 +165,196 @@ _tuplesize2code = [EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3]
BINBYTES = b'B' # push bytes; counted binary string argument
SHORT_BINBYTES = b'C' # " " ; " " " " < 256 bytes
-__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$",x)])
+# Protocol 4
+SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes
+BINUNICODE8 = b'\x8d' # push very long string
+BINBYTES8 = b'\x8e' # push very long bytes string
+EMPTY_SET = b'\x8f' # push empty set on the stack
+ADDITEMS = b'\x90' # modify set by adding topmost stack items
+FROZENSET = b'\x91' # build frozenset from topmost stack items
+NEWOBJ_EX = b'\x92' # like NEWOBJ but work with keyword only arguments
+STACK_GLOBAL = b'\x93' # same as GLOBAL but using names on the stacks
+MEMOIZE = b'\x94' # store top of the stack in memo
+FRAME = b'\x95' # indicate the beginning of a new frame
+
+__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)])
+
+
+class _Framer:
+
+ _FRAME_SIZE_TARGET = 64 * 1024
+
+ def __init__(self, file_write):
+ self.file_write = file_write
+ self.current_frame = None
+
+ def _commit_frame(self):
+ f = self.current_frame
+ with f.getbuffer() as data:
+ n = len(data)
+ write = self.file_write
+ write(FRAME)
+ write(pack("<Q", n))
+ write(data)
+ f.seek(0)
+ f.truncate()
+
+ def start_framing(self):
+ self.current_frame = io.BytesIO()
+
+ def end_framing(self):
+ if self.current_frame is not None:
+ self._commit_frame()
+ self.current_frame = None
+
+ def write(self, data):
+ f = self.current_frame
+ if f is None:
+ return self.file_write(data)
+ else:
+ n = len(data)
+ if f.tell() >= self._FRAME_SIZE_TARGET:
+ self._commit_frame()
+ return f.write(data)
+
+class _Unframer:
+
+ def __init__(self, file_read, file_readline, file_tell=None):
+ self.file_read = file_read
+ self.file_readline = file_readline
+ self.file_tell = file_tell
+ self.framing_enabled = False
+ self.current_frame = None
+ self.frame_start = None
+
+ def read(self, n):
+ if n == 0:
+ return b''
+ _file_read = self.file_read
+ if not self.framing_enabled:
+ return _file_read(n)
+ f = self.current_frame
+ if f is not None:
+ data = f.read(n)
+ if data:
+ if len(data) < n:
+ raise UnpicklingError(
+ "pickle exhausted before end of frame")
+ return data
+ frame_opcode = _file_read(1)
+ if frame_opcode != FRAME:
+ raise UnpicklingError(
+ "expected a FRAME opcode, got {} instead".format(frame_opcode))
+ frame_size, = unpack("<Q", _file_read(8))
+ if frame_size > sys.maxsize:
+ raise ValueError("frame size > sys.maxsize: %d" % frame_size)
+ if self.file_tell is not None:
+ self.frame_start = self.file_tell()
+ f = self.current_frame = io.BytesIO(_file_read(frame_size))
+ self.readline = f.readline
+ data = f.read(n)
+ assert len(data) == n, (len(data), n)
+ return data
+
+ def readline(self):
+ if not self.framing_enabled:
+ return self.file_readline()
+ else:
+ return self.current_frame.readline()
+
+ def tell(self):
+ if self.file_tell is None:
+ return None
+ elif self.current_frame is None:
+ return self.file_tell()
+ else:
+ return self.frame_start + self.current_frame.tell()
+
+
+# Tools used for pickling.
+
+def _getattribute(obj, name, allow_qualname=False):
+ dotted_path = name.split(".")
+ if not allow_qualname and len(dotted_path) > 1:
+ raise AttributeError("Can't get qualified attribute {!r} on {!r}; " +
+ "use protocols >= 4 to enable support"
+ .format(name, obj))
+ for subpath in dotted_path:
+ if subpath == '<locals>':
+ raise AttributeError("Can't get local attribute {!r} on {!r}"
+ .format(name, obj))
+ try:
+ obj = getattr(obj, subpath)
+ except AttributeError:
+ raise AttributeError("Can't get attribute {!r} on {!r}"
+ .format(name, obj))
+ return obj
+
+def whichmodule(obj, name, allow_qualname=False):
+ """Find the module an object belong to."""
+ module_name = getattr(obj, '__module__', None)
+ if module_name is not None:
+ return module_name
+ for module_name, module in sys.modules.items():
+ if module_name == '__main__' or module is None:
+ continue
+ try:
+ if _getattribute(module, name, allow_qualname) is obj:
+ return module_name
+ except AttributeError:
+ pass
+ return '__main__'
+
+def encode_long(x):
+ r"""Encode a long to a two's complement little-endian binary string.
+ Note that 0 is a special case, returning an empty string, to save a
+ byte in the LONG1 pickling context.
+
+ >>> encode_long(0)
+ b''
+ >>> encode_long(255)
+ b'\xff\x00'
+ >>> encode_long(32767)
+ b'\xff\x7f'
+ >>> encode_long(-256)
+ b'\x00\xff'
+ >>> encode_long(-32768)
+ b'\x00\x80'
+ >>> encode_long(-128)
+ b'\x80'
+ >>> encode_long(127)
+ b'\x7f'
+ >>>
+ """
+ if x == 0:
+ return b''
+ nbytes = (x.bit_length() >> 3) + 1
+ result = x.to_bytes(nbytes, byteorder='little', signed=True)
+ if x < 0 and nbytes > 1:
+ if result[-1] == 0xff and (result[-2] & 0x80) != 0:
+ result = result[:-1]
+ return result
+
+def decode_long(data):
+ r"""Decode a long from a two's complement little-endian binary string.
+
+ >>> decode_long(b'')
+ 0
+ >>> decode_long(b"\xff\x00")
+ 255
+ >>> decode_long(b"\xff\x7f")
+ 32767
+ >>> decode_long(b"\x00\xff")
+ -256
+ >>> decode_long(b"\x00\x80")
+ -32768
+ >>> decode_long(b"\x80")
+ -128
+ >>> decode_long(b"\x7f")
+ 127
+ """
+ return int.from_bytes(data, byteorder='little', signed=True)
+
# Pickling machinery
@@ -174,9 +364,9 @@ class _Pickler:
"""This takes a binary file for writing a pickle data stream.
The optional protocol argument tells the pickler to use the
- given protocol; supported protocols are 0, 1, 2, 3. The default
- protocol is 3; a backward-incompatible protocol designed for
- Python 3.0.
+ given protocol; supported protocols are 0, 1, 2, 3 and 4. The
+ default protocol is 3; a backward-incompatible protocol designed for
+ Python 3.
Specifying a negative protocol version selects the highest
protocol version supported. The higher the protocol used, the
@@ -189,8 +379,8 @@ class _Pickler:
meets this interface.
If fix_imports is True and protocol is less than 3, pickle will try to
- map the new Python 3.x names to the old module names used in Python
- 2.x, so that the pickle data stream is readable with Python 2.x.
+ map the new Python 3 names to the old module names used in Python 2,
+ so that the pickle data stream is readable with Python 2.
"""
if protocol is None:
protocol = DEFAULT_PROTOCOL
@@ -199,7 +389,7 @@ class _Pickler:
elif not 0 <= protocol <= HIGHEST_PROTOCOL:
raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL)
try:
- self.write = file.write
+ self._file_write = file.write
except AttributeError:
raise TypeError("file must have a 'write' attribute")
self.memo = {}
@@ -223,13 +413,22 @@ class _Pickler:
"""Write a pickled representation of obj to the open file."""
# Check whether Pickler was initialized correctly. This is
# only needed to mimic the behavior of _pickle.Pickler.dump().
- if not hasattr(self, "write"):
+ if not hasattr(self, "_file_write"):
raise PicklingError("Pickler.__init__() was not called by "
"%s.__init__()" % (self.__class__.__name__,))
if self.proto >= 2:
- self.write(PROTO + pack("<B", self.proto))
+ self._file_write(PROTO + pack("<B", self.proto))
+ if self.proto >= 4:
+ framer = _Framer(self._file_write)
+ framer.start_framing()
+ self.write = framer.write
+ else:
+ framer = None
+ self.write = self._file_write
self.save(obj)
self.write(STOP)
+ if framer is not None:
+ framer.end_framing()
def memoize(self, obj):
"""Store an object in the memo."""
@@ -249,19 +448,21 @@ class _Pickler:
if self.fast:
return
assert id(obj) not in self.memo
- memo_len = len(self.memo)
- self.write(self.put(memo_len))
- self.memo[id(obj)] = memo_len, obj
+ idx = len(self.memo)
+ self.write(self.put(idx))
+ self.memo[id(obj)] = idx, obj
# Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i.
- def put(self, i):
- if self.bin:
- if i < 256:
- return BINPUT + pack("<B", i)
+ def put(self, idx):
+ if self.proto >= 4:
+ return MEMOIZE
+ elif self.bin:
+ if idx < 256:
+ return BINPUT + pack("<B", idx)
else:
- return LONG_BINPUT + pack("<I", i)
-
- return PUT + repr(i).encode("ascii") + b'\n'
+ return LONG_BINPUT + pack("<I", idx)
+ else:
+ return PUT + repr(idx).encode("ascii") + b'\n'
# Return a GET (BINGET, LONG_BINGET) opcode string, with argument i.
def get(self, i):
@@ -349,24 +550,33 @@ class _Pickler:
else:
self.write(PERSID + str(pid).encode("ascii") + b'\n')
- def save_reduce(self, func, args, state=None,
- listitems=None, dictitems=None, obj=None):
+ def save_reduce(self, func, args, state=None, listitems=None,
+ dictitems=None, obj=None):
# This API is called by some subclasses
- # Assert that args is a tuple
if not isinstance(args, tuple):
- raise PicklingError("args from save_reduce() should be a tuple")
-
- # Assert that func is callable
+ raise PicklingError("args from save_reduce() must be a tuple")
if not callable(func):
- raise PicklingError("func from save_reduce() should be callable")
+ raise PicklingError("func from save_reduce() must be callable")
save = self.save
write = self.write
- # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ
- if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__":
- # A __reduce__ implementation can direct protocol 2 to
+ func_name = getattr(func, "__name__", "")
+ if self.proto >= 4 and func_name == "__newobj_ex__":
+ cls, args, kwargs = args
+ if not hasattr(cls, "__new__"):
+ raise PicklingError("args[0] from {} args has no __new__"
+ .format(func_name))
+ if obj is not None and cls is not obj.__class__:
+ raise PicklingError("args[0] from {} args has the wrong class"
+ .format(func_name))
+ save(cls)
+ save(args)
+ save(kwargs)
+ write(NEWOBJ_EX)
+ elif self.proto >= 2 and func_name == "__newobj__":
+ # A __reduce__ implementation can direct protocol 2 or newer to
# use the more efficient NEWOBJ opcode, while still
# allowing protocol 0 and 1 to work normally. For this to
# work, the function returned by __reduce__ should be
@@ -409,7 +619,13 @@ class _Pickler:
write(REDUCE)
if obj is not None:
- self.memoize(obj)
+ # If the object is already in the memo, this means it is
+ # recursive. In this case, throw away everything we put on the
+ # stack, and fetch the object back from the memo.
+ if id(obj) in self.memo:
+ write(POP + self.get(self.memo[id(obj)][0]))
+ else:
+ self.memoize(obj)
# More new special cases (that work with older protocols as
# well): when __reduce__ returns a tuple with 4 or 5 items,
@@ -493,8 +709,10 @@ class _Pickler:
(str(obj, 'latin1'), 'latin1'), obj=obj)
return
n = len(obj)
- if n < 256:
+ if n <= 0xff:
self.write(SHORT_BINBYTES + pack("<B", n) + obj)
+ elif n > 0xffffffff and self.proto >= 4:
+ self.write(BINBYTES8 + pack("<Q", n) + obj)
else:
self.write(BINBYTES + pack("<I", n) + obj)
self.memoize(obj)
@@ -504,11 +722,17 @@ class _Pickler:
if self.bin:
encoded = obj.encode('utf-8', 'surrogatepass')
n = len(encoded)
- self.write(BINUNICODE + pack("<I", n) + encoded)
+ if n <= 0xff and self.proto >= 4:
+ self.write(SHORT_BINUNICODE + pack("<B", n) + encoded)
+ elif n > 0xffffffff and self.proto >= 4:
+ self.write(BINUNICODE8 + pack("<Q", n) + encoded)
+ else:
+ self.write(BINUNICODE + pack("<I", n) + encoded)
else:
obj = obj.replace("\\", "\\u005c")
obj = obj.replace("\n", "\\u000a")
- self.write(UNICODE + obj.encode('raw-unicode-escape') + b'\n')
+ self.write(UNICODE + obj.encode('raw-unicode-escape') +
+ b'\n')
self.memoize(obj)
dispatch[str] = save_str
@@ -647,33 +871,79 @@ class _Pickler:
if n < self._BATCHSIZE:
return
+ def save_set(self, obj):
+ save = self.save
+ write = self.write
+
+ if self.proto < 4:
+ self.save_reduce(set, (list(obj),), obj=obj)
+ return
+
+ write(EMPTY_SET)
+ self.memoize(obj)
+
+ it = iter(obj)
+ while True:
+ batch = list(islice(it, self._BATCHSIZE))
+ n = len(batch)
+ if n > 0:
+ write(MARK)
+ for item in batch:
+ save(item)
+ write(ADDITEMS)
+ if n < self._BATCHSIZE:
+ return
+ dispatch[set] = save_set
+
+ def save_frozenset(self, obj):
+ save = self.save
+ write = self.write
+
+ if self.proto < 4:
+ self.save_reduce(frozenset, (list(obj),), obj=obj)
+ return
+
+ write(MARK)
+ for item in obj:
+ save(item)
+
+ if id(obj) in self.memo:
+ # If the object is already in the memo, this means it is
+ # recursive. In this case, throw away everything we put on the
+ # stack, and fetch the object back from the memo.
+ write(POP_MARK + self.get(self.memo[id(obj)][0]))
+ return
+
+ write(FROZENSET)
+ self.memoize(obj)
+ dispatch[frozenset] = save_frozenset
+
def save_global(self, obj, name=None):
write = self.write
memo = self.memo
+ if name is None and self.proto >= 4:
+ name = getattr(obj, '__qualname__', None)
if name is None:
name = obj.__name__
- module = getattr(obj, "__module__", None)
- if module is None:
- module = whichmodule(obj, name)
-
+ module_name = whichmodule(obj, name, allow_qualname=self.proto >= 4)
try:
- __import__(module, level=0)
- mod = sys.modules[module]
- klass = getattr(mod, name)
+ __import__(module_name, level=0)
+ module = sys.modules[module_name]
+ obj2 = _getattribute(module, name, allow_qualname=self.proto >= 4)
except (ImportError, KeyError, AttributeError):
raise PicklingError(
"Can't pickle %r: it's not found as %s.%s" %
- (obj, module, name))
+ (obj, module_name, name))
else:
- if klass is not obj:
+ if obj2 is not obj:
raise PicklingError(
"Can't pickle %r: it's not the same object as %s.%s" %
- (obj, module, name))
+ (obj, module_name, name))
if self.proto >= 2:
- code = _extension_registry.get((module, name))
+ code = _extension_registry.get((module_name, name))
if code:
assert code > 0
if code <= 0xff:
@@ -684,17 +954,23 @@ class _Pickler:
write(EXT4 + pack("<i", code))
return
# Non-ASCII identifiers are supported only with protocols >= 3.
- if self.proto >= 3:
- write(GLOBAL + bytes(module, "utf-8") + b'\n' +
+ if self.proto >= 4:
+ self.save(module_name)
+ self.save(name)
+ write(STACK_GLOBAL)
+ elif self.proto >= 3:
+ write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
else:
if self.fix_imports:
- if (module, name) in _compat_pickle.REVERSE_NAME_MAPPING:
- module, name = _compat_pickle.REVERSE_NAME_MAPPING[(module, name)]
- if module in _compat_pickle.REVERSE_IMPORT_MAPPING:
- module = _compat_pickle.REVERSE_IMPORT_MAPPING[module]
+ r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
+ r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
+ if (module_name, name) in r_name_mapping:
+ module_name, name = r_name_mapping[(module_name, name)]
+ if module_name in r_import_mapping:
+ module_name = r_import_mapping[module_name]
try:
- write(GLOBAL + bytes(module, "ascii") + b'\n' +
+ write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
bytes(name, "ascii") + b'\n')
except UnicodeEncodeError:
raise PicklingError(
@@ -703,40 +979,16 @@ class _Pickler:
self.memoize(obj)
+ def save_method(self, obj):
+ if obj.__self__ is None or type(obj.__self__) is ModuleType:
+ self.save_global(obj)
+ else:
+ self.save_reduce(getattr, (obj.__self__, obj.__name__), obj=obj)
+
dispatch[FunctionType] = save_global
- dispatch[BuiltinFunctionType] = save_global
+ dispatch[BuiltinFunctionType] = save_method
dispatch[type] = save_global
-# A cache for whichmodule(), mapping a function object to the name of
-# the module in which the function was found.
-
-classmap = {} # called classmap for backwards compatibility
-
-def whichmodule(func, funcname):
- """Figure out the module in which a function occurs.
-
- Search sys.modules for the module.
- Cache in classmap.
- Return a module name.
- If the function cannot be found, return "__main__".
- """
- # Python functions should always get an __module__ from their globals.
- mod = getattr(func, "__module__", None)
- if mod is not None:
- return mod
- if func in classmap:
- return classmap[func]
-
- for name, module in list(sys.modules.items()):
- if module is None:
- continue # skip dummy package entries
- if name != '__main__' and getattr(module, funcname, None) is func:
- break
- else:
- name = '__main__'
- classmap[func] = name
- return name
-
# Unpickling machinery
@@ -764,8 +1016,8 @@ class _Unpickler:
instances pickled by Python 2.x; these default to 'ASCII' and
'strict', respectively.
"""
- self.readline = file.readline
- self.read = file.read
+ self._file_readline = file.readline
+ self._file_read = file.read
self.memo = {}
self.encoding = encoding
self.errors = errors
@@ -779,12 +1031,16 @@ class _Unpickler:
"""
# Check whether Unpickler was initialized correctly. This is
# only needed to mimic the behavior of _pickle.Unpickler.dump().
- if not hasattr(self, "read"):
+ if not hasattr(self, "_file_read"):
raise UnpicklingError("Unpickler.__init__() was not called by "
"%s.__init__()" % (self.__class__.__name__,))
+ self._unframer = _Unframer(self._file_read, self._file_readline)
+ self.read = self._unframer.read
+ self.readline = self._unframer.readline
self.mark = object() # any new unique object
self.stack = []
self.append = self.stack.append
+ self.proto = 0
read = self.read
dispatch = self.dispatch
try:
@@ -822,6 +1078,8 @@ class _Unpickler:
if not 0 <= proto <= HIGHEST_PROTOCOL:
raise ValueError("unsupported pickle protocol: %d" % proto)
self.proto = proto
+ if proto >= 4:
+ self._unframer.framing_enabled = True
dispatch[PROTO[0]] = load_proto
def load_persid(self):
@@ -940,6 +1198,14 @@ class _Unpickler:
self.append(str(self.read(len), 'utf-8', 'surrogatepass'))
dispatch[BINUNICODE[0]] = load_binunicode
+ def load_binunicode8(self):
+ len, = unpack('<Q', self.read(8))
+ if len > maxsize:
+ raise UnpicklingError("BINUNICODE8 exceeds system's maximum size "
+ "of %d bytes" % maxsize)
+ self.append(str(self.read(len), 'utf-8', 'surrogatepass'))
+ dispatch[BINUNICODE8[0]] = load_binunicode8
+
def load_short_binstring(self):
len = self.read(1)[0]
data = self.read(len)
@@ -952,6 +1218,11 @@ class _Unpickler:
self.append(self.read(len))
dispatch[SHORT_BINBYTES[0]] = load_short_binbytes
+ def load_short_binunicode(self):
+ len = self.read(1)[0]
+ self.append(str(self.read(len), 'utf-8', 'surrogatepass'))
+ dispatch[SHORT_BINUNICODE[0]] = load_short_binunicode
+
def load_tuple(self):
k = self.marker()
self.stack[k:] = [tuple(self.stack[k+1:])]
@@ -981,6 +1252,15 @@ class _Unpickler:
self.append({})
dispatch[EMPTY_DICT[0]] = load_empty_dictionary
+ def load_empty_set(self):
+ self.append(set())
+ dispatch[EMPTY_SET[0]] = load_empty_set
+
+ def load_frozenset(self):
+ k = self.marker()
+ self.stack[k:] = [frozenset(self.stack[k+1:])]
+ dispatch[FROZENSET[0]] = load_frozenset
+
def load_list(self):
k = self.marker()
self.stack[k:] = [self.stack[k+1:]]
@@ -1029,11 +1309,19 @@ class _Unpickler:
def load_newobj(self):
args = self.stack.pop()
- cls = self.stack[-1]
+ cls = self.stack.pop()
obj = cls.__new__(cls, *args)
- self.stack[-1] = obj
+ self.append(obj)
dispatch[NEWOBJ[0]] = load_newobj
+ def load_newobj_ex(self):
+ kwargs = self.stack.pop()
+ args = self.stack.pop()
+ cls = self.stack.pop()
+ obj = cls.__new__(cls, *args, **kwargs)
+ self.append(obj)
+ dispatch[NEWOBJ_EX[0]] = load_newobj_ex
+
def load_global(self):
module = self.readline()[:-1].decode("utf-8")
name = self.readline()[:-1].decode("utf-8")
@@ -1041,6 +1329,14 @@ class _Unpickler:
self.append(klass)
dispatch[GLOBAL[0]] = load_global
+ def load_stack_global(self):
+ name = self.stack.pop()
+ module = self.stack.pop()
+ if type(name) is not str or type(module) is not str:
+ raise UnpicklingError("STACK_GLOBAL requires str")
+ self.append(self.find_class(module, name))
+ dispatch[STACK_GLOBAL[0]] = load_stack_global
+
def load_ext1(self):
code = self.read(1)[0]
self.get_extension(code)
@@ -1080,9 +1376,8 @@ class _Unpickler:
if module in _compat_pickle.IMPORT_MAPPING:
module = _compat_pickle.IMPORT_MAPPING[module]
__import__(module, level=0)
- mod = sys.modules[module]
- klass = getattr(mod, name)
- return klass
+ return _getattribute(sys.modules[module], name,
+ allow_qualname=self.proto >= 4)
def load_reduce(self):
stack = self.stack
@@ -1146,6 +1441,11 @@ class _Unpickler:
self.memo[i] = self.stack[-1]
dispatch[LONG_BINPUT[0]] = load_long_binput
+ def load_memoize(self):
+ memo = self.memo
+ memo[len(memo)] = self.stack[-1]
+ dispatch[MEMOIZE[0]] = load_memoize
+
def load_append(self):
stack = self.stack
value = stack.pop()
@@ -1185,6 +1485,20 @@ class _Unpickler:
del stack[mark:]
dispatch[SETITEMS[0]] = load_setitems
+ def load_additems(self):
+ stack = self.stack
+ mark = self.marker()
+ set_obj = stack[mark - 1]
+ items = stack[mark + 1:]
+ if isinstance(set_obj, set):
+ set_obj.update(items)
+ else:
+ add = set_obj.add
+ for item in items:
+ add(item)
+ del stack[mark:]
+ dispatch[ADDITEMS[0]] = load_additems
+
def load_build(self):
stack = self.stack
state = stack.pop()
@@ -1218,86 +1532,46 @@ class _Unpickler:
raise _Stop(value)
dispatch[STOP[0]] = load_stop
-# Encode/decode ints.
-
-def encode_long(x):
- r"""Encode a long to a two's complement little-endian binary string.
- Note that 0 is a special case, returning an empty string, to save a
- byte in the LONG1 pickling context.
-
- >>> encode_long(0)
- b''
- >>> encode_long(255)
- b'\xff\x00'
- >>> encode_long(32767)
- b'\xff\x7f'
- >>> encode_long(-256)
- b'\x00\xff'
- >>> encode_long(-32768)
- b'\x00\x80'
- >>> encode_long(-128)
- b'\x80'
- >>> encode_long(127)
- b'\x7f'
- >>>
- """
- if x == 0:
- return b''
- nbytes = (x.bit_length() >> 3) + 1
- result = x.to_bytes(nbytes, byteorder='little', signed=True)
- if x < 0 and nbytes > 1:
- if result[-1] == 0xff and (result[-2] & 0x80) != 0:
- result = result[:-1]
- return result
-
-def decode_long(data):
- r"""Decode an int from a two's complement little-endian binary string.
-
- >>> decode_long(b'')
- 0
- >>> decode_long(b"\xff\x00")
- 255
- >>> decode_long(b"\xff\x7f")
- 32767
- >>> decode_long(b"\x00\xff")
- -256
- >>> decode_long(b"\x00\x80")
- -32768
- >>> decode_long(b"\x80")
- -128
- >>> decode_long(b"\x7f")
- 127
- """
- return int.from_bytes(data, byteorder='little', signed=True)
# Shorthands
-def dump(obj, file, protocol=None, *, fix_imports=True):
- Pickler(file, protocol, fix_imports=fix_imports).dump(obj)
+def _dump(obj, file, protocol=None, *, fix_imports=True):
+ _Pickler(file, protocol, fix_imports=fix_imports).dump(obj)
-def dumps(obj, protocol=None, *, fix_imports=True):
+def _dumps(obj, protocol=None, *, fix_imports=True):
f = io.BytesIO()
- Pickler(f, protocol, fix_imports=fix_imports).dump(obj)
+ _Pickler(f, protocol, fix_imports=fix_imports).dump(obj)
res = f.getvalue()
assert isinstance(res, bytes_types)
return res
-def load(file, *, fix_imports=True, encoding="ASCII", errors="strict"):
- return Unpickler(file, fix_imports=fix_imports,
+def _load(file, *, fix_imports=True, encoding="ASCII", errors="strict"):
+ return _Unpickler(file, fix_imports=fix_imports,
encoding=encoding, errors=errors).load()
-def loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"):
+def _loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"):
if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
- return Unpickler(file, fix_imports=fix_imports,
- encoding=encoding, errors=errors).load()
+ return _Unpickler(file, fix_imports=fix_imports,
+ encoding=encoding, errors=errors).load()
# Use the faster _pickle if possible
try:
- from _pickle import *
+ from _pickle import (
+ PickleError,
+ PicklingError,
+ UnpicklingError,
+ Pickler,
+ Unpickler,
+ dump,
+ dumps,
+ load,
+ loads
+ )
except ImportError:
Pickler, Unpickler = _Pickler, _Unpickler
+ dump, dumps, load, loads = _dump, _dumps, _load, _loads
# Doctest
def _test():
diff --git a/Lib/pickletools.py b/Lib/pickletools.py
index e92146d..d711bf0 100644
--- a/Lib/pickletools.py
+++ b/Lib/pickletools.py
@@ -11,6 +11,7 @@ dis(pickle, out=None, memo=None, indentlevel=4)
'''
import codecs
+import io
import pickle
import re
import sys
@@ -168,6 +169,7 @@ UP_TO_NEWLINE = -1
TAKEN_FROM_ARGUMENT1 = -2 # num bytes is 1-byte unsigned int
TAKEN_FROM_ARGUMENT4 = -3 # num bytes is 4-byte signed little-endian int
TAKEN_FROM_ARGUMENT4U = -4 # num bytes is 4-byte unsigned little-endian int
+TAKEN_FROM_ARGUMENT8U = -5 # num bytes is 8-byte unsigned little-endian int
class ArgumentDescriptor(object):
__slots__ = (
@@ -175,7 +177,7 @@ class ArgumentDescriptor(object):
'name',
# length of argument, in bytes; an int; UP_TO_NEWLINE and
- # TAKEN_FROM_ARGUMENT{1,4} are negative values for variable-length
+ # TAKEN_FROM_ARGUMENT{1,4,8} are negative values for variable-length
# cases
'n',
@@ -196,7 +198,8 @@ class ArgumentDescriptor(object):
n in (UP_TO_NEWLINE,
TAKEN_FROM_ARGUMENT1,
TAKEN_FROM_ARGUMENT4,
- TAKEN_FROM_ARGUMENT4U))
+ TAKEN_FROM_ARGUMENT4U,
+ TAKEN_FROM_ARGUMENT8U))
self.n = n
self.reader = reader
@@ -288,6 +291,27 @@ uint4 = ArgumentDescriptor(
doc="Four-byte unsigned integer, little-endian.")
+def read_uint8(f):
+ r"""
+ >>> import io
+ >>> read_uint8(io.BytesIO(b'\xff\x00\x00\x00\x00\x00\x00\x00'))
+ 255
+ >>> read_uint8(io.BytesIO(b'\xff' * 8)) == 2**64-1
+ True
+ """
+
+ data = f.read(8)
+ if len(data) == 8:
+ return _unpack("<Q", data)[0]
+ raise ValueError("not enough data in stream to read uint8")
+
+uint8 = ArgumentDescriptor(
+ name='uint8',
+ n=8,
+ reader=read_uint8,
+ doc="Eight-byte unsigned integer, little-endian.")
+
+
def read_stringnl(f, decode=True, stripquotes=True):
r"""
>>> import io
@@ -381,6 +405,36 @@ stringnl_noescape_pair = ArgumentDescriptor(
a single blank separating the two strings.
""")
+
+def read_string1(f):
+ r"""
+ >>> import io
+ >>> read_string1(io.BytesIO(b"\x00"))
+ ''
+ >>> read_string1(io.BytesIO(b"\x03abcdef"))
+ 'abc'
+ """
+
+ n = read_uint1(f)
+ assert n >= 0
+ data = f.read(n)
+ if len(data) == n:
+ return data.decode("latin-1")
+ raise ValueError("expected %d bytes in a string1, but only %d remain" %
+ (n, len(data)))
+
+string1 = ArgumentDescriptor(
+ name="string1",
+ n=TAKEN_FROM_ARGUMENT1,
+ reader=read_string1,
+ doc="""A counted string.
+
+ The first argument is a 1-byte unsigned int giving the number
+ of bytes in the string, and the second argument is that many
+ bytes.
+ """)
+
+
def read_string4(f):
r"""
>>> import io
@@ -415,28 +469,28 @@ string4 = ArgumentDescriptor(
""")
-def read_string1(f):
+def read_bytes1(f):
r"""
>>> import io
- >>> read_string1(io.BytesIO(b"\x00"))
- ''
- >>> read_string1(io.BytesIO(b"\x03abcdef"))
- 'abc'
+ >>> read_bytes1(io.BytesIO(b"\x00"))
+ b''
+ >>> read_bytes1(io.BytesIO(b"\x03abcdef"))
+ b'abc'
"""
n = read_uint1(f)
assert n >= 0
data = f.read(n)
if len(data) == n:
- return data.decode("latin-1")
- raise ValueError("expected %d bytes in a string1, but only %d remain" %
+ return data
+ raise ValueError("expected %d bytes in a bytes1, but only %d remain" %
(n, len(data)))
-string1 = ArgumentDescriptor(
- name="string1",
+bytes1 = ArgumentDescriptor(
+ name="bytes1",
n=TAKEN_FROM_ARGUMENT1,
- reader=read_string1,
- doc="""A counted string.
+ reader=read_bytes1,
+ doc="""A counted bytes string.
The first argument is a 1-byte unsigned int giving the number
of bytes in the string, and the second argument is that many
@@ -486,6 +540,7 @@ def read_bytes4(f):
"""
n = read_uint4(f)
+ assert n >= 0
if n > sys.maxsize:
raise ValueError("bytes4 byte count > sys.maxsize: %d" % n)
data = f.read(n)
@@ -505,6 +560,39 @@ bytes4 = ArgumentDescriptor(
""")
+def read_bytes8(f):
+ r"""
+ >>> import io
+ >>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x00\x00abc"))
+ b''
+ >>> read_bytes8(io.BytesIO(b"\x03\x00\x00\x00\x00\x00\x00\x00abcdef"))
+ b'abc'
+ >>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x03\x00abcdef"))
+ Traceback (most recent call last):
+ ...
+ ValueError: expected 844424930131968 bytes in a bytes8, but only 6 remain
+ """
+
+ n = read_uint8(f)
+ assert n >= 0
+ if n > sys.maxsize:
+ raise ValueError("bytes8 byte count > sys.maxsize: %d" % n)
+ data = f.read(n)
+ if len(data) == n:
+ return data
+ raise ValueError("expected %d bytes in a bytes8, but only %d remain" %
+ (n, len(data)))
+
+bytes8 = ArgumentDescriptor(
+ name="bytes8",
+ n=TAKEN_FROM_ARGUMENT8U,
+ reader=read_bytes8,
+ doc="""A counted bytes string.
+
+ The first argument is a 8-byte little-endian unsigned int giving
+ the number of bytes, and the second argument is that many bytes.
+ """)
+
def read_unicodestringnl(f):
r"""
>>> import io
@@ -530,6 +618,46 @@ unicodestringnl = ArgumentDescriptor(
escape sequences.
""")
+
+def read_unicodestring1(f):
+ r"""
+ >>> import io
+ >>> s = 'abcd\uabcd'
+ >>> enc = s.encode('utf-8')
+ >>> enc
+ b'abcd\xea\xaf\x8d'
+ >>> n = bytes([len(enc)]) # little-endian 1-byte length
+ >>> t = read_unicodestring1(io.BytesIO(n + enc + b'junk'))
+ >>> s == t
+ True
+
+ >>> read_unicodestring1(io.BytesIO(n + enc[:-1]))
+ Traceback (most recent call last):
+ ...
+ ValueError: expected 7 bytes in a unicodestring1, but only 6 remain
+ """
+
+ n = read_uint1(f)
+ assert n >= 0
+ data = f.read(n)
+ if len(data) == n:
+ return str(data, 'utf-8', 'surrogatepass')
+ raise ValueError("expected %d bytes in a unicodestring1, but only %d "
+ "remain" % (n, len(data)))
+
+unicodestring1 = ArgumentDescriptor(
+ name="unicodestring1",
+ n=TAKEN_FROM_ARGUMENT1,
+ reader=read_unicodestring1,
+ doc="""A counted Unicode string.
+
+ The first argument is a 1-byte little-endian signed int
+ giving the number of bytes in the string, and the second
+ argument-- the UTF-8 encoding of the Unicode string --
+ contains that many bytes.
+ """)
+
+
def read_unicodestring4(f):
r"""
>>> import io
@@ -549,6 +677,7 @@ def read_unicodestring4(f):
"""
n = read_uint4(f)
+ assert n >= 0
if n > sys.maxsize:
raise ValueError("unicodestring4 byte count > sys.maxsize: %d" % n)
data = f.read(n)
@@ -570,6 +699,47 @@ unicodestring4 = ArgumentDescriptor(
""")
+def read_unicodestring8(f):
+ r"""
+ >>> import io
+ >>> s = 'abcd\uabcd'
+ >>> enc = s.encode('utf-8')
+ >>> enc
+ b'abcd\xea\xaf\x8d'
+ >>> n = bytes([len(enc)]) + bytes(7) # little-endian 8-byte length
+ >>> t = read_unicodestring8(io.BytesIO(n + enc + b'junk'))
+ >>> s == t
+ True
+
+ >>> read_unicodestring8(io.BytesIO(n + enc[:-1]))
+ Traceback (most recent call last):
+ ...
+ ValueError: expected 7 bytes in a unicodestring8, but only 6 remain
+ """
+
+ n = read_uint8(f)
+ assert n >= 0
+ if n > sys.maxsize:
+ raise ValueError("unicodestring8 byte count > sys.maxsize: %d" % n)
+ data = f.read(n)
+ if len(data) == n:
+ return str(data, 'utf-8', 'surrogatepass')
+ raise ValueError("expected %d bytes in a unicodestring8, but only %d "
+ "remain" % (n, len(data)))
+
+unicodestring8 = ArgumentDescriptor(
+ name="unicodestring8",
+ n=TAKEN_FROM_ARGUMENT8U,
+ reader=read_unicodestring8,
+ doc="""A counted Unicode string.
+
+ The first argument is a 8-byte little-endian signed int
+ giving the number of bytes in the string, and the second
+ argument-- the UTF-8 encoding of the Unicode string --
+ contains that many bytes.
+ """)
+
+
def read_decimalnl_short(f):
r"""
>>> import io
@@ -859,6 +1029,16 @@ pydict = StackObject(
obtype=dict,
doc="A Python dict object.")
+pyset = StackObject(
+ name="set",
+ obtype=set,
+ doc="A Python set object.")
+
+pyfrozenset = StackObject(
+ name="frozenset",
+ obtype=set,
+ doc="A Python frozenset object.")
+
anyobject = StackObject(
name='any',
obtype=object,
@@ -1142,6 +1322,19 @@ opcodes = [
literally as the string content.
"""),
+ I(name='BINBYTES8',
+ code='\x8e',
+ arg=bytes8,
+ stack_before=[],
+ stack_after=[pybytes],
+ proto=4,
+ doc="""Push a Python bytes object.
+
+ There are two arguments: the first is a 8-byte unsigned int giving
+ the number of bytes in the string, and the second is that many bytes,
+ which are taken literally as the string content.
+ """),
+
# Ways to spell None.
I(name='NONE',
@@ -1190,6 +1383,19 @@ opcodes = [
until the next newline character.
"""),
+ I(name='SHORT_BINUNICODE',
+ code='\x8c',
+ arg=unicodestring1,
+ stack_before=[],
+ stack_after=[pyunicode],
+ proto=4,
+ doc="""Push a Python Unicode string object.
+
+ There are two arguments: the first is a 1-byte little-endian signed int
+ giving the number of bytes in the string. The second is that many
+ bytes, and is the UTF-8 encoding of the Unicode string.
+ """),
+
I(name='BINUNICODE',
code='X',
arg=unicodestring4,
@@ -1203,6 +1409,19 @@ opcodes = [
bytes, and is the UTF-8 encoding of the Unicode string.
"""),
+ I(name='BINUNICODE8',
+ code='\x8d',
+ arg=unicodestring8,
+ stack_before=[],
+ stack_after=[pyunicode],
+ proto=4,
+ doc="""Push a Python Unicode string object.
+
+ There are two arguments: the first is a 8-byte little-endian signed int
+ giving the number of bytes in the string. The second is that many
+ bytes, and is the UTF-8 encoding of the Unicode string.
+ """),
+
# Ways to spell floats.
I(name='FLOAT',
@@ -1428,6 +1647,54 @@ opcodes = [
1, 2, ..., n, and in that order.
"""),
+ # Ways to build sets
+
+ I(name='EMPTY_SET',
+ code='\x8f',
+ arg=None,
+ stack_before=[],
+ stack_after=[pyset],
+ proto=4,
+ doc="Push an empty set."),
+
+ I(name='ADDITEMS',
+ code='\x90',
+ arg=None,
+ stack_before=[pyset, markobject, stackslice],
+ stack_after=[pyset],
+ proto=4,
+ doc="""Add an arbitrary number of items to an existing set.
+
+ The slice of the stack following the topmost markobject is taken as
+ a sequence of items, added to the set immediately under the topmost
+ markobject. Everything at and after the topmost markobject is popped,
+ leaving the mutated set at the top of the stack.
+
+ Stack before: ... pyset markobject item_1 ... item_n
+ Stack after: ... pyset
+
+ where pyset has been modified via pyset.add(item_i) = item_i for i in
+ 1, 2, ..., n, and in that order.
+ """),
+
+ # Way to build frozensets
+
+ I(name='FROZENSET',
+ code='\x91',
+ arg=None,
+ stack_before=[markobject, stackslice],
+ stack_after=[pyfrozenset],
+ proto=4,
+ doc="""Build a frozenset out of the topmost slice, after markobject.
+
+ All the stack entries following the topmost markobject are placed into
+ a single Python frozenset, which single frozenset object replaces all
+ of the stack from the topmost markobject onward. For example,
+
+ Stack before: ... markobject 1 2 3
+ Stack after: ... frozenset({1, 2, 3})
+ """),
+
# Stack manipulation.
I(name='POP',
@@ -1549,6 +1816,18 @@ opcodes = [
unsigned little-endian integer following.
"""),
+ I(name='MEMOIZE',
+ code='\x94',
+ arg=None,
+ stack_before=[anyobject],
+ stack_after=[anyobject],
+ proto=4,
+ doc="""Store the stack top into the memo. The stack is not popped.
+
+ The index of the memo location to write is the number of
+ elements currently present in the memo.
+ """),
+
# Access the extension registry (predefined objects). Akin to the GET
# family.
@@ -1614,6 +1893,15 @@ opcodes = [
stack, so unpickling subclasses can override this form of lookup.
"""),
+ I(name='STACK_GLOBAL',
+ code='\x93',
+ arg=None,
+ stack_before=[pyunicode, pyunicode],
+ stack_after=[anyobject],
+ proto=0,
+ doc="""Push a global object (module.attr) on the stack.
+ """),
+
# Ways to build objects of classes pickle doesn't know about directly
# (user-defined classes). I despair of documenting this accurately
# and comprehensibly -- you really have to read the pickle code to
@@ -1770,6 +2058,21 @@ opcodes = [
onto the stack.
"""),
+ I(name='NEWOBJ_EX',
+ code='\x92',
+ arg=None,
+ stack_before=[anyobject, anyobject, anyobject],
+ stack_after=[anyobject],
+ proto=4,
+ doc="""Build an object instance.
+
+ The stack before should be thought of as containing a class
+ object followed by an argument tuple and by a keyword argument dict
+ (the dict being the stack top). Call these cls and args. They are
+ popped off the stack, and the value returned by
+ cls.__new__(cls, *args, *kwargs) is pushed back onto the stack.
+ """),
+
# Machine control.
I(name='PROTO',
@@ -1797,6 +2100,20 @@ opcodes = [
empty then.
"""),
+ # Framing support.
+
+ I(name='FRAME',
+ code='\x95',
+ arg=uint8,
+ stack_before=[],
+ stack_after=[],
+ proto=4,
+ doc="""Indicate the beginning of a new frame.
+
+ The unpickler may use this opcode to safely prefetch data from its
+ underlying stream.
+ """),
+
# Ways to deal with persistent IDs.
I(name='PERSID',
@@ -1903,6 +2220,38 @@ del assure_pickle_consistency
##############################################################################
# A pickle opcode generator.
+def _genops(data, yield_end_pos=False):
+ if isinstance(data, bytes_types):
+ data = io.BytesIO(data)
+
+ if hasattr(data, "tell"):
+ getpos = data.tell
+ else:
+ getpos = lambda: None
+
+ while True:
+ pos = getpos()
+ code = data.read(1)
+ opcode = code2op.get(code.decode("latin-1"))
+ if opcode is None:
+ if code == b"":
+ raise ValueError("pickle exhausted before seeing STOP")
+ else:
+ raise ValueError("at position %s, opcode %r unknown" % (
+ "<unknown>" if pos is None else pos,
+ code))
+ if opcode.arg is None:
+ arg = None
+ else:
+ arg = opcode.arg.reader(data)
+ if yield_end_pos:
+ yield opcode, arg, pos, getpos()
+ else:
+ yield opcode, arg, pos
+ if code == b'.':
+ assert opcode.name == 'STOP'
+ break
+
def genops(pickle):
"""Generate all the opcodes in a pickle.
@@ -1926,62 +2275,47 @@ def genops(pickle):
used. Else (the pickle doesn't have a tell(), and it's not obvious how
to query its current position) pos is None.
"""
-
- if isinstance(pickle, bytes_types):
- import io
- pickle = io.BytesIO(pickle)
-
- if hasattr(pickle, "tell"):
- getpos = pickle.tell
- else:
- getpos = lambda: None
-
- while True:
- pos = getpos()
- code = pickle.read(1)
- opcode = code2op.get(code.decode("latin-1"))
- if opcode is None:
- if code == b"":
- raise ValueError("pickle exhausted before seeing STOP")
- else:
- raise ValueError("at position %s, opcode %r unknown" % (
- pos is None and "<unknown>" or pos,
- code))
- if opcode.arg is None:
- arg = None
- else:
- arg = opcode.arg.reader(pickle)
- yield opcode, arg, pos
- if code == b'.':
- assert opcode.name == 'STOP'
- break
+ return _genops(pickle)
##############################################################################
# A pickle optimizer.
def optimize(p):
'Optimize a pickle string by removing unused PUT opcodes'
- gets = set() # set of args used by a GET opcode
- puts = [] # (arg, startpos, stoppos) for the PUT opcodes
- prevpos = None # set to pos if previous opcode was a PUT
- for opcode, arg, pos in genops(p):
- if prevpos is not None:
- puts.append((prevarg, prevpos, pos))
- prevpos = None
+ not_a_put = object()
+ gets = { not_a_put } # set of args used by a GET opcode
+ opcodes = [] # (startpos, stoppos, putid)
+ proto = 0
+ for opcode, arg, pos, end_pos in _genops(p, yield_end_pos=True):
if 'PUT' in opcode.name:
- prevarg, prevpos = arg, pos
- elif 'GET' in opcode.name:
- gets.add(arg)
-
- # Copy the pickle string except for PUTS without a corresponding GET
- s = []
- i = 0
- for arg, start, stop in puts:
- j = stop if (arg in gets) else start
- s.append(p[i:j])
- i = stop
- s.append(p[i:])
- return b''.join(s)
+ opcodes.append((pos, end_pos, arg))
+ elif 'FRAME' in opcode.name:
+ pass
+ else:
+ if 'GET' in opcode.name:
+ gets.add(arg)
+ elif opcode.name == 'PROTO':
+ assert pos == 0, pos
+ proto = arg
+ opcodes.append((pos, end_pos, not_a_put))
+ prevpos, prevarg = pos, None
+
+ # Copy the opcodes except for PUTS without a corresponding GET
+ out = io.BytesIO()
+ opcodes = iter(opcodes)
+ if proto >= 2:
+ # Write the PROTO header before any framing
+ start, stop, _ = next(opcodes)
+ out.write(p[start:stop])
+ buf = pickle._Framer(out.write)
+ if proto >= 4:
+ buf.start_framing()
+ for start, stop, putid in opcodes:
+ if putid in gets:
+ buf.write(p[start:stop])
+ if proto >= 4:
+ buf.end_framing()
+ return out.getvalue()
##############################################################################
# A symbolic pickle disassembler.
@@ -2081,17 +2415,20 @@ def dis(pickle, out=None, memo=None, indentlevel=4, annotate=0):
errormsg = markmsg = "no MARK exists on stack"
# Check for correct memo usage.
- if opcode.name in ("PUT", "BINPUT", "LONG_BINPUT"):
- assert arg is not None
- if arg in memo:
+ if opcode.name in ("PUT", "BINPUT", "LONG_BINPUT", "MEMOIZE"):
+ if opcode.name == "MEMOIZE":
+ memo_idx = len(memo)
+ else:
+ assert arg is not None
+ memo_idx = arg
+ if memo_idx in memo:
errormsg = "memo key %r already defined" % arg
elif not stack:
errormsg = "stack is empty -- can't store into memo"
elif stack[-1] is markobject:
errormsg = "can't store markobject in the memo"
else:
- memo[arg] = stack[-1]
-
+ memo[memo_idx] = stack[-1]
elif opcode.name in ("GET", "BINGET", "LONG_BINGET"):
if arg in memo:
assert len(after) == 1
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index 1971120..cadc5a7 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -1,9 +1,10 @@
+import copyreg
import io
-import unittest
import pickle
import pickletools
+import random
import sys
-import copyreg
+import unittest
import weakref
from http.cookies import SimpleCookie
@@ -95,6 +96,9 @@ class E(C):
def __getinitargs__(self):
return ()
+class H(object):
+ pass
+
import __main__
__main__.C = C
C.__module__ = "__main__"
@@ -102,6 +106,8 @@ __main__.D = D
D.__module__ = "__main__"
__main__.E = E
E.__module__ = "__main__"
+__main__.H = H
+H.__module__ = "__main__"
class myint(int):
def __init__(self, x):
@@ -428,6 +434,7 @@ def create_data():
x.append(5)
return x
+
class AbstractPickleTests(unittest.TestCase):
# Subclass must define self.dumps, self.loads.
@@ -436,23 +443,41 @@ class AbstractPickleTests(unittest.TestCase):
def setUp(self):
pass
+ def assert_is_copy(self, obj, objcopy, msg=None):
+ """Utility method to verify if two objects are copies of each others.
+ """
+ if msg is None:
+ msg = "{!r} is not a copy of {!r}".format(obj, objcopy)
+ self.assertEqual(obj, objcopy, msg=msg)
+ self.assertIs(type(obj), type(objcopy), msg=msg)
+ if hasattr(obj, '__dict__'):
+ self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg)
+ self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg)
+ if hasattr(obj, '__slots__'):
+ self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg)
+ for slot in obj.__slots__:
+ self.assertEqual(
+ hasattr(obj, slot), hasattr(objcopy, slot), msg=msg)
+ self.assertEqual(getattr(obj, slot, None),
+ getattr(objcopy, slot, None), msg=msg)
+
def test_misc(self):
# test various datatypes not tested by testdata
for proto in protocols:
x = myint(4)
s = self.dumps(x, proto)
y = self.loads(s)
- self.assertEqual(x, y)
+ self.assert_is_copy(x, y)
x = (1, ())
s = self.dumps(x, proto)
y = self.loads(s)
- self.assertEqual(x, y)
+ self.assert_is_copy(x, y)
x = initarg(1, x)
s = self.dumps(x, proto)
y = self.loads(s)
- self.assertEqual(x, y)
+ self.assert_is_copy(x, y)
# XXX test __reduce__ protocol?
@@ -461,16 +486,16 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(expected, proto)
got = self.loads(s)
- self.assertEqual(expected, got)
+ self.assert_is_copy(expected, got)
def test_load_from_data0(self):
- self.assertEqual(self._testdata, self.loads(DATA0))
+ self.assert_is_copy(self._testdata, self.loads(DATA0))
def test_load_from_data1(self):
- self.assertEqual(self._testdata, self.loads(DATA1))
+ self.assert_is_copy(self._testdata, self.loads(DATA1))
def test_load_from_data2(self):
- self.assertEqual(self._testdata, self.loads(DATA2))
+ self.assert_is_copy(self._testdata, self.loads(DATA2))
def test_load_classic_instance(self):
# See issue5180. Test loading 2.x pickles that
@@ -492,7 +517,7 @@ class AbstractPickleTests(unittest.TestCase):
b"X\n"
b"p0\n"
b"(dp1\nb.").replace(b'X', xname)
- self.assertEqual(X(*args), self.loads(pickle0))
+ self.assert_is_copy(X(*args), self.loads(pickle0))
# Protocol 1 (binary mode pickle)
"""
@@ -509,7 +534,7 @@ class AbstractPickleTests(unittest.TestCase):
pickle1 = (b'(c__main__\n'
b'X\n'
b'q\x00oq\x01}q\x02b.').replace(b'X', xname)
- self.assertEqual(X(*args), self.loads(pickle1))
+ self.assert_is_copy(X(*args), self.loads(pickle1))
# Protocol 2 (pickle2 = b'\x80\x02' + pickle1)
"""
@@ -527,7 +552,7 @@ class AbstractPickleTests(unittest.TestCase):
pickle2 = (b'\x80\x02(c__main__\n'
b'X\n'
b'q\x00oq\x01}q\x02b.').replace(b'X', xname)
- self.assertEqual(X(*args), self.loads(pickle2))
+ self.assert_is_copy(X(*args), self.loads(pickle2))
# There are gratuitous differences between pickles produced by
# pickle and cPickle, largely because cPickle starts PUT indices at
@@ -552,6 +577,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(l, proto)
x = self.loads(s)
+ self.assertIsInstance(x, list)
self.assertEqual(len(x), 1)
self.assertTrue(x is x[0])
@@ -561,6 +587,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(t, proto)
x = self.loads(s)
+ self.assertIsInstance(x, tuple)
self.assertEqual(len(x), 1)
self.assertEqual(len(x[0]), 1)
self.assertTrue(x is x[0][0])
@@ -571,15 +598,39 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(d, proto)
x = self.loads(s)
+ self.assertIsInstance(x, dict)
self.assertEqual(list(x.keys()), [1])
self.assertTrue(x[1] is x)
+ def test_recursive_set(self):
+ h = H()
+ y = set({h})
+ h.attr = y
+ for proto in protocols:
+ s = self.dumps(y, proto)
+ x = self.loads(s)
+ self.assertIsInstance(x, set)
+ self.assertIs(list(x)[0].attr, x)
+ self.assertEqual(len(x), 1)
+
+ def test_recursive_frozenset(self):
+ h = H()
+ y = frozenset({h})
+ h.attr = y
+ for proto in protocols:
+ s = self.dumps(y, proto)
+ x = self.loads(s)
+ self.assertIsInstance(x, frozenset)
+ self.assertIs(list(x)[0].attr, x)
+ self.assertEqual(len(x), 1)
+
def test_recursive_inst(self):
i = C()
i.attr = i
for proto in protocols:
s = self.dumps(i, proto)
x = self.loads(s)
+ self.assertIsInstance(x, C)
self.assertEqual(dir(x), dir(i))
self.assertIs(x.attr, x)
@@ -592,6 +643,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(l, proto)
x = self.loads(s)
+ self.assertIsInstance(x, list)
self.assertEqual(len(x), 1)
self.assertEqual(dir(x[0]), dir(i))
self.assertEqual(list(x[0].attr.keys()), [1])
@@ -599,7 +651,8 @@ class AbstractPickleTests(unittest.TestCase):
def test_get(self):
self.assertRaises(KeyError, self.loads, b'g0\np0')
- self.assertEqual(self.loads(b'((Kdtp0\nh\x00l.))'), [(100,), (100,)])
+ self.assert_is_copy([(100,), (100,)],
+ self.loads(b'((Kdtp0\nh\x00l.))'))
def test_unicode(self):
endcases = ['', '<\\u>', '<\\\u1234>', '<\n>',
@@ -610,26 +663,26 @@ class AbstractPickleTests(unittest.TestCase):
for u in endcases:
p = self.dumps(u, proto)
u2 = self.loads(p)
- self.assertEqual(u2, u)
+ self.assert_is_copy(u, u2)
def test_unicode_high_plane(self):
t = '\U00012345'
for proto in protocols:
p = self.dumps(t, proto)
t2 = self.loads(p)
- self.assertEqual(t2, t)
+ self.assert_is_copy(t, t2)
def test_bytes(self):
for proto in protocols:
for s in b'', b'xyz', b'xyz'*100:
p = self.dumps(s, proto)
- self.assertEqual(self.loads(p), s)
+ self.assert_is_copy(s, self.loads(p))
for s in [bytes([i]) for i in range(256)]:
p = self.dumps(s, proto)
- self.assertEqual(self.loads(p), s)
+ self.assert_is_copy(s, self.loads(p))
for s in [bytes([i, i]) for i in range(256)]:
p = self.dumps(s, proto)
- self.assertEqual(self.loads(p), s)
+ self.assert_is_copy(s, self.loads(p))
def test_ints(self):
import sys
@@ -639,14 +692,14 @@ class AbstractPickleTests(unittest.TestCase):
for expected in (-n, n):
s = self.dumps(expected, proto)
n2 = self.loads(s)
- self.assertEqual(expected, n2)
+ self.assert_is_copy(expected, n2)
n = n >> 1
def test_maxint64(self):
maxint64 = (1 << 63) - 1
data = b'I' + str(maxint64).encode("ascii") + b'\n.'
got = self.loads(data)
- self.assertEqual(got, maxint64)
+ self.assert_is_copy(maxint64, got)
# Try too with a bogus literal.
data = b'I' + str(maxint64).encode("ascii") + b'JUNK\n.'
@@ -661,7 +714,7 @@ class AbstractPickleTests(unittest.TestCase):
for n in npos, -npos:
pickle = self.dumps(n, proto)
got = self.loads(pickle)
- self.assertEqual(n, got)
+ self.assert_is_copy(n, got)
# Try a monster. This is quadratic-time in protos 0 & 1, so don't
# bother with those.
nbase = int("deadbeeffeedface", 16)
@@ -669,7 +722,7 @@ class AbstractPickleTests(unittest.TestCase):
for n in nbase, -nbase:
p = self.dumps(n, 2)
got = self.loads(p)
- self.assertEqual(n, got)
+ self.assert_is_copy(n, got)
def test_float(self):
test_values = [0.0, 4.94e-324, 1e-310, 7e-308, 6.626e-34, 0.1, 0.5,
@@ -679,7 +732,7 @@ class AbstractPickleTests(unittest.TestCase):
for value in test_values:
pickle = self.dumps(value, proto)
got = self.loads(pickle)
- self.assertEqual(value, got)
+ self.assert_is_copy(value, got)
@run_with_locale('LC_ALL', 'de_DE', 'fr_FR')
def test_float_format(self):
@@ -711,6 +764,7 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(a, proto)
b = self.loads(s)
self.assertEqual(a, b)
+ self.assertIs(type(a), type(b))
def test_structseq(self):
import time
@@ -720,48 +774,48 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(t, proto)
u = self.loads(s)
- self.assertEqual(t, u)
+ self.assert_is_copy(t, u)
if hasattr(os, "stat"):
t = os.stat(os.curdir)
s = self.dumps(t, proto)
u = self.loads(s)
- self.assertEqual(t, u)
+ self.assert_is_copy(t, u)
if hasattr(os, "statvfs"):
t = os.statvfs(os.curdir)
s = self.dumps(t, proto)
u = self.loads(s)
- self.assertEqual(t, u)
+ self.assert_is_copy(t, u)
def test_ellipsis(self):
for proto in protocols:
s = self.dumps(..., proto)
u = self.loads(s)
- self.assertEqual(..., u)
+ self.assertIs(..., u)
def test_notimplemented(self):
for proto in protocols:
s = self.dumps(NotImplemented, proto)
u = self.loads(s)
- self.assertEqual(NotImplemented, u)
+ self.assertIs(NotImplemented, u)
# Tests for protocol 2
def test_proto(self):
- build_none = pickle.NONE + pickle.STOP
for proto in protocols:
- expected = build_none
+ pickled = self.dumps(None, proto)
if proto >= 2:
- expected = pickle.PROTO + bytes([proto]) + expected
- p = self.dumps(None, proto)
- self.assertEqual(p, expected)
+ proto_header = pickle.PROTO + bytes([proto])
+ self.assertTrue(pickled.startswith(proto_header))
+ else:
+ self.assertEqual(count_opcode(pickle.PROTO, pickled), 0)
oob = protocols[-1] + 1 # a future protocol
+ build_none = pickle.NONE + pickle.STOP
badpickle = pickle.PROTO + bytes([oob]) + build_none
try:
self.loads(badpickle)
- except ValueError as detail:
- self.assertTrue(str(detail).startswith(
- "unsupported pickle protocol"))
+ except ValueError as err:
+ self.assertIn("unsupported pickle protocol", str(err))
else:
self.fail("expected bad protocol number to raise ValueError")
@@ -770,7 +824,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
- self.assertEqual(x, y)
+ self.assert_is_copy(x, y)
self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2)
def test_long4(self):
@@ -778,7 +832,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
- self.assertEqual(x, y)
+ self.assert_is_copy(x, y)
self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2)
def test_short_tuples(self):
@@ -816,9 +870,9 @@ class AbstractPickleTests(unittest.TestCase):
for x in a, b, c, d, e:
s = self.dumps(x, proto)
y = self.loads(s)
- self.assertEqual(x, y, (proto, x, s, y))
- expected = expected_opcode[proto, len(x)]
- self.assertEqual(opcode_in_pickle(expected, s), True)
+ self.assert_is_copy(x, y)
+ expected = expected_opcode[min(proto, 3), len(x)]
+ self.assertTrue(opcode_in_pickle(expected, s))
def test_singletons(self):
# Map (proto, singleton) to expected opcode.
@@ -842,8 +896,8 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(x, proto)
y = self.loads(s)
self.assertTrue(x is y, (proto, x, s, y))
- expected = expected_opcode[proto, x]
- self.assertEqual(opcode_in_pickle(expected, s), True)
+ expected = expected_opcode[min(proto, 3), x]
+ self.assertTrue(opcode_in_pickle(expected, s))
def test_newobj_tuple(self):
x = MyTuple([1, 2, 3])
@@ -852,8 +906,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
- self.assertEqual(tuple(x), tuple(y))
- self.assertEqual(x.__dict__, y.__dict__)
+ self.assert_is_copy(x, y)
def test_newobj_list(self):
x = MyList([1, 2, 3])
@@ -862,8 +915,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
- self.assertEqual(list(x), list(y))
- self.assertEqual(x.__dict__, y.__dict__)
+ self.assert_is_copy(x, y)
def test_newobj_generic(self):
for proto in protocols:
@@ -874,6 +926,7 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(x, proto)
y = self.loads(s)
detail = (proto, C, B, x, y, type(y))
+ self.assert_is_copy(x, y) # XXX revisit
self.assertEqual(B(x), B(y), detail)
self.assertEqual(x.__dict__, y.__dict__, detail)
@@ -912,11 +965,10 @@ class AbstractPickleTests(unittest.TestCase):
s1 = self.dumps(x, 1)
self.assertIn(__name__.encode("utf-8"), s1)
self.assertIn(b"MyList", s1)
- self.assertEqual(opcode_in_pickle(opcode, s1), False)
+ self.assertFalse(opcode_in_pickle(opcode, s1))
y = self.loads(s1)
- self.assertEqual(list(x), list(y))
- self.assertEqual(x.__dict__, y.__dict__)
+ self.assert_is_copy(x, y)
# Dump using protocol 2 for test.
s2 = self.dumps(x, 2)
@@ -925,9 +977,7 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(opcode_in_pickle(opcode, s2), True, repr(s2))
y = self.loads(s2)
- self.assertEqual(list(x), list(y))
- self.assertEqual(x.__dict__, y.__dict__)
-
+ self.assert_is_copy(x, y)
finally:
e.restore()
@@ -951,7 +1001,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
- self.assertEqual(x, y)
+ self.assert_is_copy(x, y)
num_appends = count_opcode(pickle.APPENDS, s)
self.assertEqual(num_appends, proto > 0)
@@ -960,7 +1010,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
- self.assertEqual(x, y)
+ self.assert_is_copy(x, y)
num_appends = count_opcode(pickle.APPENDS, s)
if proto == 0:
self.assertEqual(num_appends, 0)
@@ -974,7 +1024,7 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(x, proto)
self.assertIsInstance(s, bytes_types)
y = self.loads(s)
- self.assertEqual(x, y)
+ self.assert_is_copy(x, y)
num_setitems = count_opcode(pickle.SETITEMS, s)
self.assertEqual(num_setitems, proto > 0)
@@ -983,22 +1033,49 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
- self.assertEqual(x, y)
+ self.assert_is_copy(x, y)
num_setitems = count_opcode(pickle.SETITEMS, s)
if proto == 0:
self.assertEqual(num_setitems, 0)
else:
self.assertTrue(num_setitems >= 2)
+ def test_set_chunking(self):
+ n = 10 # too small to chunk
+ x = set(range(n))
+ for proto in protocols:
+ s = self.dumps(x, proto)
+ y = self.loads(s)
+ self.assert_is_copy(x, y)
+ num_additems = count_opcode(pickle.ADDITEMS, s)
+ if proto < 4:
+ self.assertEqual(num_additems, 0)
+ else:
+ self.assertEqual(num_additems, 1)
+
+ n = 2500 # expect at least two chunks when proto >= 4
+ x = set(range(n))
+ for proto in protocols:
+ s = self.dumps(x, proto)
+ y = self.loads(s)
+ self.assert_is_copy(x, y)
+ num_additems = count_opcode(pickle.ADDITEMS, s)
+ if proto < 4:
+ self.assertEqual(num_additems, 0)
+ else:
+ self.assertGreaterEqual(num_additems, 2)
+
def test_simple_newobj(self):
x = object.__new__(SimpleNewObj) # avoid __init__
x.abc = 666
for proto in protocols:
s = self.dumps(x, proto)
- self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2)
+ self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s),
+ 2 <= proto < 4)
+ self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s),
+ proto >= 4)
y = self.loads(s) # will raise TypeError if __init__ called
- self.assertEqual(y.abc, 666)
- self.assertEqual(x.__dict__, y.__dict__)
+ self.assert_is_copy(x, y)
def test_newobj_list_slots(self):
x = SlotList([1, 2, 3])
@@ -1006,10 +1083,7 @@ class AbstractPickleTests(unittest.TestCase):
x.bar = "hello"
s = self.dumps(x, 2)
y = self.loads(s)
- self.assertEqual(list(x), list(y))
- self.assertEqual(x.__dict__, y.__dict__)
- self.assertEqual(x.foo, y.foo)
- self.assertEqual(x.bar, y.bar)
+ self.assert_is_copy(x, y)
def test_reduce_overrides_default_reduce_ex(self):
for proto in protocols:
@@ -1058,11 +1132,10 @@ class AbstractPickleTests(unittest.TestCase):
@no_tracing
def test_bad_getattr(self):
+ # Issue #3514: crash when there is an infinite loop in __getattr__
x = BadGetattr()
- for proto in 0, 1:
+ for proto in protocols:
self.assertRaises(RuntimeError, self.dumps, x, proto)
- # protocol 2 don't raise a RuntimeError.
- d = self.dumps(x, 2)
def test_reduce_bad_iterator(self):
# Issue4176: crash when 4th and 5th items of __reduce__()
@@ -1095,11 +1168,10 @@ class AbstractPickleTests(unittest.TestCase):
obj = [dict(large_dict), dict(large_dict), dict(large_dict)]
for proto in protocols:
- dumped = self.dumps(obj, proto)
- loaded = self.loads(dumped)
- self.assertEqual(loaded, obj,
- "Failed protocol %d: %r != %r"
- % (proto, obj, loaded))
+ with self.subTest(proto=proto):
+ dumped = self.dumps(obj, proto)
+ loaded = self.loads(dumped)
+ self.assert_is_copy(obj, loaded)
def test_attribute_name_interning(self):
# Test that attribute names of pickled objects are interned when
@@ -1155,11 +1227,14 @@ class AbstractPickleTests(unittest.TestCase):
def test_int_pickling_efficiency(self):
# Test compacity of int representation (see issue #12744)
for proto in protocols:
- sizes = [len(self.dumps(2**n, proto)) for n in range(70)]
- # the size function is monotonic
- self.assertEqual(sorted(sizes), sizes)
- if proto >= 2:
- self.assertLessEqual(sizes[-1], 14)
+ with self.subTest(proto=proto):
+ pickles = [self.dumps(2**n, proto) for n in range(70)]
+ sizes = list(map(len, pickles))
+ # the size function is monotonic
+ self.assertEqual(sorted(sizes), sizes)
+ if proto >= 2:
+ for p in pickles:
+ self.assertFalse(opcode_in_pickle(pickle.LONG, p))
def check_negative_32b_binXXX(self, dumped):
if sys.maxsize > 2**32:
@@ -1242,6 +1317,137 @@ class AbstractPickleTests(unittest.TestCase):
else:
self._check_pickling_with_opcode(obj, pickle.SETITEMS, proto)
+ # Exercise framing (proto >= 4) for significant workloads
+
+ FRAME_SIZE_TARGET = 64 * 1024
+
+ def test_framing_many_objects(self):
+ obj = list(range(10**5))
+ for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(proto=proto):
+ pickled = self.dumps(obj, proto)
+ unpickled = self.loads(pickled)
+ self.assertEqual(obj, unpickled)
+ # Test the framing heuristic is sane,
+ # assuming a given frame size target.
+ bytes_per_frame = (len(pickled) /
+ pickled.count(b'\x00\x00\x00\x00\x00'))
+ self.assertGreater(bytes_per_frame,
+ self.FRAME_SIZE_TARGET / 2)
+ self.assertLessEqual(bytes_per_frame,
+ self.FRAME_SIZE_TARGET * 1)
+
+ def test_framing_large_objects(self):
+ N = 1024 * 1024
+ obj = [b'x' * N, b'y' * N, b'z' * N]
+ for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(proto=proto):
+ pickled = self.dumps(obj, proto)
+ unpickled = self.loads(pickled)
+ self.assertEqual(obj, unpickled)
+ # At least one frame was emitted per large bytes object.
+ n_frames = pickled.count(b'\x00\x00\x00\x00\x00')
+ self.assertGreaterEqual(n_frames, len(obj))
+
+ def test_nested_names(self):
+ global Nested
+ class Nested:
+ class A:
+ class B:
+ class C:
+ pass
+
+ for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
+ for obj in [Nested.A, Nested.A.B, Nested.A.B.C]:
+ with self.subTest(proto=proto, obj=obj):
+ unpickled = self.loads(self.dumps(obj, proto))
+ self.assertIs(obj, unpickled)
+
+ def test_py_methods(self):
+ global PyMethodsTest
+ class PyMethodsTest:
+ @staticmethod
+ def cheese():
+ return "cheese"
+ @classmethod
+ def wine(cls):
+ assert cls is PyMethodsTest
+ return "wine"
+ def biscuits(self):
+ assert isinstance(self, PyMethodsTest)
+ return "biscuits"
+ class Nested:
+ "Nested class"
+ @staticmethod
+ def ketchup():
+ return "ketchup"
+ @classmethod
+ def maple(cls):
+ assert cls is PyMethodsTest.Nested
+ return "maple"
+ def pie(self):
+ assert isinstance(self, PyMethodsTest.Nested)
+ return "pie"
+
+ py_methods = (
+ PyMethodsTest.cheese,
+ PyMethodsTest.wine,
+ PyMethodsTest().biscuits,
+ PyMethodsTest.Nested.ketchup,
+ PyMethodsTest.Nested.maple,
+ PyMethodsTest.Nested().pie
+ )
+ py_unbound_methods = (
+ (PyMethodsTest.biscuits, PyMethodsTest),
+ (PyMethodsTest.Nested.pie, PyMethodsTest.Nested)
+ )
+ for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
+ for method in py_methods:
+ with self.subTest(proto=proto, method=method):
+ unpickled = self.loads(self.dumps(method, proto))
+ self.assertEqual(method(), unpickled())
+ for method, cls in py_unbound_methods:
+ obj = cls()
+ with self.subTest(proto=proto, method=method):
+ unpickled = self.loads(self.dumps(method, proto))
+ self.assertEqual(method(obj), unpickled(obj))
+
+
+ def test_c_methods(self):
+ global Subclass
+ class Subclass(tuple):
+ class Nested(str):
+ pass
+
+ c_methods = (
+ # bound built-in method
+ ("abcd".index, ("c",)),
+ # unbound built-in method
+ (str.index, ("abcd", "c")),
+ # bound "slot" method
+ ([1, 2, 3].__len__, ()),
+ # unbound "slot" method
+ (list.__len__, ([1, 2, 3],)),
+ # bound "coexist" method
+ ({1, 2}.__contains__, (2,)),
+ # unbound "coexist" method
+ (set.__contains__, ({1, 2}, 2)),
+ # built-in class method
+ (dict.fromkeys, (("a", 1), ("b", 2))),
+ # built-in static method
+ (bytearray.maketrans, (b"abc", b"xyz")),
+ # subclass methods
+ (Subclass([1,2,2]).count, (2,)),
+ (Subclass.count, (Subclass([1,2,2]), 2)),
+ (Subclass.Nested("sweet").count, ("e",)),
+ (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")),
+ )
+ for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
+ for method, args in c_methods:
+ with self.subTest(proto=proto, method=method):
+ unpickled = self.loads(self.dumps(method, proto))
+ self.assertEqual(method(*args), unpickled(*args))
+
class BigmemPickleTests(unittest.TestCase):
@@ -1252,10 +1458,11 @@ class BigmemPickleTests(unittest.TestCase):
data = 1 << (8 * size)
try:
for proto in protocols:
- if proto < 2:
- continue
- with self.assertRaises((ValueError, OverflowError)):
- self.dumps(data, protocol=proto)
+ with self.subTest(proto=proto):
+ if proto < 2:
+ continue
+ with self.assertRaises((ValueError, OverflowError)):
+ self.dumps(data, protocol=proto)
finally:
data = None
@@ -1268,14 +1475,15 @@ class BigmemPickleTests(unittest.TestCase):
data = b"abcd" * (size // 4)
try:
for proto in protocols:
- if proto < 3:
- continue
- try:
- pickled = self.dumps(data, protocol=proto)
- self.assertTrue(b"abcd" in pickled[:15])
- self.assertTrue(b"abcd" in pickled[-15:])
- finally:
- pickled = None
+ with self.subTest(proto=proto):
+ if proto < 3:
+ continue
+ try:
+ pickled = self.dumps(data, protocol=proto)
+ self.assertTrue(b"abcd" in pickled[:19])
+ self.assertTrue(b"abcd" in pickled[-18:])
+ finally:
+ pickled = None
finally:
data = None
@@ -1284,10 +1492,11 @@ class BigmemPickleTests(unittest.TestCase):
data = b"a" * size
try:
for proto in protocols:
- if proto < 3:
- continue
- with self.assertRaises((ValueError, OverflowError)):
- self.dumps(data, protocol=proto)
+ with self.subTest(proto=proto):
+ if proto < 3:
+ continue
+ with self.assertRaises((ValueError, OverflowError)):
+ self.dumps(data, protocol=proto)
finally:
data = None
@@ -1299,27 +1508,38 @@ class BigmemPickleTests(unittest.TestCase):
data = "abcd" * (size // 4)
try:
for proto in protocols:
- try:
- pickled = self.dumps(data, protocol=proto)
- self.assertTrue(b"abcd" in pickled[:15])
- self.assertTrue(b"abcd" in pickled[-15:])
- finally:
- pickled = None
+ with self.subTest(proto=proto):
+ try:
+ pickled = self.dumps(data, protocol=proto)
+ self.assertTrue(b"abcd" in pickled[:19])
+ self.assertTrue(b"abcd" in pickled[-18:])
+ finally:
+ pickled = None
finally:
data = None
- # BINUNICODE (protocols 1, 2 and 3) cannot carry more than
- # 2**32 - 1 bytes of utf-8 encoded unicode.
+ # BINUNICODE (protocols 1, 2 and 3) cannot carry more than 2**32 - 1 bytes
+ # of utf-8 encoded unicode. BINUNICODE8 (protocol 4) supports these huge
+ # unicode strings however.
- @bigmemtest(size=_4G, memuse=1 + ascii_char_size, dry_run=False)
+ @bigmemtest(size=_4G, memuse=2 + ascii_char_size, dry_run=False)
def test_huge_str_64b(self, size):
- data = "a" * size
+ data = "abcd" * (size // 4)
try:
for proto in protocols:
- if proto == 0:
- continue
- with self.assertRaises((ValueError, OverflowError)):
- self.dumps(data, protocol=proto)
+ with self.subTest(proto=proto):
+ if proto == 0:
+ continue
+ if proto < 4:
+ with self.assertRaises((ValueError, OverflowError)):
+ self.dumps(data, protocol=proto)
+ else:
+ try:
+ pickled = self.dumps(data, protocol=proto)
+ self.assertTrue(b"abcd" in pickled[:19])
+ self.assertTrue(b"abcd" in pickled[-18:])
+ finally:
+ pickled = None
finally:
data = None
@@ -1363,8 +1583,8 @@ class REX_five(object):
return object.__reduce__(self)
class REX_six(object):
- """This class is used to check the 4th argument (list iterator) of the reduce
- protocol.
+ """This class is used to check the 4th argument (list iterator) of
+ the reduce protocol.
"""
def __init__(self, items=None):
self.items = items if items is not None else []
@@ -1376,8 +1596,8 @@ class REX_six(object):
return type(self), (), None, iter(self.items), None
class REX_seven(object):
- """This class is used to check the 5th argument (dict iterator) of the reduce
- protocol.
+ """This class is used to check the 5th argument (dict iterator) of
+ the reduce protocol.
"""
def __init__(self, table=None):
self.table = table if table is not None else {}
@@ -1415,10 +1635,16 @@ class MyList(list):
class MyDict(dict):
sample = {"a": 1, "b": 2}
+class MySet(set):
+ sample = {"a", "b"}
+
+class MyFrozenSet(frozenset):
+ sample = frozenset({"a", "b"})
+
myclasses = [MyInt, MyFloat,
MyComplex,
MyStr, MyUnicode,
- MyTuple, MyList, MyDict]
+ MyTuple, MyList, MyDict, MySet, MyFrozenSet]
class SlotList(MyList):
@@ -1428,6 +1654,8 @@ class SimpleNewObj(object):
def __init__(self, a, b, c):
# raise an error, to make sure this isn't called
raise TypeError("SimpleNewObj.__init__() didn't expect to get called")
+ def __eq__(self, other):
+ return self.__dict__ == other.__dict__
class BadGetattr:
def __getattr__(self, key):
@@ -1464,7 +1692,7 @@ class AbstractPickleModuleTests(unittest.TestCase):
def test_highest_protocol(self):
# Of course this needs to be changed when HIGHEST_PROTOCOL changes.
- self.assertEqual(pickle.HIGHEST_PROTOCOL, 3)
+ self.assertEqual(pickle.HIGHEST_PROTOCOL, 4)
def test_callapi(self):
f = io.BytesIO()
@@ -1645,22 +1873,23 @@ class AbstractPicklerUnpicklerObjectTests(unittest.TestCase):
def _check_multiple_unpicklings(self, ioclass):
for proto in protocols:
- data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len]
- f = ioclass()
- pickler = self.pickler_class(f, protocol=proto)
- pickler.dump(data1)
- pickled = f.getvalue()
-
- N = 5
- f = ioclass(pickled * N)
- unpickler = self.unpickler_class(f)
- for i in range(N):
- if f.seekable():
- pos = f.tell()
- self.assertEqual(unpickler.load(), data1)
- if f.seekable():
- self.assertEqual(f.tell(), pos + len(pickled))
- self.assertRaises(EOFError, unpickler.load)
+ with self.subTest(proto=proto):
+ data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len]
+ f = ioclass()
+ pickler = self.pickler_class(f, protocol=proto)
+ pickler.dump(data1)
+ pickled = f.getvalue()
+
+ N = 5
+ f = ioclass(pickled * N)
+ unpickler = self.unpickler_class(f)
+ for i in range(N):
+ if f.seekable():
+ pos = f.tell()
+ self.assertEqual(unpickler.load(), data1)
+ if f.seekable():
+ self.assertEqual(f.tell(), pos + len(pickled))
+ self.assertRaises(EOFError, unpickler.load)
def test_multiple_unpicklings_seekable(self):
self._check_multiple_unpicklings(io.BytesIO)
diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py
index 5aa4b0a..8ef51de 100644
--- a/Lib/test/test_descr.py
+++ b/Lib/test/test_descr.py
@@ -1,8 +1,11 @@
import builtins
+import copyreg
import gc
+import itertools
+import math
+import pickle
import sys
import types
-import math
import unittest
import weakref
@@ -3153,176 +3156,6 @@ order (MRO) for bases """
self.assertEqual(e.a, 1)
self.assertEqual(can_delete_dict(e), can_delete_dict(ValueError()))
- def test_pickles(self):
- # Testing pickling and copying new-style classes and objects...
- import pickle
-
- def sorteditems(d):
- L = list(d.items())
- L.sort()
- return L
-
- global C
- class C(object):
- def __init__(self, a, b):
- super(C, self).__init__()
- self.a = a
- self.b = b
- def __repr__(self):
- return "C(%r, %r)" % (self.a, self.b)
-
- global C1
- class C1(list):
- def __new__(cls, a, b):
- return super(C1, cls).__new__(cls)
- def __getnewargs__(self):
- return (self.a, self.b)
- def __init__(self, a, b):
- self.a = a
- self.b = b
- def __repr__(self):
- return "C1(%r, %r)<%r>" % (self.a, self.b, list(self))
-
- global C2
- class C2(int):
- def __new__(cls, a, b, val=0):
- return super(C2, cls).__new__(cls, val)
- def __getnewargs__(self):
- return (self.a, self.b, int(self))
- def __init__(self, a, b, val=0):
- self.a = a
- self.b = b
- def __repr__(self):
- return "C2(%r, %r)<%r>" % (self.a, self.b, int(self))
-
- global C3
- class C3(object):
- def __init__(self, foo):
- self.foo = foo
- def __getstate__(self):
- return self.foo
- def __setstate__(self, foo):
- self.foo = foo
-
- global C4classic, C4
- class C4classic: # classic
- pass
- class C4(C4classic, object): # mixed inheritance
- pass
-
- for bin in 0, 1:
- for cls in C, C1, C2:
- s = pickle.dumps(cls, bin)
- cls2 = pickle.loads(s)
- self.assertIs(cls2, cls)
-
- a = C1(1, 2); a.append(42); a.append(24)
- b = C2("hello", "world", 42)
- s = pickle.dumps((a, b), bin)
- x, y = pickle.loads(s)
- self.assertEqual(x.__class__, a.__class__)
- self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__))
- self.assertEqual(y.__class__, b.__class__)
- self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__))
- self.assertEqual(repr(x), repr(a))
- self.assertEqual(repr(y), repr(b))
- # Test for __getstate__ and __setstate__ on new style class
- u = C3(42)
- s = pickle.dumps(u, bin)
- v = pickle.loads(s)
- self.assertEqual(u.__class__, v.__class__)
- self.assertEqual(u.foo, v.foo)
- # Test for picklability of hybrid class
- u = C4()
- u.foo = 42
- s = pickle.dumps(u, bin)
- v = pickle.loads(s)
- self.assertEqual(u.__class__, v.__class__)
- self.assertEqual(u.foo, v.foo)
-
- # Testing copy.deepcopy()
- import copy
- for cls in C, C1, C2:
- cls2 = copy.deepcopy(cls)
- self.assertIs(cls2, cls)
-
- a = C1(1, 2); a.append(42); a.append(24)
- b = C2("hello", "world", 42)
- x, y = copy.deepcopy((a, b))
- self.assertEqual(x.__class__, a.__class__)
- self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__))
- self.assertEqual(y.__class__, b.__class__)
- self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__))
- self.assertEqual(repr(x), repr(a))
- self.assertEqual(repr(y), repr(b))
-
- def test_pickle_slots(self):
- # Testing pickling of classes with __slots__ ...
- import pickle
- # Pickling of classes with __slots__ but without __getstate__ should fail
- # (if using protocol 0 or 1)
- global B, C, D, E
- class B(object):
- pass
- for base in [object, B]:
- class C(base):
- __slots__ = ['a']
- class D(C):
- pass
- try:
- pickle.dumps(C(), 0)
- except TypeError:
- pass
- else:
- self.fail("should fail: pickle C instance - %s" % base)
- try:
- pickle.dumps(C(), 0)
- except TypeError:
- pass
- else:
- self.fail("should fail: pickle D instance - %s" % base)
- # Give C a nice generic __getstate__ and __setstate__
- class C(base):
- __slots__ = ['a']
- def __getstate__(self):
- try:
- d = self.__dict__.copy()
- except AttributeError:
- d = {}
- for cls in self.__class__.__mro__:
- for sn in cls.__dict__.get('__slots__', ()):
- try:
- d[sn] = getattr(self, sn)
- except AttributeError:
- pass
- return d
- def __setstate__(self, d):
- for k, v in list(d.items()):
- setattr(self, k, v)
- class D(C):
- pass
- # Now it should work
- x = C()
- y = pickle.loads(pickle.dumps(x))
- self.assertNotHasAttr(y, 'a')
- x.a = 42
- y = pickle.loads(pickle.dumps(x))
- self.assertEqual(y.a, 42)
- x = D()
- x.a = 42
- x.b = 100
- y = pickle.loads(pickle.dumps(x))
- self.assertEqual(y.a + y.b, 142)
- # A subclass that adds a slot should also work
- class E(C):
- __slots__ = ['b']
- x = E()
- x.a = 42
- x.b = "foo"
- y = pickle.loads(pickle.dumps(x))
- self.assertEqual(y.a, x.a)
- self.assertEqual(y.b, x.b)
-
def test_binary_operator_override(self):
# Testing overrides of binary operations...
class I(int):
@@ -4690,11 +4523,439 @@ class MiscTests(unittest.TestCase):
self.assertEqual(X.mykey2, 'from Base2')
+class PicklingTests(unittest.TestCase):
+
+ def _check_reduce(self, proto, obj, args=(), kwargs={}, state=None,
+ listitems=None, dictitems=None):
+ if proto >= 4:
+ reduce_value = obj.__reduce_ex__(proto)
+ self.assertEqual(reduce_value[:3],
+ (copyreg.__newobj_ex__,
+ (type(obj), args, kwargs),
+ state))
+ if listitems is not None:
+ self.assertListEqual(list(reduce_value[3]), listitems)
+ else:
+ self.assertIsNone(reduce_value[3])
+ if dictitems is not None:
+ self.assertDictEqual(dict(reduce_value[4]), dictitems)
+ else:
+ self.assertIsNone(reduce_value[4])
+ elif proto >= 2:
+ reduce_value = obj.__reduce_ex__(proto)
+ self.assertEqual(reduce_value[:3],
+ (copyreg.__newobj__,
+ (type(obj),) + args,
+ state))
+ if listitems is not None:
+ self.assertListEqual(list(reduce_value[3]), listitems)
+ else:
+ self.assertIsNone(reduce_value[3])
+ if dictitems is not None:
+ self.assertDictEqual(dict(reduce_value[4]), dictitems)
+ else:
+ self.assertIsNone(reduce_value[4])
+ else:
+ base_type = type(obj).__base__
+ reduce_value = (copyreg._reconstructor,
+ (type(obj),
+ base_type,
+ None if base_type is object else base_type(obj)))
+ if state is not None:
+ reduce_value += (state,)
+ self.assertEqual(obj.__reduce_ex__(proto), reduce_value)
+ self.assertEqual(obj.__reduce__(), reduce_value)
+
+ def test_reduce(self):
+ protocols = range(pickle.HIGHEST_PROTOCOL + 1)
+ args = (-101, "spam")
+ kwargs = {'bacon': -201, 'fish': -301}
+ state = {'cheese': -401}
+
+ class C1:
+ def __getnewargs__(self):
+ return args
+ obj = C1()
+ for proto in protocols:
+ self._check_reduce(proto, obj, args)
+
+ for name, value in state.items():
+ setattr(obj, name, value)
+ for proto in protocols:
+ self._check_reduce(proto, obj, args, state=state)
+
+ class C2:
+ def __getnewargs__(self):
+ return "bad args"
+ obj = C2()
+ for proto in protocols:
+ if proto >= 2:
+ with self.assertRaises(TypeError):
+ obj.__reduce_ex__(proto)
+
+ class C3:
+ def __getnewargs_ex__(self):
+ return (args, kwargs)
+ obj = C3()
+ for proto in protocols:
+ if proto >= 4:
+ self._check_reduce(proto, obj, args, kwargs)
+ elif proto >= 2:
+ with self.assertRaises(ValueError):
+ obj.__reduce_ex__(proto)
+
+ class C4:
+ def __getnewargs_ex__(self):
+ return (args, "bad dict")
+ class C5:
+ def __getnewargs_ex__(self):
+ return ("bad tuple", kwargs)
+ class C6:
+ def __getnewargs_ex__(self):
+ return ()
+ class C7:
+ def __getnewargs_ex__(self):
+ return "bad args"
+ for proto in protocols:
+ for cls in C4, C5, C6, C7:
+ obj = cls()
+ if proto >= 2:
+ with self.assertRaises((TypeError, ValueError)):
+ obj.__reduce_ex__(proto)
+
+ class C8:
+ def __getnewargs_ex__(self):
+ return (args, kwargs)
+ obj = C8()
+ for proto in protocols:
+ if 2 <= proto < 4:
+ with self.assertRaises(ValueError):
+ obj.__reduce_ex__(proto)
+ class C9:
+ def __getnewargs_ex__(self):
+ return (args, {})
+ obj = C9()
+ for proto in protocols:
+ self._check_reduce(proto, obj, args)
+
+ class C10:
+ def __getnewargs_ex__(self):
+ raise IndexError
+ obj = C10()
+ for proto in protocols:
+ if proto >= 2:
+ with self.assertRaises(IndexError):
+ obj.__reduce_ex__(proto)
+
+ class C11:
+ def __getstate__(self):
+ return state
+ obj = C11()
+ for proto in protocols:
+ self._check_reduce(proto, obj, state=state)
+
+ class C12:
+ def __getstate__(self):
+ return "not dict"
+ obj = C12()
+ for proto in protocols:
+ self._check_reduce(proto, obj, state="not dict")
+
+ class C13:
+ def __getstate__(self):
+ raise IndexError
+ obj = C13()
+ for proto in protocols:
+ with self.assertRaises(IndexError):
+ obj.__reduce_ex__(proto)
+ if proto < 2:
+ with self.assertRaises(IndexError):
+ obj.__reduce__()
+
+ class C14:
+ __slots__ = tuple(state)
+ def __init__(self):
+ for name, value in state.items():
+ setattr(self, name, value)
+
+ obj = C14()
+ for proto in protocols:
+ if proto >= 2:
+ self._check_reduce(proto, obj, state=(None, state))
+ else:
+ with self.assertRaises(TypeError):
+ obj.__reduce_ex__(proto)
+ with self.assertRaises(TypeError):
+ obj.__reduce__()
+
+ class C15(dict):
+ pass
+ obj = C15({"quebec": -601})
+ for proto in protocols:
+ self._check_reduce(proto, obj, dictitems=dict(obj))
+
+ class C16(list):
+ pass
+ obj = C16(["yukon"])
+ for proto in protocols:
+ self._check_reduce(proto, obj, listitems=list(obj))
+
+ def _assert_is_copy(self, obj, objcopy, msg=None):
+ """Utility method to verify if two objects are copies of each others.
+ """
+ if msg is None:
+ msg = "{!r} is not a copy of {!r}".format(obj, objcopy)
+ if type(obj).__repr__ is object.__repr__:
+ # We have this limitation for now because we use the object's repr
+ # to help us verify that the two objects are copies. This allows
+ # us to delegate the non-generic verification logic to the objects
+ # themselves.
+ raise ValueError("object passed to _assert_is_copy must " +
+ "override the __repr__ method.")
+ self.assertIsNot(obj, objcopy, msg=msg)
+ self.assertIs(type(obj), type(objcopy), msg=msg)
+ if hasattr(obj, '__dict__'):
+ self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg)
+ self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg)
+ if hasattr(obj, '__slots__'):
+ self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg)
+ for slot in obj.__slots__:
+ self.assertEqual(
+ hasattr(obj, slot), hasattr(objcopy, slot), msg=msg)
+ self.assertEqual(getattr(obj, slot, None),
+ getattr(objcopy, slot, None), msg=msg)
+ self.assertEqual(repr(obj), repr(objcopy), msg=msg)
+
+ @staticmethod
+ def _generate_pickle_copiers():
+ """Utility method to generate the many possible pickle configurations.
+ """
+ class PickleCopier:
+ "This class copies object using pickle."
+ def __init__(self, proto, dumps, loads):
+ self.proto = proto
+ self.dumps = dumps
+ self.loads = loads
+ def copy(self, obj):
+ return self.loads(self.dumps(obj, self.proto))
+ def __repr__(self):
+ # We try to be as descriptive as possible here since this is
+ # the string which we will allow us to tell the pickle
+ # configuration we are using during debugging.
+ return ("PickleCopier(proto={}, dumps={}.{}, loads={}.{})"
+ .format(self.proto,
+ self.dumps.__module__, self.dumps.__qualname__,
+ self.loads.__module__, self.loads.__qualname__))
+ return (PickleCopier(*args) for args in
+ itertools.product(range(pickle.HIGHEST_PROTOCOL + 1),
+ {pickle.dumps, pickle._dumps},
+ {pickle.loads, pickle._loads}))
+
+ def test_pickle_slots(self):
+ # Tests pickling of classes with __slots__.
+
+ # Pickling of classes with __slots__ but without __getstate__ should
+ # fail (if using protocol 0 or 1)
+ global C
+ class C:
+ __slots__ = ['a']
+ with self.assertRaises(TypeError):
+ pickle.dumps(C(), 0)
+
+ global D
+ class D(C):
+ pass
+ with self.assertRaises(TypeError):
+ pickle.dumps(D(), 0)
+
+ class C:
+ "A class with __getstate__ and __setstate__ implemented."
+ __slots__ = ['a']
+ def __getstate__(self):
+ state = getattr(self, '__dict__', {}).copy()
+ for cls in type(self).__mro__:
+ for slot in cls.__dict__.get('__slots__', ()):
+ try:
+ state[slot] = getattr(self, slot)
+ except AttributeError:
+ pass
+ return state
+ def __setstate__(self, state):
+ for k, v in state.items():
+ setattr(self, k, v)
+ def __repr__(self):
+ return "%s()<%r>" % (type(self).__name__, self.__getstate__())
+
+ class D(C):
+ "A subclass of a class with slots."
+ pass
+
+ global E
+ class E(C):
+ "A subclass with an extra slot."
+ __slots__ = ['b']
+
+ # Now it should work
+ for pickle_copier in self._generate_pickle_copiers():
+ with self.subTest(pickle_copier=pickle_copier):
+ x = C()
+ y = pickle_copier.copy(x)
+ self._assert_is_copy(x, y)
+
+ x.a = 42
+ y = pickle_copier.copy(x)
+ self._assert_is_copy(x, y)
+
+ x = D()
+ x.a = 42
+ x.b = 100
+ y = pickle_copier.copy(x)
+ self._assert_is_copy(x, y)
+
+ x = E()
+ x.a = 42
+ x.b = "foo"
+ y = pickle_copier.copy(x)
+ self._assert_is_copy(x, y)
+
+ def test_reduce_copying(self):
+ # Tests pickling and copying new-style classes and objects.
+ global C1
+ class C1:
+ "The state of this class is copyable via its instance dict."
+ ARGS = (1, 2)
+ NEED_DICT_COPYING = True
+ def __init__(self, a, b):
+ super().__init__()
+ self.a = a
+ self.b = b
+ def __repr__(self):
+ return "C1(%r, %r)" % (self.a, self.b)
+
+ global C2
+ class C2(list):
+ "A list subclass copyable via __getnewargs__."
+ ARGS = (1, 2)
+ NEED_DICT_COPYING = False
+ def __new__(cls, a, b):
+ self = super().__new__(cls)
+ self.a = a
+ self.b = b
+ return self
+ def __init__(self, *args):
+ super().__init__()
+ # This helps testing that __init__ is not called during the
+ # unpickling process, which would cause extra appends.
+ self.append("cheese")
+ @classmethod
+ def __getnewargs__(cls):
+ return cls.ARGS
+ def __repr__(self):
+ return "C2(%r, %r)<%r>" % (self.a, self.b, list(self))
+
+ global C3
+ class C3(list):
+ "A list subclass copyable via __getstate__."
+ ARGS = (1, 2)
+ NEED_DICT_COPYING = False
+ def __init__(self, a, b):
+ self.a = a
+ self.b = b
+ # This helps testing that __init__ is not called during the
+ # unpickling process, which would cause extra appends.
+ self.append("cheese")
+ @classmethod
+ def __getstate__(cls):
+ return cls.ARGS
+ def __setstate__(self, state):
+ a, b = state
+ self.a = a
+ self.b = b
+ def __repr__(self):
+ return "C3(%r, %r)<%r>" % (self.a, self.b, list(self))
+
+ global C4
+ class C4(int):
+ "An int subclass copyable via __getnewargs__."
+ ARGS = ("hello", "world", 1)
+ NEED_DICT_COPYING = False
+ def __new__(cls, a, b, value):
+ self = super().__new__(cls, value)
+ self.a = a
+ self.b = b
+ return self
+ @classmethod
+ def __getnewargs__(cls):
+ return cls.ARGS
+ def __repr__(self):
+ return "C4(%r, %r)<%r>" % (self.a, self.b, int(self))
+
+ global C5
+ class C5(int):
+ "An int subclass copyable via __getnewargs_ex__."
+ ARGS = (1, 2)
+ KWARGS = {'value': 3}
+ NEED_DICT_COPYING = False
+ def __new__(cls, a, b, *, value=0):
+ self = super().__new__(cls, value)
+ self.a = a
+ self.b = b
+ return self
+ @classmethod
+ def __getnewargs_ex__(cls):
+ return (cls.ARGS, cls.KWARGS)
+ def __repr__(self):
+ return "C5(%r, %r)<%r>" % (self.a, self.b, int(self))
+
+ test_classes = (C1, C2, C3, C4, C5)
+ # Testing copying through pickle
+ pickle_copiers = self._generate_pickle_copiers()
+ for cls, pickle_copier in itertools.product(test_classes, pickle_copiers):
+ with self.subTest(cls=cls, pickle_copier=pickle_copier):
+ kwargs = getattr(cls, 'KWARGS', {})
+ obj = cls(*cls.ARGS, **kwargs)
+ proto = pickle_copier.proto
+ if 2 <= proto < 4 and hasattr(cls, '__getnewargs_ex__'):
+ with self.assertRaises(ValueError):
+ pickle_copier.dumps(obj, proto)
+ continue
+ objcopy = pickle_copier.copy(obj)
+ self._assert_is_copy(obj, objcopy)
+ # For test classes that supports this, make sure we didn't go
+ # around the reduce protocol by simply copying the attribute
+ # dictionary. We clear attributes using the previous copy to
+ # not mutate the original argument.
+ if proto >= 2 and not cls.NEED_DICT_COPYING:
+ objcopy.__dict__.clear()
+ objcopy2 = pickle_copier.copy(objcopy)
+ self._assert_is_copy(obj, objcopy2)
+
+ # Testing copying through copy.deepcopy()
+ for cls in test_classes:
+ with self.subTest(cls=cls):
+ kwargs = getattr(cls, 'KWARGS', {})
+ obj = cls(*cls.ARGS, **kwargs)
+ # XXX: We need to modify the copy module to support PEP 3154's
+ # reduce protocol 4.
+ if hasattr(cls, '__getnewargs_ex__'):
+ continue
+ objcopy = deepcopy(obj)
+ self._assert_is_copy(obj, objcopy)
+ # For test classes that supports this, make sure we didn't go
+ # around the reduce protocol by simply copying the attribute
+ # dictionary. We clear attributes using the previous copy to
+ # not mutate the original argument.
+ if not cls.NEED_DICT_COPYING:
+ objcopy.__dict__.clear()
+ objcopy2 = deepcopy(objcopy)
+ self._assert_is_copy(obj, objcopy2)
+
+
def test_main():
# Run all local test cases, with PTypesLongInitTest first.
support.run_unittest(PTypesLongInitTest, OperatorsTest,
ClassPropertiesAndMethods, DictProxyTests,
- MiscTests)
+ MiscTests, PicklingTests)
if __name__ == "__main__":
test_main()
diff --git a/Misc/NEWS b/Misc/NEWS
index 336c3f4..06bb771 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -68,6 +68,8 @@ Core and Builtins
Library
-------
+- Issue #17810: Implement PEP 3154, pickle protocol 4.
+
- Issue #19668: Added support for the cp1125 encoding.
- Issue #19689: Add ssl.create_default_context() factory function. It creates
diff --git a/Modules/_pickle.c b/Modules/_pickle.c
index 9852cd3..f9aa043 100644
--- a/Modules/_pickle.c
+++ b/Modules/_pickle.c
@@ -6,7 +6,7 @@ PyDoc_STRVAR(pickle_module_doc,
/* Bump this when new opcodes are added to the pickle protocol. */
enum {
- HIGHEST_PROTOCOL = 3,
+ HIGHEST_PROTOCOL = 4,
DEFAULT_PROTOCOL = 3
};
@@ -71,7 +71,19 @@ enum opcode {
/* Protocol 3 (Python 3.x) */
BINBYTES = 'B',
- SHORT_BINBYTES = 'C'
+ SHORT_BINBYTES = 'C',
+
+ /* Protocol 4 */
+ SHORT_BINUNICODE = '\x8c',
+ BINUNICODE8 = '\x8d',
+ BINBYTES8 = '\x8e',
+ EMPTY_SET = '\x8f',
+ ADDITEMS = '\x90',
+ FROZENSET = '\x91',
+ NEWOBJ_EX = '\x92',
+ STACK_GLOBAL = '\x93',
+ MEMOIZE = '\x94',
+ FRAME = '\x95'
};
/* These aren't opcodes -- they're ways to pickle bools before protocol 2
@@ -103,7 +115,11 @@ enum {
MAX_WRITE_BUF_SIZE = 64 * 1024,
/* Prefetch size when unpickling (disabled on unpeekable streams) */
- PREFETCH = 8192 * 16
+ PREFETCH = 8192 * 16,
+
+ FRAME_SIZE_TARGET = 64 * 1024,
+
+ FRAME_HEADER_SIZE = 9
};
/* Exception classes for pickle. These should override the ones defined in
@@ -136,9 +152,6 @@ static PyObject *empty_tuple = NULL;
/* For looking up name pairs in copyreg._extension_registry. */
static PyObject *two_tuple = NULL;
-_Py_IDENTIFIER(__name__);
-_Py_IDENTIFIER(modules);
-
static int
stack_underflow(void)
{
@@ -332,7 +345,12 @@ typedef struct PicklerObject {
Py_ssize_t max_output_len; /* Allocation size of output_buffer. */
int proto; /* Pickle protocol number, >= 0 */
int bin; /* Boolean, true if proto > 0 */
- Py_ssize_t buf_size; /* Size of the current buffered pickle data */
+ int framing; /* True when framing is enabled, proto >= 4 */
+ Py_ssize_t frame_start; /* Position in output_buffer where the
+ where the current frame begins. -1 if there
+ is no frame currently open. */
+
+ Py_ssize_t buf_size; /* Size of the current buffered pickle data */
int fast; /* Enable fast mode if set to a true value.
The fast mode disable the usage of memo,
therefore speeding the pickling process by
@@ -352,7 +370,8 @@ typedef struct UnpicklerObject {
/* The unpickler memo is just an array of PyObject *s. Using a dict
is unnecessary, since the keys are contiguous ints. */
PyObject **memo;
- Py_ssize_t memo_size;
+ Py_ssize_t memo_size; /* Capacity of the memo array */
+ Py_ssize_t memo_len; /* Number of objects in the memo */
PyObject *arg;
PyObject *pers_func; /* persistent_load() method, can be NULL. */
@@ -362,7 +381,9 @@ typedef struct UnpicklerObject {
char *input_line;
Py_ssize_t input_len;
Py_ssize_t next_read_idx;
+ Py_ssize_t frame_end_idx;
Py_ssize_t prefetched_idx; /* index of first prefetched byte */
+
PyObject *read; /* read() method of the input stream. */
PyObject *readline; /* readline() method of the input stream. */
PyObject *peek; /* peek() method of the input stream, or NULL */
@@ -380,6 +401,7 @@ typedef struct UnpicklerObject {
int proto; /* Protocol of the pickle loaded. */
int fix_imports; /* Indicate whether Unpickler should fix
the name of globals pickled by Python 2.x. */
+ int framing; /* True when framing is enabled, proto >= 4 */
} UnpicklerObject;
/* Forward declarations */
@@ -673,15 +695,63 @@ _Pickler_ClearBuffer(PicklerObject *self)
if (self->output_buffer == NULL)
return -1;
self->output_len = 0;
+ self->frame_start = -1;
+ return 0;
+}
+
+static void
+_Pickler_WriteFrameHeader(PicklerObject *self, char *qdata, size_t frame_len)
+{
+ qdata[0] = (unsigned char)FRAME;
+ qdata[1] = (unsigned char)(frame_len & 0xff);
+ qdata[2] = (unsigned char)((frame_len >> 8) & 0xff);
+ qdata[3] = (unsigned char)((frame_len >> 16) & 0xff);
+ qdata[4] = (unsigned char)((frame_len >> 24) & 0xff);
+ qdata[5] = (unsigned char)((frame_len >> 32) & 0xff);
+ qdata[6] = (unsigned char)((frame_len >> 40) & 0xff);
+ qdata[7] = (unsigned char)((frame_len >> 48) & 0xff);
+ qdata[8] = (unsigned char)((frame_len >> 56) & 0xff);
+}
+
+static int
+_Pickler_CommitFrame(PicklerObject *self)
+{
+ size_t frame_len;
+ char *qdata;
+
+ if (!self->framing || self->frame_start == -1)
+ return 0;
+ frame_len = self->output_len - self->frame_start - FRAME_HEADER_SIZE;
+ qdata = PyBytes_AS_STRING(self->output_buffer) + self->frame_start;
+ _Pickler_WriteFrameHeader(self, qdata, frame_len);
+ self->frame_start = -1;
return 0;
}
+static int
+_Pickler_OpcodeBoundary(PicklerObject *self)
+{
+ Py_ssize_t frame_len;
+
+ if (!self->framing || self->frame_start == -1)
+ return 0;
+ frame_len = self->output_len - self->frame_start - FRAME_HEADER_SIZE;
+ if (frame_len >= FRAME_SIZE_TARGET)
+ return _Pickler_CommitFrame(self);
+ else
+ return 0;
+}
+
static PyObject *
_Pickler_GetString(PicklerObject *self)
{
PyObject *output_buffer = self->output_buffer;
assert(self->output_buffer != NULL);
+
+ if (_Pickler_CommitFrame(self))
+ return NULL;
+
self->output_buffer = NULL;
/* Resize down to exact size */
if (_PyBytes_Resize(&output_buffer, self->output_len) < 0)
@@ -696,6 +766,7 @@ _Pickler_FlushToFile(PicklerObject *self)
assert(self->write != NULL);
+ /* This will commit the frame first */
output = _Pickler_GetString(self);
if (output == NULL)
return -1;
@@ -706,57 +777,93 @@ _Pickler_FlushToFile(PicklerObject *self)
}
static Py_ssize_t
-_Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t n)
+_Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t data_len)
{
- Py_ssize_t i, required;
+ Py_ssize_t i, n, required;
char *buffer;
+ int need_new_frame;
assert(s != NULL);
+ need_new_frame = (self->framing && self->frame_start == -1);
+
+ if (need_new_frame)
+ n = data_len + FRAME_HEADER_SIZE;
+ else
+ n = data_len;
required = self->output_len + n;
- if (required > self->max_output_len) {
- if (self->write != NULL && required > MAX_WRITE_BUF_SIZE) {
- /* XXX This reallocates a new buffer every time, which is a bit
- wasteful. */
- if (_Pickler_FlushToFile(self) < 0)
- return -1;
- if (_Pickler_ClearBuffer(self) < 0)
- return -1;
- }
- if (self->write != NULL && n > MAX_WRITE_BUF_SIZE) {
- /* we already flushed above, so the buffer is empty */
- PyObject *result;
- /* XXX we could spare an intermediate copy and pass
- a memoryview instead */
- PyObject *output = PyBytes_FromStringAndSize(s, n);
- if (s == NULL)
+ if (self->write != NULL && required > MAX_WRITE_BUF_SIZE) {
+ /* XXX This reallocates a new buffer every time, which is a bit
+ wasteful. */
+ if (_Pickler_FlushToFile(self) < 0)
+ return -1;
+ if (_Pickler_ClearBuffer(self) < 0)
+ return -1;
+ /* The previous frame was just committed by _Pickler_FlushToFile */
+ need_new_frame = self->framing;
+ if (need_new_frame)
+ n = data_len + FRAME_HEADER_SIZE;
+ else
+ n = data_len;
+ required = self->output_len + n;
+ }
+ if (self->write != NULL && n > MAX_WRITE_BUF_SIZE) {
+ /* For large pickle chunks, we write directly to the output
+ file instead of buffering. Note the buffer is empty at this
+ point (it was flushed above, since required >= n). */
+ PyObject *output, *result;
+ if (need_new_frame) {
+ char frame_header[FRAME_HEADER_SIZE];
+ _Pickler_WriteFrameHeader(self, frame_header, (size_t) data_len);
+ output = PyBytes_FromStringAndSize(frame_header, FRAME_HEADER_SIZE);
+ if (output == NULL)
return -1;
result = _Pickler_FastCall(self, self->write, output);
Py_XDECREF(result);
- return (result == NULL) ? -1 : 0;
- }
- else {
- if (self->output_len >= PY_SSIZE_T_MAX / 2 - n) {
- PyErr_NoMemory();
- return -1;
- }
- self->max_output_len = (self->output_len + n) / 2 * 3;
- if (_PyBytes_Resize(&self->output_buffer, self->max_output_len) < 0)
+ if (result == NULL)
return -1;
}
+ /* XXX we could spare an intermediate copy and pass
+ a memoryview instead */
+ output = PyBytes_FromStringAndSize(s, data_len);
+ if (output == NULL)
+ return -1;
+ result = _Pickler_FastCall(self, self->write, output);
+ Py_XDECREF(result);
+ return (result == NULL) ? -1 : 0;
+ }
+ if (required > self->max_output_len) {
+ /* Make place in buffer for the pickle chunk */
+ if (self->output_len >= PY_SSIZE_T_MAX / 2 - n) {
+ PyErr_NoMemory();
+ return -1;
+ }
+ self->max_output_len = (self->output_len + n) / 2 * 3;
+ if (_PyBytes_Resize(&self->output_buffer, self->max_output_len) < 0)
+ return -1;
}
buffer = PyBytes_AS_STRING(self->output_buffer);
- if (n < 8) {
+ if (need_new_frame) {
+ /* Setup new frame */
+ Py_ssize_t frame_start = self->output_len;
+ self->frame_start = frame_start;
+ for (i = 0; i < FRAME_HEADER_SIZE; i++) {
+ /* Write an invalid value, for debugging */
+ buffer[frame_start + i] = 0xFE;
+ }
+ self->output_len += FRAME_HEADER_SIZE;
+ }
+ if (data_len < 8) {
/* This is faster than memcpy when the string is short. */
- for (i = 0; i < n; i++) {
+ for (i = 0; i < data_len; i++) {
buffer[self->output_len + i] = s[i];
}
}
else {
- memcpy(buffer + self->output_len, s, n);
+ memcpy(buffer + self->output_len, s, data_len);
}
- self->output_len += n;
- return n;
+ self->output_len += data_len;
+ return data_len;
}
static PicklerObject *
@@ -774,6 +881,8 @@ _Pickler_New(void)
self->write = NULL;
self->proto = 0;
self->bin = 0;
+ self->framing = 0;
+ self->frame_start = -1;
self->fast = 0;
self->fast_nesting = 0;
self->fix_imports = 0;
@@ -868,6 +977,7 @@ _Unpickler_SetStringInput(UnpicklerObject *self, PyObject *input)
self->input_buffer = self->buffer.buf;
self->input_len = self->buffer.len;
self->next_read_idx = 0;
+ self->frame_end_idx = -1;
self->prefetched_idx = self->input_len;
return self->input_len;
}
@@ -932,7 +1042,7 @@ _Unpickler_ReadFromFile(UnpicklerObject *self, Py_ssize_t n)
return -1;
/* Prefetch some data without advancing the file pointer, if possible */
- if (self->peek) {
+ if (self->peek && !self->framing) {
PyObject *len, *prefetched;
len = PyLong_FromSsize_t(PREFETCH);
if (len == NULL) {
@@ -980,7 +1090,7 @@ _Unpickler_ReadFromFile(UnpicklerObject *self, Py_ssize_t n)
Returns -1 (with an exception set) on failure. On success, return the
number of chars read. */
static Py_ssize_t
-_Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n)
+_Unpickler_ReadUnframed(UnpicklerObject *self, char **s, Py_ssize_t n)
{
Py_ssize_t num_read;
@@ -1006,6 +1116,67 @@ _Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n)
}
static Py_ssize_t
+_Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n)
+{
+ if (self->framing &&
+ (self->frame_end_idx == -1 ||
+ self->frame_end_idx <= self->next_read_idx)) {
+ /* Need to read new frame */
+ char *dummy;
+ unsigned char *frame_start;
+ size_t frame_len;
+ if (_Unpickler_ReadUnframed(self, &dummy, FRAME_HEADER_SIZE) < 0)
+ return -1;
+ frame_start = (unsigned char *) dummy;
+ if (frame_start[0] != (unsigned char)FRAME) {
+ PyErr_Format(UnpicklingError,
+ "expected FRAME opcode, got 0x%x instead",
+ frame_start[0]);
+ return -1;
+ }
+ frame_len = (size_t) frame_start[1];
+ frame_len |= (size_t) frame_start[2] << 8;
+ frame_len |= (size_t) frame_start[3] << 16;
+ frame_len |= (size_t) frame_start[4] << 24;
+#if SIZEOF_SIZE_T >= 8
+ frame_len |= (size_t) frame_start[5] << 32;
+ frame_len |= (size_t) frame_start[6] << 40;
+ frame_len |= (size_t) frame_start[7] << 48;
+ frame_len |= (size_t) frame_start[8] << 56;
+#else
+ if (frame_start[5] || frame_start[6] ||
+ frame_start[7] || frame_start[8]) {
+ PyErr_Format(PyExc_OverflowError,
+ "Frame size too large for 32-bit build");
+ return -1;
+ }
+#endif
+ if (frame_len > PY_SSIZE_T_MAX) {
+ PyErr_Format(UnpicklingError, "Invalid frame length");
+ return -1;
+ }
+ if (frame_len < n) {
+ PyErr_Format(UnpicklingError, "Bad framing");
+ return -1;
+ }
+ if (_Unpickler_ReadUnframed(self, &dummy /* unused */,
+ frame_len) < 0)
+ return -1;
+ /* Rewind to start of frame */
+ self->frame_end_idx = self->next_read_idx;
+ self->next_read_idx -= frame_len;
+ }
+ if (self->framing) {
+ /* Check for bad input */
+ if (n + self->next_read_idx > self->frame_end_idx) {
+ PyErr_Format(UnpicklingError, "Bad framing");
+ return -1;
+ }
+ }
+ return _Unpickler_ReadUnframed(self, s, n);
+}
+
+static Py_ssize_t
_Unpickler_CopyLine(UnpicklerObject *self, char *line, Py_ssize_t len,
char **result)
{
@@ -1102,7 +1273,12 @@ _Unpickler_MemoPut(UnpicklerObject *self, Py_ssize_t idx, PyObject *value)
Py_INCREF(value);
old_item = self->memo[idx];
self->memo[idx] = value;
- Py_XDECREF(old_item);
+ if (old_item != NULL) {
+ Py_DECREF(old_item);
+ }
+ else {
+ self->memo_len++;
+ }
return 0;
}
@@ -1150,6 +1326,7 @@ _Unpickler_New(void)
self->input_line = NULL;
self->input_len = 0;
self->next_read_idx = 0;
+ self->frame_end_idx = -1;
self->prefetched_idx = 0;
self->read = NULL;
self->readline = NULL;
@@ -1160,9 +1337,11 @@ _Unpickler_New(void)
self->num_marks = 0;
self->marks_size = 0;
self->proto = 0;
+ self->framing = 0;
self->fix_imports = 0;
memset(&self->buffer, 0, sizeof(Py_buffer));
self->memo_size = 32;
+ self->memo_len = 0;
self->memo = _Unpickler_NewMemo(self->memo_size);
self->stack = (Pdata *)Pdata_New();
@@ -1277,36 +1456,44 @@ memo_get(PicklerObject *self, PyObject *key)
static int
memo_put(PicklerObject *self, PyObject *obj)
{
- Py_ssize_t x;
char pdata[30];
Py_ssize_t len;
- int status = 0;
+ Py_ssize_t idx;
+
+ const char memoize_op = MEMOIZE;
if (self->fast)
return 0;
+ if (_Pickler_OpcodeBoundary(self))
+ return -1;
- x = PyMemoTable_Size(self->memo);
- if (PyMemoTable_Set(self->memo, obj, x) < 0)
- goto error;
+ idx = PyMemoTable_Size(self->memo);
+ if (PyMemoTable_Set(self->memo, obj, idx) < 0)
+ return -1;
- if (!self->bin) {
+ if (self->proto >= 4) {
+ if (_Pickler_Write(self, &memoize_op, 1) < 0)
+ return -1;
+ return 0;
+ }
+ else if (!self->bin) {
pdata[0] = PUT;
PyOS_snprintf(pdata + 1, sizeof(pdata) - 1,
- "%" PY_FORMAT_SIZE_T "d\n", x);
+ "%" PY_FORMAT_SIZE_T "d\n", idx);
len = strlen(pdata);
}
else {
- if (x < 256) {
+ if (idx < 256) {
pdata[0] = BINPUT;
- pdata[1] = (unsigned char)x;
+ pdata[1] = (unsigned char)idx;
len = 2;
}
- else if (x <= 0xffffffffL) {
+ else if (idx <= 0xffffffffL) {
pdata[0] = LONG_BINPUT;
- pdata[1] = (unsigned char)(x & 0xff);
- pdata[2] = (unsigned char)((x >> 8) & 0xff);
- pdata[3] = (unsigned char)((x >> 16) & 0xff);
- pdata[4] = (unsigned char)((x >> 24) & 0xff);
+ pdata[1] = (unsigned char)(idx & 0xff);
+ pdata[2] = (unsigned char)((idx >> 8) & 0xff);
+ pdata[3] = (unsigned char)((idx >> 16) & 0xff);
+ pdata[4] = (unsigned char)((idx >> 24) & 0xff);
len = 5;
}
else { /* unlikely */
@@ -1315,57 +1502,94 @@ memo_put(PicklerObject *self, PyObject *obj)
return -1;
}
}
-
if (_Pickler_Write(self, pdata, len) < 0)
- goto error;
+ return -1;
- if (0) {
- error:
- status = -1;
- }
+ return 0;
+}
- return status;
+static PyObject *
+getattribute(PyObject *obj, PyObject *name, int allow_qualname) {
+ PyObject *dotted_path;
+ Py_ssize_t i;
+ _Py_static_string(PyId_dot, ".");
+ _Py_static_string(PyId_locals, "<locals>");
+
+ dotted_path = PyUnicode_Split(name, _PyUnicode_FromId(&PyId_dot), -1);
+ if (dotted_path == NULL) {
+ return NULL;
+ }
+ assert(Py_SIZE(dotted_path) >= 1);
+ if (!allow_qualname && Py_SIZE(dotted_path) > 1) {
+ PyErr_Format(PyExc_AttributeError,
+ "Can't get qualified attribute %R on %R;"
+ "use protocols >= 4 to enable support",
+ name, obj);
+ Py_DECREF(dotted_path);
+ return NULL;
+ }
+ Py_INCREF(obj);
+ for (i = 0; i < Py_SIZE(dotted_path); i++) {
+ PyObject *subpath = PyList_GET_ITEM(dotted_path, i);
+ PyObject *tmp;
+ PyObject *result = PyUnicode_RichCompare(
+ subpath, _PyUnicode_FromId(&PyId_locals), Py_EQ);
+ int is_equal = (result == Py_True);
+ assert(PyBool_Check(result));
+ Py_DECREF(result);
+ if (is_equal) {
+ PyErr_Format(PyExc_AttributeError,
+ "Can't get local attribute %R on %R", name, obj);
+ Py_DECREF(dotted_path);
+ Py_DECREF(obj);
+ return NULL;
+ }
+ tmp = PyObject_GetAttr(obj, subpath);
+ Py_DECREF(obj);
+ if (tmp == NULL) {
+ if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ PyErr_Clear();
+ PyErr_Format(PyExc_AttributeError,
+ "Can't get attribute %R on %R", name, obj);
+ }
+ Py_DECREF(dotted_path);
+ return NULL;
+ }
+ obj = tmp;
+ }
+ Py_DECREF(dotted_path);
+ return obj;
}
static PyObject *
-whichmodule(PyObject *global, PyObject *global_name)
+whichmodule(PyObject *global, PyObject *global_name, int allow_qualname)
{
- Py_ssize_t i, j;
- static PyObject *module_str = NULL;
- static PyObject *main_str = NULL;
PyObject *module_name;
PyObject *modules_dict;
PyObject *module;
PyObject *obj;
+ Py_ssize_t i, j;
+ _Py_IDENTIFIER(__module__);
+ _Py_IDENTIFIER(modules);
+ _Py_IDENTIFIER(__main__);
- if (module_str == NULL) {
- module_str = PyUnicode_InternFromString("__module__");
- if (module_str == NULL)
- return NULL;
- main_str = PyUnicode_InternFromString("__main__");
- if (main_str == NULL)
- return NULL;
- }
-
- module_name = PyObject_GetAttr(global, module_str);
+ module_name = _PyObject_GetAttrId(global, &PyId___module__);
- /* In some rare cases (e.g., bound methods of extension types),
- __module__ can be None. If it is so, then search sys.modules
- for the module of global. */
- if (module_name == Py_None) {
- Py_DECREF(module_name);
- goto search;
+ if (module_name == NULL) {
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError))
+ return NULL;
+ PyErr_Clear();
}
-
- if (module_name) {
- return module_name;
+ else {
+ /* In some rare cases (e.g., bound methods of extension types),
+ __module__ can be None. If it is so, then search sys.modules for
+ the module of global. */
+ if (module_name != Py_None)
+ return module_name;
+ Py_CLEAR(module_name);
}
- if (PyErr_ExceptionMatches(PyExc_AttributeError))
- PyErr_Clear();
- else
- return NULL;
+ assert(module_name == NULL);
- search:
modules_dict = _PySys_GetObjectId(&PyId_modules);
if (modules_dict == NULL) {
PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules");
@@ -1373,34 +1597,35 @@ whichmodule(PyObject *global, PyObject *global_name)
}
i = 0;
- module_name = NULL;
while ((j = PyDict_Next(modules_dict, &i, &module_name, &module))) {
- if (PyObject_RichCompareBool(module_name, main_str, Py_EQ) == 1)
+ PyObject *result = PyUnicode_RichCompare(
+ module_name, _PyUnicode_FromId(&PyId___main__), Py_EQ);
+ int is_equal = (result == Py_True);
+ assert(PyBool_Check(result));
+ Py_DECREF(result);
+ if (is_equal)
+ continue;
+ if (module == Py_None)
continue;
- obj = PyObject_GetAttr(module, global_name);
+ obj = getattribute(module, global_name, allow_qualname);
if (obj == NULL) {
- if (PyErr_ExceptionMatches(PyExc_AttributeError))
- PyErr_Clear();
- else
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
+ PyErr_Clear();
continue;
}
- if (obj != global) {
+ if (obj == global) {
Py_DECREF(obj);
- continue;
+ Py_INCREF(module_name);
+ return module_name;
}
-
Py_DECREF(obj);
- break;
}
/* If no module is found, use __main__. */
- if (!j) {
- module_name = main_str;
- }
-
+ module_name = _PyUnicode_FromId(&PyId___main__);
Py_INCREF(module_name);
return module_name;
}
@@ -1744,22 +1969,17 @@ save_bytes(PicklerObject *self, PyObject *obj)
reduce_value = Py_BuildValue("(O())", (PyObject*)&PyBytes_Type);
}
else {
- static PyObject *latin1 = NULL;
PyObject *unicode_str =
PyUnicode_DecodeLatin1(PyBytes_AS_STRING(obj),
PyBytes_GET_SIZE(obj),
"strict");
+ _Py_IDENTIFIER(latin1);
+
if (unicode_str == NULL)
return -1;
- if (latin1 == NULL) {
- latin1 = PyUnicode_InternFromString("latin1");
- if (latin1 == NULL) {
- Py_DECREF(unicode_str);
- return -1;
- }
- }
reduce_value = Py_BuildValue("(O(OO))",
- codecs_encode, unicode_str, latin1);
+ codecs_encode, unicode_str,
+ _PyUnicode_FromId(&PyId_latin1));
Py_DECREF(unicode_str);
}
@@ -1773,14 +1993,14 @@ save_bytes(PicklerObject *self, PyObject *obj)
}
else {
Py_ssize_t size;
- char header[5];
+ char header[9];
Py_ssize_t len;
size = PyBytes_GET_SIZE(obj);
if (size < 0)
return -1;
- if (size < 256) {
+ if (size <= 0xff) {
header[0] = SHORT_BINBYTES;
header[1] = (unsigned char)size;
len = 2;
@@ -1793,6 +2013,14 @@ save_bytes(PicklerObject *self, PyObject *obj)
header[4] = (unsigned char)((size >> 24) & 0xff);
len = 5;
}
+ else if (self->proto >= 4) {
+ int i;
+ header[0] = BINBYTES8;
+ for (i = 0; i < 8; i++) {
+ header[i+1] = (unsigned char)((size >> (8 * i)) & 0xff);
+ }
+ len = 8;
+ }
else {
PyErr_SetString(PyExc_OverflowError,
"cannot serialize a bytes object larger than 4 GiB");
@@ -1882,26 +2110,39 @@ done:
static int
write_utf8(PicklerObject *self, char *data, Py_ssize_t size)
{
- char pdata[5];
+ char header[9];
+ Py_ssize_t len;
+
+ if (size <= 0xff && self->proto >= 4) {
+ header[0] = SHORT_BINUNICODE;
+ header[1] = (unsigned char)(size & 0xff);
+ len = 2;
+ }
+ else if (size <= 0xffffffffUL) {
+ header[0] = BINUNICODE;
+ header[1] = (unsigned char)(size & 0xff);
+ header[2] = (unsigned char)((size >> 8) & 0xff);
+ header[3] = (unsigned char)((size >> 16) & 0xff);
+ header[4] = (unsigned char)((size >> 24) & 0xff);
+ len = 5;
+ }
+ else if (self->proto >= 4) {
+ int i;
-#if SIZEOF_SIZE_T > 4
- if (size > 0xffffffffUL) {
- /* string too large */
+ header[0] = BINUNICODE8;
+ for (i = 0; i < 8; i++) {
+ header[i+1] = (unsigned char)((size >> (8 * i)) & 0xff);
+ }
+ len = 9;
+ }
+ else {
PyErr_SetString(PyExc_OverflowError,
"cannot serialize a string larger than 4GiB");
return -1;
}
-#endif
- pdata[0] = BINUNICODE;
- pdata[1] = (unsigned char)(size & 0xff);
- pdata[2] = (unsigned char)((size >> 8) & 0xff);
- pdata[3] = (unsigned char)((size >> 16) & 0xff);
- pdata[4] = (unsigned char)((size >> 24) & 0xff);
-
- if (_Pickler_Write(self, pdata, sizeof(pdata)) < 0)
+ if (_Pickler_Write(self, header, len) < 0)
return -1;
-
if (_Pickler_Write(self, data, size) < 0)
return -1;
@@ -2598,6 +2839,214 @@ save_dict(PicklerObject *self, PyObject *obj)
}
static int
+save_set(PicklerObject *self, PyObject *obj)
+{
+ PyObject *item;
+ int i;
+ Py_ssize_t set_size, ppos = 0;
+ Py_hash_t hash;
+
+ const char empty_set_op = EMPTY_SET;
+ const char mark_op = MARK;
+ const char additems_op = ADDITEMS;
+
+ if (self->proto < 4) {
+ PyObject *items;
+ PyObject *reduce_value;
+ int status;
+
+ items = PySequence_List(obj);
+ if (items == NULL) {
+ return -1;
+ }
+ reduce_value = Py_BuildValue("(O(O))", (PyObject*)&PySet_Type, items);
+ Py_DECREF(items);
+ if (reduce_value == NULL) {
+ return -1;
+ }
+ /* save_reduce() will memoize the object automatically. */
+ status = save_reduce(self, reduce_value, obj);
+ Py_DECREF(reduce_value);
+ return status;
+ }
+
+ if (_Pickler_Write(self, &empty_set_op, 1) < 0)
+ return -1;
+
+ if (memo_put(self, obj) < 0)
+ return -1;
+
+ set_size = PySet_GET_SIZE(obj);
+ if (set_size == 0)
+ return 0; /* nothing to do */
+
+ /* Write in batches of BATCHSIZE. */
+ do {
+ i = 0;
+ if (_Pickler_Write(self, &mark_op, 1) < 0)
+ return -1;
+ while (_PySet_NextEntry(obj, &ppos, &item, &hash)) {
+ if (save(self, item, 0) < 0)
+ return -1;
+ if (++i == BATCHSIZE)
+ break;
+ }
+ if (_Pickler_Write(self, &additems_op, 1) < 0)
+ return -1;
+ if (PySet_GET_SIZE(obj) != set_size) {
+ PyErr_Format(
+ PyExc_RuntimeError,
+ "set changed size during iteration");
+ return -1;
+ }
+ } while (i == BATCHSIZE);
+
+ return 0;
+}
+
+static int
+save_frozenset(PicklerObject *self, PyObject *obj)
+{
+ PyObject *iter;
+
+ const char mark_op = MARK;
+ const char frozenset_op = FROZENSET;
+
+ if (self->fast && !fast_save_enter(self, obj))
+ return -1;
+
+ if (self->proto < 4) {
+ PyObject *items;
+ PyObject *reduce_value;
+ int status;
+
+ items = PySequence_List(obj);
+ if (items == NULL) {
+ return -1;
+ }
+ reduce_value = Py_BuildValue("(O(O))", (PyObject*)&PyFrozenSet_Type,
+ items);
+ Py_DECREF(items);
+ if (reduce_value == NULL) {
+ return -1;
+ }
+ /* save_reduce() will memoize the object automatically. */
+ status = save_reduce(self, reduce_value, obj);
+ Py_DECREF(reduce_value);
+ return status;
+ }
+
+ if (_Pickler_Write(self, &mark_op, 1) < 0)
+ return -1;
+
+ iter = PyObject_GetIter(obj);
+ for (;;) {
+ PyObject *item;
+
+ item = PyIter_Next(iter);
+ if (item == NULL) {
+ if (PyErr_Occurred()) {
+ Py_DECREF(iter);
+ return -1;
+ }
+ break;
+ }
+ if (save(self, item, 0) < 0) {
+ Py_DECREF(item);
+ Py_DECREF(iter);
+ return -1;
+ }
+ Py_DECREF(item);
+ }
+ Py_DECREF(iter);
+
+ /* If the object is already in the memo, this means it is
+ recursive. In this case, throw away everything we put on the
+ stack, and fetch the object back from the memo. */
+ if (PyMemoTable_Get(self->memo, obj)) {
+ const char pop_mark_op = POP_MARK;
+
+ if (_Pickler_Write(self, &pop_mark_op, 1) < 0)
+ return -1;
+ if (memo_get(self, obj) < 0)
+ return -1;
+ return 0;
+ }
+
+ if (_Pickler_Write(self, &frozenset_op, 1) < 0)
+ return -1;
+ if (memo_put(self, obj) < 0)
+ return -1;
+
+ return 0;
+}
+
+static int
+fix_imports(PyObject **module_name, PyObject **global_name)
+{
+ PyObject *key;
+ PyObject *item;
+
+ key = PyTuple_Pack(2, *module_name, *global_name);
+ if (key == NULL)
+ return -1;
+ item = PyDict_GetItemWithError(name_mapping_3to2, key);
+ Py_DECREF(key);
+ if (item) {
+ PyObject *fixed_module_name;
+ PyObject *fixed_global_name;
+
+ if (!PyTuple_Check(item) || PyTuple_GET_SIZE(item) != 2) {
+ PyErr_Format(PyExc_RuntimeError,
+ "_compat_pickle.REVERSE_NAME_MAPPING values "
+ "should be 2-tuples, not %.200s",
+ Py_TYPE(item)->tp_name);
+ return -1;
+ }
+ fixed_module_name = PyTuple_GET_ITEM(item, 0);
+ fixed_global_name = PyTuple_GET_ITEM(item, 1);
+ if (!PyUnicode_Check(fixed_module_name) ||
+ !PyUnicode_Check(fixed_global_name)) {
+ PyErr_Format(PyExc_RuntimeError,
+ "_compat_pickle.REVERSE_NAME_MAPPING values "
+ "should be pairs of str, not (%.200s, %.200s)",
+ Py_TYPE(fixed_module_name)->tp_name,
+ Py_TYPE(fixed_global_name)->tp_name);
+ return -1;
+ }
+
+ Py_CLEAR(*module_name);
+ Py_CLEAR(*global_name);
+ Py_INCREF(fixed_module_name);
+ Py_INCREF(fixed_global_name);
+ *module_name = fixed_module_name;
+ *global_name = fixed_global_name;
+ }
+ else if (PyErr_Occurred()) {
+ return -1;
+ }
+
+ item = PyDict_GetItemWithError(import_mapping_3to2, *module_name);
+ if (item) {
+ if (!PyUnicode_Check(item)) {
+ PyErr_Format(PyExc_RuntimeError,
+ "_compat_pickle.REVERSE_IMPORT_MAPPING values "
+ "should be strings, not %.200s",
+ Py_TYPE(item)->tp_name);
+ return -1;
+ }
+ Py_CLEAR(*module_name);
+ Py_INCREF(item);
+ *module_name = item;
+ }
+ else if (PyErr_Occurred()) {
+ return -1;
+ }
+
+ return 0;
+}
+
+static int
save_global(PicklerObject *self, PyObject *obj, PyObject *name)
{
PyObject *global_name = NULL;
@@ -2605,20 +3054,32 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name)
PyObject *module = NULL;
PyObject *cls;
int status = 0;
+ _Py_IDENTIFIER(__name__);
+ _Py_IDENTIFIER(__qualname__);
const char global_op = GLOBAL;
if (name) {
+ Py_INCREF(name);
global_name = name;
- Py_INCREF(global_name);
}
else {
- global_name = _PyObject_GetAttrId(obj, &PyId___name__);
- if (global_name == NULL)
- goto error;
+ if (self->proto >= 4) {
+ global_name = _PyObject_GetAttrId(obj, &PyId___qualname__);
+ if (global_name == NULL) {
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError))
+ goto error;
+ PyErr_Clear();
+ }
+ }
+ if (global_name == NULL) {
+ global_name = _PyObject_GetAttrId(obj, &PyId___name__);
+ if (global_name == NULL)
+ goto error;
+ }
}
- module_name = whichmodule(obj, global_name);
+ module_name = whichmodule(obj, global_name, self->proto >= 4);
if (module_name == NULL)
goto error;
@@ -2637,11 +3098,11 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name)
obj, module_name);
goto error;
}
- cls = PyObject_GetAttr(module, global_name);
+ cls = getattribute(module, global_name, self->proto >= 4);
if (cls == NULL) {
PyErr_Format(PicklingError,
- "Can't pickle %R: attribute lookup %S.%S failed",
- obj, module_name, global_name);
+ "Can't pickle %R: attribute lookup %S on %S failed",
+ obj, global_name, module_name);
goto error;
}
if (cls != obj) {
@@ -2715,120 +3176,82 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name)
goto error;
}
else {
- /* Generate a normal global opcode if we are using a pickle
- protocol <= 2, or if the object is not registered in the
- extension registry. */
- PyObject *encoded;
- PyObject *(*unicode_encoder)(PyObject *);
-
gen_global:
- if (_Pickler_Write(self, &global_op, 1) < 0)
- goto error;
+ if (self->proto >= 4) {
+ const char stack_global_op = STACK_GLOBAL;
+
+ save(self, module_name, 0);
+ save(self, global_name, 0);
- /* Since Python 3.0 now supports non-ASCII identifiers, we encode both
- the module name and the global name using UTF-8. We do so only when
- we are using the pickle protocol newer than version 3. This is to
- ensure compatibility with older Unpickler running on Python 2.x. */
- if (self->proto >= 3) {
- unicode_encoder = PyUnicode_AsUTF8String;
+ if (_Pickler_Write(self, &stack_global_op, 1) < 0)
+ goto error;
}
else {
- unicode_encoder = PyUnicode_AsASCIIString;
- }
+ /* Generate a normal global opcode if we are using a pickle
+ protocol < 4, or if the object is not registered in the
+ extension registry. */
+ PyObject *encoded;
+ PyObject *(*unicode_encoder)(PyObject *);
- /* For protocol < 3 and if the user didn't request against doing so,
- we convert module names to the old 2.x module names. */
- if (self->fix_imports) {
- PyObject *key;
- PyObject *item;
-
- key = PyTuple_Pack(2, module_name, global_name);
- if (key == NULL)
+ if (_Pickler_Write(self, &global_op, 1) < 0)
goto error;
- item = PyDict_GetItemWithError(name_mapping_3to2, key);
- Py_DECREF(key);
- if (item) {
- if (!PyTuple_Check(item) || PyTuple_GET_SIZE(item) != 2) {
- PyErr_Format(PyExc_RuntimeError,
- "_compat_pickle.REVERSE_NAME_MAPPING values "
- "should be 2-tuples, not %.200s",
- Py_TYPE(item)->tp_name);
- goto error;
- }
- Py_CLEAR(module_name);
- Py_CLEAR(global_name);
- module_name = PyTuple_GET_ITEM(item, 0);
- global_name = PyTuple_GET_ITEM(item, 1);
- if (!PyUnicode_Check(module_name) ||
- !PyUnicode_Check(global_name)) {
- PyErr_Format(PyExc_RuntimeError,
- "_compat_pickle.REVERSE_NAME_MAPPING values "
- "should be pairs of str, not (%.200s, %.200s)",
- Py_TYPE(module_name)->tp_name,
- Py_TYPE(global_name)->tp_name);
+
+ /* For protocol < 3 and if the user didn't request against doing
+ so, we convert module names to the old 2.x module names. */
+ if (self->proto < 3 && self->fix_imports) {
+ if (fix_imports(&module_name, &global_name) < 0) {
goto error;
}
- Py_INCREF(module_name);
- Py_INCREF(global_name);
- }
- else if (PyErr_Occurred()) {
- goto error;
}
- item = PyDict_GetItemWithError(import_mapping_3to2, module_name);
- if (item) {
- if (!PyUnicode_Check(item)) {
- PyErr_Format(PyExc_RuntimeError,
- "_compat_pickle.REVERSE_IMPORT_MAPPING values "
- "should be strings, not %.200s",
- Py_TYPE(item)->tp_name);
- goto error;
- }
- Py_CLEAR(module_name);
- module_name = item;
- Py_INCREF(module_name);
+ /* Since Python 3.0 now supports non-ASCII identifiers, we encode
+ both the module name and the global name using UTF-8. We do so
+ only when we are using the pickle protocol newer than version
+ 3. This is to ensure compatibility with older Unpickler running
+ on Python 2.x. */
+ if (self->proto == 3) {
+ unicode_encoder = PyUnicode_AsUTF8String;
+ }
+ else {
+ unicode_encoder = PyUnicode_AsASCIIString;
}
- else if (PyErr_Occurred()) {
+ encoded = unicode_encoder(module_name);
+ if (encoded == NULL) {
+ if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError))
+ PyErr_Format(PicklingError,
+ "can't pickle module identifier '%S' using "
+ "pickle protocol %i",
+ module_name, self->proto);
+ goto error;
+ }
+ if (_Pickler_Write(self, PyBytes_AS_STRING(encoded),
+ PyBytes_GET_SIZE(encoded)) < 0) {
+ Py_DECREF(encoded);
goto error;
}
- }
-
- /* Save the name of the module. */
- encoded = unicode_encoder(module_name);
- if (encoded == NULL) {
- if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError))
- PyErr_Format(PicklingError,
- "can't pickle module identifier '%S' using "
- "pickle protocol %i", module_name, self->proto);
- goto error;
- }
- if (_Pickler_Write(self, PyBytes_AS_STRING(encoded),
- PyBytes_GET_SIZE(encoded)) < 0) {
Py_DECREF(encoded);
- goto error;
- }
- Py_DECREF(encoded);
- if(_Pickler_Write(self, "\n", 1) < 0)
- goto error;
+ if(_Pickler_Write(self, "\n", 1) < 0)
+ goto error;
- /* Save the name of the module. */
- encoded = unicode_encoder(global_name);
- if (encoded == NULL) {
- if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError))
- PyErr_Format(PicklingError,
- "can't pickle global identifier '%S' using "
- "pickle protocol %i", global_name, self->proto);
- goto error;
- }
- if (_Pickler_Write(self, PyBytes_AS_STRING(encoded),
- PyBytes_GET_SIZE(encoded)) < 0) {
+ /* Save the name of the module. */
+ encoded = unicode_encoder(global_name);
+ if (encoded == NULL) {
+ if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError))
+ PyErr_Format(PicklingError,
+ "can't pickle global identifier '%S' using "
+ "pickle protocol %i",
+ global_name, self->proto);
+ goto error;
+ }
+ if (_Pickler_Write(self, PyBytes_AS_STRING(encoded),
+ PyBytes_GET_SIZE(encoded)) < 0) {
+ Py_DECREF(encoded);
+ goto error;
+ }
Py_DECREF(encoded);
- goto error;
+ if (_Pickler_Write(self, "\n", 1) < 0)
+ goto error;
}
- Py_DECREF(encoded);
- if(_Pickler_Write(self, "\n", 1) < 0)
- goto error;
-
/* Memoize the object. */
if (memo_put(self, obj) < 0)
goto error;
@@ -2927,14 +3350,9 @@ static PyObject *
get_class(PyObject *obj)
{
PyObject *cls;
- static PyObject *str_class;
+ _Py_IDENTIFIER(__class__);
- if (str_class == NULL) {
- str_class = PyUnicode_InternFromString("__class__");
- if (str_class == NULL)
- return NULL;
- }
- cls = PyObject_GetAttr(obj, str_class);
+ cls = _PyObject_GetAttrId(obj, &PyId___class__);
if (cls == NULL) {
if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
PyErr_Clear();
@@ -2957,12 +3375,12 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj)
PyObject *listitems = Py_None;
PyObject *dictitems = Py_None;
Py_ssize_t size;
-
- int use_newobj = self->proto >= 2;
+ int use_newobj = 0, use_newobj_ex = 0;
const char reduce_op = REDUCE;
const char build_op = BUILD;
const char newobj_op = NEWOBJ;
+ const char newobj_ex_op = NEWOBJ_EX;
size = PyTuple_Size(args);
if (size < 2 || size > 5) {
@@ -3007,33 +3425,75 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj)
return -1;
}
- /* Protocol 2 special case: if callable's name is __newobj__, use
- NEWOBJ. */
- if (use_newobj) {
- static PyObject *newobj_str = NULL;
+ if (self->proto >= 2) {
PyObject *name;
-
- if (newobj_str == NULL) {
- newobj_str = PyUnicode_InternFromString("__newobj__");
- if (newobj_str == NULL)
- return -1;
- }
+ _Py_IDENTIFIER(__name__);
name = _PyObject_GetAttrId(callable, &PyId___name__);
if (name == NULL) {
- if (PyErr_ExceptionMatches(PyExc_AttributeError))
- PyErr_Clear();
- else
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
return -1;
- use_newobj = 0;
+ }
+ PyErr_Clear();
+ }
+ else if (self->proto >= 4) {
+ _Py_IDENTIFIER(__newobj_ex__);
+ use_newobj_ex = PyUnicode_Check(name) &&
+ PyUnicode_Compare(
+ name, _PyUnicode_FromId(&PyId___newobj_ex__)) == 0;
+ Py_DECREF(name);
}
else {
+ _Py_IDENTIFIER(__newobj__);
use_newobj = PyUnicode_Check(name) &&
- PyUnicode_Compare(name, newobj_str) == 0;
+ PyUnicode_Compare(
+ name, _PyUnicode_FromId(&PyId___newobj__)) == 0;
Py_DECREF(name);
}
}
- if (use_newobj) {
+
+ if (use_newobj_ex) {
+ PyObject *cls;
+ PyObject *args;
+ PyObject *kwargs;
+
+ if (Py_SIZE(argtup) != 3) {
+ PyErr_Format(PicklingError,
+ "length of the NEWOBJ_EX argument tuple must be "
+ "exactly 3, not %zd", Py_SIZE(argtup));
+ return -1;
+ }
+
+ cls = PyTuple_GET_ITEM(argtup, 0);
+ if (!PyType_Check(cls)) {
+ PyErr_Format(PicklingError,
+ "first item from NEWOBJ_EX argument tuple must "
+ "be a class, not %.200s", Py_TYPE(cls)->tp_name);
+ return -1;
+ }
+ args = PyTuple_GET_ITEM(argtup, 1);
+ if (!PyTuple_Check(args)) {
+ PyErr_Format(PicklingError,
+ "second item from NEWOBJ_EX argument tuple must "
+ "be a tuple, not %.200s", Py_TYPE(args)->tp_name);
+ return -1;
+ }
+ kwargs = PyTuple_GET_ITEM(argtup, 2);
+ if (!PyDict_Check(kwargs)) {
+ PyErr_Format(PicklingError,
+ "third item from NEWOBJ_EX argument tuple must "
+ "be a dict, not %.200s", Py_TYPE(kwargs)->tp_name);
+ return -1;
+ }
+
+ if (save(self, cls, 0) < 0 ||
+ save(self, args, 0) < 0 ||
+ save(self, kwargs, 0) < 0 ||
+ _Pickler_Write(self, &newobj_ex_op, 1) < 0) {
+ return -1;
+ }
+ }
+ else if (use_newobj) {
PyObject *cls;
PyObject *newargtup;
PyObject *obj_class;
@@ -3117,8 +3577,23 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj)
the caller do not want to memoize the object. Not particularly useful,
but that is to mimic the behavior save_reduce() in pickle.py when
obj is None. */
- if (obj && memo_put(self, obj) < 0)
- return -1;
+ if (obj != NULL) {
+ /* If the object is already in the memo, this means it is
+ recursive. In this case, throw away everything we put on the
+ stack, and fetch the object back from the memo. */
+ if (PyMemoTable_Get(self->memo, obj)) {
+ const char pop_op = POP;
+
+ if (_Pickler_Write(self, &pop_op, 1) < 0)
+ return -1;
+ if (memo_get(self, obj) < 0)
+ return -1;
+
+ return 0;
+ }
+ else if (memo_put(self, obj) < 0)
+ return -1;
+ }
if (listitems && batch_list(self, listitems) < 0)
return -1;
@@ -3136,6 +3611,34 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj)
}
static int
+save_method(PicklerObject *self, PyObject *obj)
+{
+ PyObject *method_self = PyCFunction_GET_SELF(obj);
+
+ if (method_self == NULL || PyModule_Check(method_self)) {
+ return save_global(self, obj, NULL);
+ }
+ else {
+ PyObject *builtins;
+ PyObject *getattr;
+ PyObject *reduce_value;
+ int status = -1;
+ _Py_IDENTIFIER(getattr);
+
+ builtins = PyEval_GetBuiltins();
+ getattr = _PyDict_GetItemId(builtins, &PyId_getattr);
+ reduce_value = \
+ Py_BuildValue("O(Os)", getattr, method_self,
+ ((PyCFunctionObject *)obj)->m_ml->ml_name);
+ if (reduce_value != NULL) {
+ status = save_reduce(self, reduce_value, obj);
+ Py_DECREF(reduce_value);
+ }
+ return status;
+ }
+}
+
+static int
save(PicklerObject *self, PyObject *obj, int pers_save)
{
PyTypeObject *type;
@@ -3213,6 +3716,14 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
status = save_dict(self, obj);
goto done;
}
+ else if (type == &PySet_Type) {
+ status = save_set(self, obj);
+ goto done;
+ }
+ else if (type == &PyFrozenSet_Type) {
+ status = save_frozenset(self, obj);
+ goto done;
+ }
else if (type == &PyList_Type) {
status = save_list(self, obj);
goto done;
@@ -3236,7 +3747,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
}
}
else if (type == &PyCFunction_Type) {
- status = save_global(self, obj, NULL);
+ status = save_method(self, obj);
goto done;
}
@@ -3269,18 +3780,9 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
goto done;
}
else {
- static PyObject *reduce_str = NULL;
- static PyObject *reduce_ex_str = NULL;
+ _Py_IDENTIFIER(__reduce__);
+ _Py_IDENTIFIER(__reduce_ex__);
- /* Cache the name of the reduce methods. */
- if (reduce_str == NULL) {
- reduce_str = PyUnicode_InternFromString("__reduce__");
- if (reduce_str == NULL)
- goto error;
- reduce_ex_str = PyUnicode_InternFromString("__reduce_ex__");
- if (reduce_ex_str == NULL)
- goto error;
- }
/* XXX: If the __reduce__ method is defined, __reduce_ex__ is
automatically defined as __reduce__. While this is convenient, this
@@ -3291,7 +3793,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
don't actually have to check for a __reduce__ method. */
/* Check for a __reduce_ex__ method. */
- reduce_func = PyObject_GetAttr(obj, reduce_ex_str);
+ reduce_func = _PyObject_GetAttrId(obj, &PyId___reduce_ex__);
if (reduce_func != NULL) {
PyObject *proto;
proto = PyLong_FromLong(self->proto);
@@ -3305,7 +3807,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
else
goto error;
/* Check for a __reduce__ method. */
- reduce_func = PyObject_GetAttr(obj, reduce_str);
+ reduce_func = _PyObject_GetAttrId(obj, &PyId___reduce__);
if (reduce_func != NULL) {
reduce_value = PyObject_Call(reduce_func, empty_tuple, NULL);
}
@@ -3338,6 +3840,8 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
status = -1;
}
done:
+ if (status == 0)
+ status = _Pickler_OpcodeBoundary(self);
Py_LeaveRecursiveCall();
Py_XDECREF(reduce_func);
Py_XDECREF(reduce_value);
@@ -3358,6 +3862,8 @@ dump(PicklerObject *self, PyObject *obj)
header[1] = (unsigned char)self->proto;
if (_Pickler_Write(self, header, 2) < 0)
return -1;
+ if (self->proto >= 4)
+ self->framing = 1;
}
if (save(self, obj, 0) < 0 ||
@@ -3478,9 +3984,9 @@ PyDoc_STRVAR(Pickler_doc,
"This takes a binary file for writing a pickle data stream.\n"
"\n"
"The optional protocol argument tells the pickler to use the\n"
-"given protocol; supported protocols are 0, 1, 2, 3. The default\n"
-"protocol is 3; a backward-incompatible protocol designed for\n"
-"Python 3.0.\n"
+"given protocol; supported protocols are 0, 1, 2, 3 and 4. The\n"
+"default protocol is 3; a backward-incompatible protocol designed for\n"
+"Python 3.\n"
"\n"
"Specifying a negative protocol version selects the highest\n"
"protocol version supported. The higher the protocol used, the\n"
@@ -3493,8 +3999,8 @@ PyDoc_STRVAR(Pickler_doc,
"meets this interface.\n"
"\n"
"If fix_imports is True and protocol is less than 3, pickle will try to\n"
-"map the new Python 3.x names to the old module names used in Python\n"
-"2.x, so that the pickle data stream is readable with Python 2.x.\n");
+"map the new Python 3 names to the old module names used in Python 2,\n"
+"so that the pickle data stream is readable with Python 2.\n");
static int
Pickler_init(PicklerObject *self, PyObject *args, PyObject *kwds)
@@ -3987,17 +4493,15 @@ load_bool(UnpicklerObject *self, PyObject *boolean)
* as a C Py_ssize_t, or -1 if it's higher than PY_SSIZE_T_MAX.
*/
static Py_ssize_t
-calc_binsize(char *bytes, int size)
+calc_binsize(char *bytes, int nbytes)
{
unsigned char *s = (unsigned char *)bytes;
+ int i;
size_t x = 0;
- assert(size == 4);
-
- x = (size_t) s[0];
- x |= (size_t) s[1] << 8;
- x |= (size_t) s[2] << 16;
- x |= (size_t) s[3] << 24;
+ for (i = 0; i < nbytes; i++) {
+ x |= (size_t) s[i] << (8 * i);
+ }
if (x > PY_SSIZE_T_MAX)
return -1;
@@ -4011,21 +4515,21 @@ calc_binsize(char *bytes, int size)
* of x-platform bugs.
*/
static long
-calc_binint(char *bytes, int size)
+calc_binint(char *bytes, int nbytes)
{
unsigned char *s = (unsigned char *)bytes;
- int i = size;
+ int i;
long x = 0;
- for (i = 0; i < size; i++) {
- x |= (long)s[i] << (i * 8);
+ for (i = 0; i < nbytes; i++) {
+ x |= (long)s[i] << (8 * i);
}
/* Unlike BININT1 and BININT2, BININT (more accurately BININT4)
* is signed, so on a box with longs bigger than 4 bytes we need
* to extend a BININT's sign bit to the full width.
*/
- if (SIZEOF_LONG > 4 && size == 4) {
+ if (SIZEOF_LONG > 4 && nbytes == 4) {
x |= -(x & (1L << 31));
}
@@ -4233,49 +4737,27 @@ load_string(UnpicklerObject *self)
}
static int
-load_binbytes(UnpicklerObject *self)
+load_counted_binbytes(UnpicklerObject *self, int nbytes)
{
PyObject *bytes;
- Py_ssize_t x;
+ Py_ssize_t size;
char *s;
- if (_Unpickler_Read(self, &s, 4) < 0)
+ if (_Unpickler_Read(self, &s, nbytes) < 0)
return -1;
- x = calc_binsize(s, 4);
- if (x < 0) {
+ size = calc_binsize(s, nbytes);
+ if (size < 0) {
PyErr_Format(PyExc_OverflowError,
"BINBYTES exceeds system's maximum size of %zd bytes",
PY_SSIZE_T_MAX);
return -1;
}
- if (_Unpickler_Read(self, &s, x) < 0)
- return -1;
- bytes = PyBytes_FromStringAndSize(s, x);
- if (bytes == NULL)
- return -1;
-
- PDATA_PUSH(self->stack, bytes, -1);
- return 0;
-}
-
-static int
-load_short_binbytes(UnpicklerObject *self)
-{
- PyObject *bytes;
- Py_ssize_t x;
- char *s;
-
- if (_Unpickler_Read(self, &s, 1) < 0)
- return -1;
-
- x = (unsigned char)s[0];
-
- if (_Unpickler_Read(self, &s, x) < 0)
+ if (_Unpickler_Read(self, &s, size) < 0)
return -1;
- bytes = PyBytes_FromStringAndSize(s, x);
+ bytes = PyBytes_FromStringAndSize(s, size);
if (bytes == NULL)
return -1;
@@ -4284,51 +4766,27 @@ load_short_binbytes(UnpicklerObject *self)
}
static int
-load_binstring(UnpicklerObject *self)
+load_counted_binstring(UnpicklerObject *self, int nbytes)
{
PyObject *str;
- Py_ssize_t x;
+ Py_ssize_t size;
char *s;
- if (_Unpickler_Read(self, &s, 4) < 0)
+ if (_Unpickler_Read(self, &s, nbytes) < 0)
return -1;
- x = calc_binint(s, 4);
- if (x < 0) {
- PyErr_SetString(UnpicklingError,
- "BINSTRING pickle has negative byte count");
+ size = calc_binsize(s, nbytes);
+ if (size < 0) {
+ PyErr_Format(UnpicklingError,
+ "BINSTRING exceeds system's maximum size of %zd bytes",
+ PY_SSIZE_T_MAX);
return -1;
}
- if (_Unpickler_Read(self, &s, x) < 0)
- return -1;
-
- /* Convert Python 2.x strings to unicode. */
- str = PyUnicode_Decode(s, x, self->encoding, self->errors);
- if (str == NULL)
- return -1;
-
- PDATA_PUSH(self->stack, str, -1);
- return 0;
-}
-
-static int
-load_short_binstring(UnpicklerObject *self)
-{
- PyObject *str;
- Py_ssize_t x;
- char *s;
-
- if (_Unpickler_Read(self, &s, 1) < 0)
- return -1;
-
- x = (unsigned char)s[0];
-
- if (_Unpickler_Read(self, &s, x) < 0)
+ if (_Unpickler_Read(self, &s, size) < 0)
return -1;
-
/* Convert Python 2.x strings to unicode. */
- str = PyUnicode_Decode(s, x, self->encoding, self->errors);
+ str = PyUnicode_Decode(s, size, self->encoding, self->errors);
if (str == NULL)
return -1;
@@ -4357,16 +4815,16 @@ load_unicode(UnpicklerObject *self)
}
static int
-load_binunicode(UnpicklerObject *self)
+load_counted_binunicode(UnpicklerObject *self, int nbytes)
{
PyObject *str;
Py_ssize_t size;
char *s;
- if (_Unpickler_Read(self, &s, 4) < 0)
+ if (_Unpickler_Read(self, &s, nbytes) < 0)
return -1;
- size = calc_binsize(s, 4);
+ size = calc_binsize(s, nbytes);
if (size < 0) {
PyErr_Format(PyExc_OverflowError,
"BINUNICODE exceeds system's maximum size of %zd bytes",
@@ -4374,7 +4832,6 @@ load_binunicode(UnpicklerObject *self)
return -1;
}
-
if (_Unpickler_Read(self, &s, size) < 0)
return -1;
@@ -4446,6 +4903,17 @@ load_empty_dict(UnpicklerObject *self)
}
static int
+load_empty_set(UnpicklerObject *self)
+{
+ PyObject *set;
+
+ if ((set = PySet_New(NULL)) == NULL)
+ return -1;
+ PDATA_PUSH(self->stack, set, -1);
+ return 0;
+}
+
+static int
load_list(UnpicklerObject *self)
{
PyObject *list;
@@ -4487,6 +4955,29 @@ load_dict(UnpicklerObject *self)
return 0;
}
+static int
+load_frozenset(UnpicklerObject *self)
+{
+ PyObject *items;
+ PyObject *frozenset;
+ Py_ssize_t i;
+
+ if ((i = marker(self)) < 0)
+ return -1;
+
+ items = Pdata_poptuple(self->stack, i);
+ if (items == NULL)
+ return -1;
+
+ frozenset = PyFrozenSet_New(items);
+ Py_DECREF(items);
+ if (frozenset == NULL)
+ return -1;
+
+ PDATA_PUSH(self->stack, frozenset, -1);
+ return 0;
+}
+
static PyObject *
instantiate(PyObject *cls, PyObject *args)
{
@@ -4638,6 +5129,57 @@ load_newobj(UnpicklerObject *self)
}
static int
+load_newobj_ex(UnpicklerObject *self)
+{
+ PyObject *cls, *args, *kwargs;
+ PyObject *obj;
+
+ PDATA_POP(self->stack, kwargs);
+ if (kwargs == NULL) {
+ return -1;
+ }
+ PDATA_POP(self->stack, args);
+ if (args == NULL) {
+ Py_DECREF(kwargs);
+ return -1;
+ }
+ PDATA_POP(self->stack, cls);
+ if (cls == NULL) {
+ Py_DECREF(kwargs);
+ Py_DECREF(args);
+ return -1;
+ }
+
+ if (!PyType_Check(cls)) {
+ Py_DECREF(kwargs);
+ Py_DECREF(args);
+ Py_DECREF(cls);
+ PyErr_Format(UnpicklingError,
+ "NEWOBJ_EX class argument must be a type, not %.200s",
+ Py_TYPE(cls)->tp_name);
+ return -1;
+ }
+
+ if (((PyTypeObject *)cls)->tp_new == NULL) {
+ Py_DECREF(kwargs);
+ Py_DECREF(args);
+ Py_DECREF(cls);
+ PyErr_SetString(UnpicklingError,
+ "NEWOBJ_EX class argument doesn't have __new__");
+ return -1;
+ }
+ obj = ((PyTypeObject *)cls)->tp_new((PyTypeObject *)cls, args, kwargs);
+ Py_DECREF(kwargs);
+ Py_DECREF(args);
+ Py_DECREF(cls);
+ if (obj == NULL) {
+ return -1;
+ }
+ PDATA_PUSH(self->stack, obj, -1);
+ return 0;
+}
+
+static int
load_global(UnpicklerObject *self)
{
PyObject *global = NULL;
@@ -4674,6 +5216,31 @@ load_global(UnpicklerObject *self)
}
static int
+load_stack_global(UnpicklerObject *self)
+{
+ PyObject *global;
+ PyObject *module_name;
+ PyObject *global_name;
+
+ PDATA_POP(self->stack, global_name);
+ PDATA_POP(self->stack, module_name);
+ if (module_name == NULL || !PyUnicode_CheckExact(module_name) ||
+ global_name == NULL || !PyUnicode_CheckExact(global_name)) {
+ PyErr_SetString(UnpicklingError, "STACK_GLOBAL requires str");
+ Py_XDECREF(global_name);
+ Py_XDECREF(module_name);
+ return -1;
+ }
+ global = find_class(self, module_name, global_name);
+ Py_DECREF(global_name);
+ Py_DECREF(module_name);
+ if (global == NULL)
+ return -1;
+ PDATA_PUSH(self->stack, global, -1);
+ return 0;
+}
+
+static int
load_persid(UnpicklerObject *self)
{
PyObject *pid;
@@ -5017,6 +5584,18 @@ load_long_binput(UnpicklerObject *self)
}
static int
+load_memoize(UnpicklerObject *self)
+{
+ PyObject *value;
+
+ if (Py_SIZE(self->stack) <= 0)
+ return stack_underflow();
+ value = self->stack->data[Py_SIZE(self->stack) - 1];
+
+ return _Unpickler_MemoPut(self, self->memo_len, value);
+}
+
+static int
do_append(UnpicklerObject *self, Py_ssize_t x)
{
PyObject *value;
@@ -5132,6 +5711,59 @@ load_setitems(UnpicklerObject *self)
}
static int
+load_additems(UnpicklerObject *self)
+{
+ PyObject *set;
+ Py_ssize_t mark, len, i;
+
+ mark = marker(self);
+ len = Py_SIZE(self->stack);
+ if (mark > len || mark <= 0)
+ return stack_underflow();
+ if (len == mark) /* nothing to do */
+ return 0;
+
+ set = self->stack->data[mark - 1];
+
+ if (PySet_Check(set)) {
+ PyObject *items;
+ int status;
+
+ items = Pdata_poptuple(self->stack, mark);
+ if (items == NULL)
+ return -1;
+
+ status = _PySet_Update(set, items);
+ Py_DECREF(items);
+ return status;
+ }
+ else {
+ PyObject *add_func;
+ _Py_IDENTIFIER(add);
+
+ add_func = _PyObject_GetAttrId(set, &PyId_add);
+ if (add_func == NULL)
+ return -1;
+ for (i = mark; i < len; i++) {
+ PyObject *result;
+ PyObject *item;
+
+ item = self->stack->data[i];
+ result = _Unpickler_FastCall(self, add_func, item);
+ if (result == NULL) {
+ Pdata_clear(self->stack, i + 1);
+ Py_SIZE(self->stack) = mark;
+ return -1;
+ }
+ Py_DECREF(result);
+ }
+ Py_SIZE(self->stack) = mark;
+ }
+
+ return 0;
+}
+
+static int
load_build(UnpicklerObject *self)
{
PyObject *state, *inst, *slotstate;
@@ -5325,6 +5957,7 @@ load_proto(UnpicklerObject *self)
i = (unsigned char)s[0];
if (i <= HIGHEST_PROTOCOL) {
self->proto = i;
+ self->framing = (self->proto >= 4);
return 0;
}
@@ -5340,6 +5973,8 @@ load(UnpicklerObject *self)
char *s;
self->num_marks = 0;
+ self->proto = 0;
+ self->framing = 0;
if (Py_SIZE(self->stack))
Pdata_clear(self->stack, 0);
@@ -5365,13 +6000,16 @@ load(UnpicklerObject *self)
OP_ARG(LONG4, load_counted_long, 4)
OP(FLOAT, load_float)
OP(BINFLOAT, load_binfloat)
- OP(BINBYTES, load_binbytes)
- OP(SHORT_BINBYTES, load_short_binbytes)
- OP(BINSTRING, load_binstring)
- OP(SHORT_BINSTRING, load_short_binstring)
+ OP_ARG(SHORT_BINBYTES, load_counted_binbytes, 1)
+ OP_ARG(BINBYTES, load_counted_binbytes, 4)
+ OP_ARG(BINBYTES8, load_counted_binbytes, 8)
+ OP_ARG(SHORT_BINSTRING, load_counted_binstring, 1)
+ OP_ARG(BINSTRING, load_counted_binstring, 4)
OP(STRING, load_string)
OP(UNICODE, load_unicode)
- OP(BINUNICODE, load_binunicode)
+ OP_ARG(SHORT_BINUNICODE, load_counted_binunicode, 1)
+ OP_ARG(BINUNICODE, load_counted_binunicode, 4)
+ OP_ARG(BINUNICODE8, load_counted_binunicode, 8)
OP_ARG(EMPTY_TUPLE, load_counted_tuple, 0)
OP_ARG(TUPLE1, load_counted_tuple, 1)
OP_ARG(TUPLE2, load_counted_tuple, 2)
@@ -5381,10 +6019,15 @@ load(UnpicklerObject *self)
OP(LIST, load_list)
OP(EMPTY_DICT, load_empty_dict)
OP(DICT, load_dict)
+ OP(EMPTY_SET, load_empty_set)
+ OP(ADDITEMS, load_additems)
+ OP(FROZENSET, load_frozenset)
OP(OBJ, load_obj)
OP(INST, load_inst)
OP(NEWOBJ, load_newobj)
+ OP(NEWOBJ_EX, load_newobj_ex)
OP(GLOBAL, load_global)
+ OP(STACK_GLOBAL, load_stack_global)
OP(APPEND, load_append)
OP(APPENDS, load_appends)
OP(BUILD, load_build)
@@ -5396,6 +6039,7 @@ load(UnpicklerObject *self)
OP(BINPUT, load_binput)
OP(LONG_BINPUT, load_long_binput)
OP(PUT, load_put)
+ OP(MEMOIZE, load_memoize)
OP(POP, load_pop)
OP(POP_MARK, load_pop_mark)
OP(SETITEM, load_setitem)
@@ -5485,6 +6129,7 @@ Unpickler_find_class(UnpicklerObject *self, PyObject *args)
PyObject *modules_dict;
PyObject *module;
PyObject *module_name, *global_name;
+ _Py_IDENTIFIER(modules);
if (!PyArg_UnpackTuple(args, "find_class", 2, 2,
&module_name, &global_name))
@@ -5556,11 +6201,11 @@ Unpickler_find_class(UnpicklerObject *self, PyObject *args)
module = PyImport_Import(module_name);
if (module == NULL)
return NULL;
- global = PyObject_GetAttr(module, global_name);
+ global = getattribute(module, global_name, self->proto >= 4);
Py_DECREF(module);
}
else {
- global = PyObject_GetAttr(module, global_name);
+ global = getattribute(module, global_name, self->proto >= 4);
}
return global;
}
@@ -5723,6 +6368,7 @@ Unpickler_init(UnpicklerObject *self, PyObject *args, PyObject *kwds)
self->arg = NULL;
self->proto = 0;
+ self->framing = 0;
return 0;
}
diff --git a/Objects/classobject.c b/Objects/classobject.c
index 27f7ef4..272f575 100644
--- a/Objects/classobject.c
+++ b/Objects/classobject.c
@@ -69,6 +69,30 @@ PyMethod_New(PyObject *func, PyObject *self)
return (PyObject *)im;
}
+static PyObject *
+method_reduce(PyMethodObject *im)
+{
+ PyObject *self = PyMethod_GET_SELF(im);
+ PyObject *func = PyMethod_GET_FUNCTION(im);
+ PyObject *builtins;
+ PyObject *getattr;
+ PyObject *funcname;
+ _Py_IDENTIFIER(getattr);
+
+ funcname = _PyObject_GetAttrId(func, &PyId___name__);
+ if (funcname == NULL) {
+ return NULL;
+ }
+ builtins = PyEval_GetBuiltins();
+ getattr = _PyDict_GetItemId(builtins, &PyId_getattr);
+ return Py_BuildValue("O(ON)", getattr, self, funcname);
+}
+
+static PyMethodDef method_methods[] = {
+ {"__reduce__", (PyCFunction)method_reduce, METH_NOARGS, NULL},
+ {NULL, NULL}
+};
+
/* Descriptors for PyMethod attributes */
/* im_func and im_self are stored in the PyMethod object */
@@ -367,7 +391,7 @@ PyTypeObject PyMethod_Type = {
offsetof(PyMethodObject, im_weakreflist), /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
- 0, /* tp_methods */
+ method_methods, /* tp_methods */
method_memberlist, /* tp_members */
method_getset, /* tp_getset */
0, /* tp_base */
diff --git a/Objects/descrobject.c b/Objects/descrobject.c
index d4f8048..da88f86 100644
--- a/Objects/descrobject.c
+++ b/Objects/descrobject.c
@@ -398,6 +398,24 @@ descr_get_qualname(PyDescrObject *descr)
return descr->d_qualname;
}
+static PyObject *
+descr_reduce(PyDescrObject *descr)
+{
+ PyObject *builtins;
+ PyObject *getattr;
+ _Py_IDENTIFIER(getattr);
+
+ builtins = PyEval_GetBuiltins();
+ getattr = _PyDict_GetItemId(builtins, &PyId_getattr);
+ return Py_BuildValue("O(OO)", getattr, PyDescr_TYPE(descr),
+ PyDescr_NAME(descr));
+}
+
+static PyMethodDef descr_methods[] = {
+ {"__reduce__", (PyCFunction)descr_reduce, METH_NOARGS, NULL},
+ {NULL, NULL}
+};
+
static PyMemberDef descr_members[] = {
{"__objclass__", T_OBJECT, offsetof(PyDescrObject, d_type), READONLY},
{"__name__", T_OBJECT, offsetof(PyDescrObject, d_name), READONLY},
@@ -494,7 +512,7 @@ PyTypeObject PyMethodDescr_Type = {
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
- 0, /* tp_methods */
+ descr_methods, /* tp_methods */
descr_members, /* tp_members */
method_getset, /* tp_getset */
0, /* tp_base */
@@ -532,7 +550,7 @@ PyTypeObject PyClassMethodDescr_Type = {
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
- 0, /* tp_methods */
+ descr_methods, /* tp_methods */
descr_members, /* tp_members */
method_getset, /* tp_getset */
0, /* tp_base */
@@ -569,7 +587,7 @@ PyTypeObject PyMemberDescr_Type = {
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
- 0, /* tp_methods */
+ descr_methods, /* tp_methods */
descr_members, /* tp_members */
member_getset, /* tp_getset */
0, /* tp_base */
@@ -643,7 +661,7 @@ PyTypeObject PyWrapperDescr_Type = {
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
- 0, /* tp_methods */
+ descr_methods, /* tp_methods */
descr_members, /* tp_members */
wrapperdescr_getset, /* tp_getset */
0, /* tp_base */
@@ -1085,6 +1103,23 @@ wrapper_repr(wrapperobject *wp)
wp->self);
}
+static PyObject *
+wrapper_reduce(wrapperobject *wp)
+{
+ PyObject *builtins;
+ PyObject *getattr;
+ _Py_IDENTIFIER(getattr);
+
+ builtins = PyEval_GetBuiltins();
+ getattr = _PyDict_GetItemId(builtins, &PyId_getattr);
+ return Py_BuildValue("O(OO)", getattr, wp->self, PyDescr_NAME(wp->descr));
+}
+
+static PyMethodDef wrapper_methods[] = {
+ {"__reduce__", (PyCFunction)wrapper_reduce, METH_NOARGS, NULL},
+ {NULL, NULL}
+};
+
static PyMemberDef wrapper_members[] = {
{"__self__", T_OBJECT, offsetof(wrapperobject, self), READONLY},
{0}
@@ -1193,7 +1228,7 @@ PyTypeObject _PyMethodWrapper_Type = {
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
- 0, /* tp_methods */
+ wrapper_methods, /* tp_methods */
wrapper_members, /* tp_members */
wrapper_getsets, /* tp_getset */
0, /* tp_base */
diff --git a/Objects/typeobject.c b/Objects/typeobject.c
index 09f77fa..5e951de 100644
--- a/Objects/typeobject.c
+++ b/Objects/typeobject.c
@@ -3405,151 +3405,430 @@ import_copyreg(void)
return cached_copyreg_module;
}
-static PyObject *
-slotnames(PyObject *cls)
+Py_LOCAL(PyObject *)
+_PyType_GetSlotNames(PyTypeObject *cls)
{
- PyObject *clsdict;
PyObject *copyreg;
PyObject *slotnames;
_Py_IDENTIFIER(__slotnames__);
_Py_IDENTIFIER(_slotnames);
- clsdict = ((PyTypeObject *)cls)->tp_dict;
- slotnames = _PyDict_GetItemId(clsdict, &PyId___slotnames__);
- if (slotnames != NULL && PyList_Check(slotnames)) {
+ assert(PyType_Check(cls));
+
+ /* Get the slot names from the cache in the class if possible. */
+ slotnames = _PyDict_GetItemIdWithError(cls->tp_dict, &PyId___slotnames__);
+ if (slotnames != NULL) {
+ if (slotnames != Py_None && !PyList_Check(slotnames)) {
+ PyErr_Format(PyExc_TypeError,
+ "%.200s.__slotnames__ should be a list or None, "
+ "not %.200s",
+ cls->tp_name, Py_TYPE(slotnames)->tp_name);
+ return NULL;
+ }
Py_INCREF(slotnames);
return slotnames;
}
+ else {
+ if (PyErr_Occurred()) {
+ return NULL;
+ }
+ /* The class does not have the slot names cached yet. */
+ }
copyreg = import_copyreg();
if (copyreg == NULL)
return NULL;
- slotnames = _PyObject_CallMethodId(copyreg, &PyId__slotnames, "O", cls);
+ /* Use _slotnames function from the copyreg module to find the slots
+ by this class and its bases. This function will cache the result
+ in __slotnames__. */
+ slotnames = _PyObject_CallMethodIdObjArgs(copyreg, &PyId__slotnames,
+ cls, NULL);
Py_DECREF(copyreg);
- if (slotnames != NULL &&
- slotnames != Py_None &&
- !PyList_Check(slotnames))
- {
+ if (slotnames == NULL)
+ return NULL;
+
+ if (slotnames != Py_None && !PyList_Check(slotnames)) {
PyErr_SetString(PyExc_TypeError,
- "copyreg._slotnames didn't return a list or None");
+ "copyreg._slotnames didn't return a list or None");
Py_DECREF(slotnames);
- slotnames = NULL;
+ return NULL;
}
return slotnames;
}
-static PyObject *
-reduce_2(PyObject *obj)
+Py_LOCAL(PyObject *)
+_PyObject_GetState(PyObject *obj)
{
- PyObject *cls, *getnewargs;
- PyObject *args = NULL, *args2 = NULL;
- PyObject *getstate = NULL, *state = NULL, *names = NULL;
- PyObject *slots = NULL, *listitems = NULL, *dictitems = NULL;
- PyObject *copyreg = NULL, *newobj = NULL, *res = NULL;
- Py_ssize_t i, n;
- _Py_IDENTIFIER(__getnewargs__);
+ PyObject *state;
+ PyObject *getstate;
_Py_IDENTIFIER(__getstate__);
- _Py_IDENTIFIER(__newobj__);
- cls = (PyObject *) Py_TYPE(obj);
+ getstate = _PyObject_GetAttrId(obj, &PyId___getstate__);
+ if (getstate == NULL) {
+ PyObject *slotnames;
- getnewargs = _PyObject_GetAttrId(obj, &PyId___getnewargs__);
- if (getnewargs != NULL) {
- args = PyObject_CallObject(getnewargs, NULL);
- Py_DECREF(getnewargs);
- if (args != NULL && !PyTuple_Check(args)) {
- PyErr_Format(PyExc_TypeError,
- "__getnewargs__ should return a tuple, "
- "not '%.200s'", Py_TYPE(args)->tp_name);
- goto end;
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ return NULL;
}
- }
- else {
PyErr_Clear();
- args = PyTuple_New(0);
- }
- if (args == NULL)
- goto end;
- getstate = _PyObject_GetAttrId(obj, &PyId___getstate__);
- if (getstate != NULL) {
- state = PyObject_CallObject(getstate, NULL);
- Py_DECREF(getstate);
- if (state == NULL)
- goto end;
- }
- else {
- PyObject **dict;
- PyErr_Clear();
- dict = _PyObject_GetDictPtr(obj);
- if (dict && *dict)
- state = *dict;
- else
- state = Py_None;
- Py_INCREF(state);
- names = slotnames(cls);
- if (names == NULL)
- goto end;
- if (names != Py_None && PyList_GET_SIZE(names) > 0) {
- assert(PyList_Check(names));
+ {
+ PyObject **dict;
+ dict = _PyObject_GetDictPtr(obj);
+ /* It is possible that the object's dict is not initialized
+ yet. In this case, we will return None for the state.
+ We also return None if the dict is empty to make the behavior
+ consistent regardless whether the dict was initialized or not.
+ This make unit testing easier. */
+ if (dict != NULL && *dict != NULL && PyDict_Size(*dict) > 0) {
+ state = *dict;
+ }
+ else {
+ state = Py_None;
+ }
+ Py_INCREF(state);
+ }
+
+ slotnames = _PyType_GetSlotNames(Py_TYPE(obj));
+ if (slotnames == NULL) {
+ Py_DECREF(state);
+ return NULL;
+ }
+
+ assert(slotnames == Py_None || PyList_Check(slotnames));
+ if (slotnames != Py_None && Py_SIZE(slotnames) > 0) {
+ PyObject *slots;
+ Py_ssize_t slotnames_size, i;
+
slots = PyDict_New();
- if (slots == NULL)
- goto end;
- n = 0;
- /* Can't pre-compute the list size; the list
- is stored on the class so accessible to other
- threads, which may be run by DECREF */
- for (i = 0; i < PyList_GET_SIZE(names); i++) {
+ if (slots == NULL) {
+ Py_DECREF(slotnames);
+ Py_DECREF(state);
+ return NULL;
+ }
+
+ slotnames_size = Py_SIZE(slotnames);
+ for (i = 0; i < slotnames_size; i++) {
PyObject *name, *value;
- name = PyList_GET_ITEM(names, i);
+
+ name = PyList_GET_ITEM(slotnames, i);
value = PyObject_GetAttr(obj, name);
- if (value == NULL)
+ if (value == NULL) {
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ goto error;
+ }
+ /* It is not an error if the attribute is not present. */
PyErr_Clear();
+ }
else {
- int err = PyDict_SetItem(slots, name,
- value);
+ int err = PyDict_SetItem(slots, name, value);
Py_DECREF(value);
- if (err)
- goto end;
- n++;
+ if (err) {
+ goto error;
+ }
+ }
+
+ /* The list is stored on the class so it may mutates while we
+ iterate over it */
+ if (slotnames_size != Py_SIZE(slotnames)) {
+ PyErr_Format(PyExc_RuntimeError,
+ "__slotsname__ changed size during iteration");
+ goto error;
+ }
+
+ /* We handle errors within the loop here. */
+ if (0) {
+ error:
+ Py_DECREF(slotnames);
+ Py_DECREF(slots);
+ Py_DECREF(state);
+ return NULL;
}
}
- if (n) {
- state = Py_BuildValue("(NO)", state, slots);
- if (state == NULL)
- goto end;
+
+ /* If we found some slot attributes, pack them in a tuple along
+ the orginal attribute dictionary. */
+ if (PyDict_Size(slots) > 0) {
+ PyObject *state2;
+
+ state2 = PyTuple_Pack(2, state, slots);
+ Py_DECREF(state);
+ if (state2 == NULL) {
+ Py_DECREF(slotnames);
+ Py_DECREF(slots);
+ return NULL;
+ }
+ state = state2;
}
+ Py_DECREF(slots);
+ }
+ Py_DECREF(slotnames);
+ }
+ else { /* getstate != NULL */
+ state = PyObject_CallObject(getstate, NULL);
+ Py_DECREF(getstate);
+ if (state == NULL)
+ return NULL;
+ }
+
+ return state;
+}
+
+Py_LOCAL(int)
+_PyObject_GetNewArguments(PyObject *obj, PyObject **args, PyObject **kwargs)
+{
+ PyObject *getnewargs, *getnewargs_ex;
+ _Py_IDENTIFIER(__getnewargs_ex__);
+ _Py_IDENTIFIER(__getnewargs__);
+
+ if (args == NULL || kwargs == NULL) {
+ PyErr_BadInternalCall();
+ return -1;
+ }
+
+ /* We first attempt to fetch the arguments for __new__ by calling
+ __getnewargs_ex__ on the object. */
+ getnewargs_ex = _PyObject_GetAttrId(obj, &PyId___getnewargs_ex__);
+ if (getnewargs_ex != NULL) {
+ PyObject *newargs = PyObject_CallObject(getnewargs_ex, NULL);
+ Py_DECREF(getnewargs_ex);
+ if (newargs == NULL) {
+ return -1;
+ }
+ if (!PyTuple_Check(newargs)) {
+ PyErr_Format(PyExc_TypeError,
+ "__getnewargs_ex__ should return a tuple, "
+ "not '%.200s'", Py_TYPE(newargs)->tp_name);
+ Py_DECREF(newargs);
+ return -1;
+ }
+ if (Py_SIZE(newargs) != 2) {
+ PyErr_Format(PyExc_ValueError,
+ "__getnewargs_ex__ should return a tuple of "
+ "length 2, not %zd", Py_SIZE(newargs));
+ Py_DECREF(newargs);
+ return -1;
+ }
+ *args = PyTuple_GET_ITEM(newargs, 0);
+ Py_INCREF(*args);
+ *kwargs = PyTuple_GET_ITEM(newargs, 1);
+ Py_INCREF(*kwargs);
+ Py_DECREF(newargs);
+
+ /* XXX We should perhaps allow None to be passed here. */
+ if (!PyTuple_Check(*args)) {
+ PyErr_Format(PyExc_TypeError,
+ "first item of the tuple returned by "
+ "__getnewargs_ex__ must be a tuple, not '%.200s'",
+ Py_TYPE(*args)->tp_name);
+ Py_CLEAR(*args);
+ Py_CLEAR(*kwargs);
+ return -1;
+ }
+ if (!PyDict_Check(*kwargs)) {
+ PyErr_Format(PyExc_TypeError,
+ "second item of the tuple returned by "
+ "__getnewargs_ex__ must be a dict, not '%.200s'",
+ Py_TYPE(*kwargs)->tp_name);
+ Py_CLEAR(*args);
+ Py_CLEAR(*kwargs);
+ return -1;
+ }
+ return 0;
+ } else {
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ return -1;
+ }
+ PyErr_Clear();
+ }
+
+ /* The object does not have __getnewargs_ex__ so we fallback on using
+ __getnewargs__ instead. */
+ getnewargs = _PyObject_GetAttrId(obj, &PyId___getnewargs__);
+ if (getnewargs != NULL) {
+ *args = PyObject_CallObject(getnewargs, NULL);
+ Py_DECREF(getnewargs);
+ if (*args == NULL) {
+ return -1;
+ }
+ if (!PyTuple_Check(*args)) {
+ PyErr_Format(PyExc_TypeError,
+ "__getnewargs__ should return a tuple, "
+ "not '%.200s'", Py_TYPE(*args)->tp_name);
+ Py_CLEAR(*args);
+ return -1;
+ }
+ *kwargs = NULL;
+ return 0;
+ } else {
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ return -1;
}
+ PyErr_Clear();
+ }
+
+ /* The object does not have __getnewargs_ex__ and __getnewargs__. This may
+ means __new__ does not takes any arguments on this object, or that the
+ object does not implement the reduce protocol for pickling or
+ copying. */
+ *args = NULL;
+ *kwargs = NULL;
+ return 0;
+}
+
+Py_LOCAL(int)
+_PyObject_GetItemsIter(PyObject *obj, PyObject **listitems,
+ PyObject **dictitems)
+{
+ if (listitems == NULL || dictitems == NULL) {
+ PyErr_BadInternalCall();
+ return -1;
}
if (!PyList_Check(obj)) {
- listitems = Py_None;
- Py_INCREF(listitems);
+ *listitems = Py_None;
+ Py_INCREF(*listitems);
}
else {
- listitems = PyObject_GetIter(obj);
+ *listitems = PyObject_GetIter(obj);
if (listitems == NULL)
- goto end;
+ return -1;
}
if (!PyDict_Check(obj)) {
- dictitems = Py_None;
- Py_INCREF(dictitems);
+ *dictitems = Py_None;
+ Py_INCREF(*dictitems);
}
else {
+ PyObject *items;
_Py_IDENTIFIER(items);
- PyObject *items = _PyObject_CallMethodId(obj, &PyId_items, "");
- if (items == NULL)
- goto end;
- dictitems = PyObject_GetIter(items);
+
+ items = _PyObject_CallMethodIdObjArgs(obj, &PyId_items, NULL);
+ if (items == NULL) {
+ Py_CLEAR(*listitems);
+ return -1;
+ }
+ *dictitems = PyObject_GetIter(items);
Py_DECREF(items);
- if (dictitems == NULL)
- goto end;
+ if (*dictitems == NULL) {
+ Py_CLEAR(*listitems);
+ return -1;
+ }
+ }
+
+ assert(*listitems != NULL && *dictitems != NULL);
+
+ return 0;
+}
+
+static PyObject *
+reduce_4(PyObject *obj)
+{
+ PyObject *args = NULL, *kwargs = NULL;
+ PyObject *copyreg;
+ PyObject *newobj, *newargs, *state, *listitems, *dictitems;
+ PyObject *result;
+ _Py_IDENTIFIER(__newobj_ex__);
+
+ if (_PyObject_GetNewArguments(obj, &args, &kwargs) < 0) {
+ return NULL;
+ }
+ if (args == NULL) {
+ args = PyTuple_New(0);
+ if (args == NULL)
+ return NULL;
+ }
+ if (kwargs == NULL) {
+ kwargs = PyDict_New();
+ if (kwargs == NULL)
+ return NULL;
}
copyreg = import_copyreg();
+ if (copyreg == NULL) {
+ Py_DECREF(args);
+ Py_DECREF(kwargs);
+ return NULL;
+ }
+ newobj = _PyObject_GetAttrId(copyreg, &PyId___newobj_ex__);
+ Py_DECREF(copyreg);
+ if (newobj == NULL) {
+ Py_DECREF(args);
+ Py_DECREF(kwargs);
+ return NULL;
+ }
+ newargs = PyTuple_Pack(3, Py_TYPE(obj), args, kwargs);
+ Py_DECREF(args);
+ Py_DECREF(kwargs);
+ if (newargs == NULL) {
+ Py_DECREF(newobj);
+ return NULL;
+ }
+ state = _PyObject_GetState(obj);
+ if (state == NULL) {
+ Py_DECREF(newobj);
+ Py_DECREF(newargs);
+ return NULL;
+ }
+ if (_PyObject_GetItemsIter(obj, &listitems, &dictitems) < 0) {
+ Py_DECREF(newobj);
+ Py_DECREF(newargs);
+ Py_DECREF(state);
+ return NULL;
+ }
+
+ result = PyTuple_Pack(5, newobj, newargs, state, listitems, dictitems);
+ Py_DECREF(newobj);
+ Py_DECREF(newargs);
+ Py_DECREF(state);
+ Py_DECREF(listitems);
+ Py_DECREF(dictitems);
+ return result;
+}
+
+static PyObject *
+reduce_2(PyObject *obj)
+{
+ PyObject *cls;
+ PyObject *args = NULL, *args2 = NULL, *kwargs = NULL;
+ PyObject *state = NULL, *listitems = NULL, *dictitems = NULL;
+ PyObject *copyreg = NULL, *newobj = NULL, *res = NULL;
+ Py_ssize_t i, n;
+ _Py_IDENTIFIER(__newobj__);
+
+ if (_PyObject_GetNewArguments(obj, &args, &kwargs) < 0) {
+ return NULL;
+ }
+ if (args == NULL) {
+ assert(kwargs == NULL);
+ args = PyTuple_New(0);
+ if (args == NULL) {
+ return NULL;
+ }
+ }
+ else if (kwargs != NULL) {
+ if (PyDict_Size(kwargs) > 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "must use protocol 4 or greater to copy this "
+ "object; since __getnewargs_ex__ returned "
+ "keyword arguments.");
+ Py_DECREF(args);
+ Py_DECREF(kwargs);
+ return NULL;
+ }
+ Py_CLEAR(kwargs);
+ }
+
+ state = _PyObject_GetState(obj);
+ if (state == NULL)
+ goto end;
+
+ if (_PyObject_GetItemsIter(obj, &listitems, &dictitems) < 0)
+ goto end;
+
+ copyreg = import_copyreg();
if (copyreg == NULL)
goto end;
newobj = _PyObject_GetAttrId(copyreg, &PyId___newobj__);
@@ -3560,6 +3839,7 @@ reduce_2(PyObject *obj)
args2 = PyTuple_New(n+1);
if (args2 == NULL)
goto end;
+ cls = (PyObject *) Py_TYPE(obj);
Py_INCREF(cls);
PyTuple_SET_ITEM(args2, 0, cls);
for (i = 0; i < n; i++) {
@@ -3573,9 +3853,7 @@ reduce_2(PyObject *obj)
end:
Py_XDECREF(args);
Py_XDECREF(args2);
- Py_XDECREF(slots);
Py_XDECREF(state);
- Py_XDECREF(names);
Py_XDECREF(listitems);
Py_XDECREF(dictitems);
Py_XDECREF(copyreg);
@@ -3603,7 +3881,9 @@ _common_reduce(PyObject *self, int proto)
{
PyObject *copyreg, *res;
- if (proto >= 2)
+ if (proto >= 4)
+ return reduce_4(self);
+ else if (proto >= 2)
return reduce_2(self);
copyreg = import_copyreg();