summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEric Snow <ericsnowcurrently@gmail.com>2025-04-30 23:34:05 (GMT)
committerGitHub <noreply@github.com>2025-04-30 23:34:05 (GMT)
commitcb35c11d82efd2959bda0397abcc1719bf6bb0cb (patch)
tree24401105bae0254b93e5d31488fa0a814d0f90be
parent6c522debc218d441756bf631abe8ec8d6c6f1c45 (diff)
downloadcpython-cb35c11d82efd2959bda0397abcc1719bf6bb0cb.zip
cpython-cb35c11d82efd2959bda0397abcc1719bf6bb0cb.tar.gz
cpython-cb35c11d82efd2959bda0397abcc1719bf6bb0cb.tar.bz2
gh-132775: Add _PyPickle_GetXIData() (gh-133107)
There's some extra complexity due to making sure we we get things right when handling functions and classes defined in the __main__ module. This is also reflected in the tests, including the addition of extra functions in test.support.import_helper.
-rw-r--r--Include/internal/pycore_crossinterp.h7
-rw-r--r--Lib/test/support/import_helper.py108
-rw-r--r--Lib/test/test_crossinterp.py541
-rw-r--r--Modules/_testinternalcapi.c5
-rw-r--r--Python/crossinterp.c452
5 files changed, 1057 insertions, 56 deletions
diff --git a/Include/internal/pycore_crossinterp.h b/Include/internal/pycore_crossinterp.h
index 4b7446a..4b4617f 100644
--- a/Include/internal/pycore_crossinterp.h
+++ b/Include/internal/pycore_crossinterp.h
@@ -171,6 +171,13 @@ PyAPI_FUNC(_PyBytes_data_t *) _PyBytes_GetXIDataWrapped(
xid_newobjfunc,
_PyXIData_t *);
+// _PyObject_GetXIData() for pickle
+PyAPI_DATA(PyObject *) _PyPickle_LoadFromXIData(_PyXIData_t *);
+PyAPI_FUNC(int) _PyPickle_GetXIData(
+ PyThreadState *,
+ PyObject *,
+ _PyXIData_t *);
+
// _PyObject_GetXIData() for marshal
PyAPI_FUNC(PyObject *) _PyMarshal_ReadObjectFromXIData(_PyXIData_t *);
PyAPI_FUNC(int) _PyMarshal_GetXIData(
diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py
index 42cfe9c..edb734d 100644
--- a/Lib/test/support/import_helper.py
+++ b/Lib/test/support/import_helper.py
@@ -1,6 +1,7 @@
import contextlib
import _imp
import importlib
+import importlib.machinery
import importlib.util
import os
import shutil
@@ -332,3 +333,110 @@ def ensure_lazy_imports(imported_module, modules_to_block):
)
from .script_helper import assert_python_ok
assert_python_ok("-S", "-c", script)
+
+
+@contextlib.contextmanager
+def module_restored(name):
+ """A context manager that restores a module to the original state."""
+ missing = object()
+ orig = sys.modules.get(name, missing)
+ if orig is None:
+ mod = importlib.import_module(name)
+ else:
+ mod = type(sys)(name)
+ mod.__dict__.update(orig.__dict__)
+ sys.modules[name] = mod
+ try:
+ yield mod
+ finally:
+ if orig is missing:
+ sys.modules.pop(name, None)
+ else:
+ sys.modules[name] = orig
+
+
+def create_module(name, loader=None, *, ispkg=False):
+ """Return a new, empty module."""
+ spec = importlib.machinery.ModuleSpec(
+ name,
+ loader,
+ origin='<import_helper>',
+ is_package=ispkg,
+ )
+ return importlib.util.module_from_spec(spec)
+
+
+def _ensure_module(name, ispkg, addparent, clearnone):
+ try:
+ mod = orig = sys.modules[name]
+ except KeyError:
+ mod = orig = None
+ missing = True
+ else:
+ missing = False
+ if mod is not None:
+ # It was already imported.
+ return mod, orig, missing
+ # Otherwise, None means it was explicitly disabled.
+
+ assert name != '__main__'
+ if not missing:
+ assert orig is None, (name, sys.modules[name])
+ if not clearnone:
+ raise ModuleNotFoundError(name)
+ del sys.modules[name]
+ # Try normal import, then fall back to adding the module.
+ try:
+ mod = importlib.import_module(name)
+ except ModuleNotFoundError:
+ if addparent and not clearnone:
+ addparent = None
+ mod = _add_module(name, ispkg, addparent)
+ return mod, orig, missing
+
+
+def _add_module(spec, ispkg, addparent):
+ if isinstance(spec, str):
+ name = spec
+ mod = create_module(name, ispkg=ispkg)
+ spec = mod.__spec__
+ else:
+ name = spec.name
+ mod = importlib.util.module_from_spec(spec)
+ sys.modules[name] = mod
+ if addparent is not False and spec.parent:
+ _ensure_module(spec.parent, True, addparent, bool(addparent))
+ return mod
+
+
+def add_module(spec, *, parents=True):
+ """Return the module after creating it and adding it to sys.modules.
+
+ If parents is True then also create any missing parents.
+ """
+ return _add_module(spec, False, parents)
+
+
+def add_package(spec, *, parents=True):
+ """Return the module after creating it and adding it to sys.modules.
+
+ If parents is True then also create any missing parents.
+ """
+ return _add_module(spec, True, parents)
+
+
+def ensure_module_imported(name, *, clearnone=True):
+ """Return the corresponding module.
+
+ If it was already imported then return that. Otherwise, try
+ importing it (optionally clear it first if None). If that fails
+ then create a new empty module.
+
+ It can be helpful to combine this with ready_to_import() and/or
+ isolated_modules().
+ """
+ if sys.modules.get(name) is not None:
+ mod = sys.modules[name]
+ else:
+ mod, _, _ = _force_import(name, False, True, clearnone)
+ return mod
diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py
index 5ebb78b..32d6fd4 100644
--- a/Lib/test/test_crossinterp.py
+++ b/Lib/test/test_crossinterp.py
@@ -1,3 +1,6 @@
+import contextlib
+import importlib
+import importlib.util
import itertools
import sys
import types
@@ -9,7 +12,7 @@ _testinternalcapi = import_helper.import_module('_testinternalcapi')
_interpreters = import_helper.import_module('_interpreters')
from _interpreters import NotShareableError
-
+from test import _code_definitions as code_defs
from test import _crossinterp_definitions as defs
@@ -21,6 +24,88 @@ OTHER_TYPES = [o for n, o in vars(types).items()
if (isinstance(o, type) and
n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
+DEFS = defs
+with open(code_defs.__file__) as infile:
+ _code_defs_text = infile.read()
+with open(DEFS.__file__) as infile:
+ _defs_text = infile.read()
+ _defs_text = _defs_text.replace('from ', '# from ')
+DEFS_TEXT = f"""
+#######################################
+# from {code_defs.__file__}
+
+{_code_defs_text}
+
+#######################################
+# from {defs.__file__}
+
+{_defs_text}
+"""
+del infile, _code_defs_text, _defs_text
+
+
+def load_defs(module=None):
+ """Return a new copy of the test._crossinterp_definitions module.
+
+ The module's __name__ matches the "module" arg, which is either
+ a str or a module.
+
+ If the "module" arg is a module then the just-loaded defs are also
+ copied into that module.
+
+ Note that the new module is not added to sys.modules.
+ """
+ if module is None:
+ modname = DEFS.__name__
+ elif isinstance(module, str):
+ modname = module
+ module = None
+ else:
+ modname = module.__name__
+ # Create the new module and populate it.
+ defs = import_helper.create_module(modname)
+ defs.__file__ = DEFS.__file__
+ exec(DEFS_TEXT, defs.__dict__)
+ # Copy the defs into the module arg, if any.
+ if module is not None:
+ for name, value in defs.__dict__.items():
+ if name.startswith('_'):
+ continue
+ assert not hasattr(module, name), (name, getattr(module, name))
+ setattr(module, name, value)
+ return defs
+
+
+@contextlib.contextmanager
+def using___main__():
+ """Make sure __main__ module exists (and clean up after)."""
+ modname = '__main__'
+ if modname not in sys.modules:
+ with import_helper.isolated_modules():
+ yield import_helper.add_module(modname)
+ else:
+ with import_helper.module_restored(modname) as mod:
+ yield mod
+
+
+@contextlib.contextmanager
+def temp_module(modname):
+ """Create the module and add to sys.modules, then remove it after."""
+ assert modname not in sys.modules, (modname,)
+ with import_helper.isolated_modules():
+ yield import_helper.add_module(modname)
+
+
+@contextlib.contextmanager
+def missing_defs_module(modname, *, prep=False):
+ assert modname not in sys.modules, (modname,)
+ if prep:
+ with import_helper.ready_to_import(modname, DEFS_TEXT):
+ yield modname
+ else:
+ with import_helper.isolated_modules():
+ yield modname
+
class _GetXIDataTests(unittest.TestCase):
@@ -32,52 +117,49 @@ class _GetXIDataTests(unittest.TestCase):
def get_roundtrip(self, obj, *, mode=None):
mode = self._resolve_mode(mode)
- xid =_testinternalcapi.get_crossinterp_data(obj, mode)
+ return self._get_roundtrip(obj, mode)
+
+ def _get_roundtrip(self, obj, mode):
+ xid = _testinternalcapi.get_crossinterp_data(obj, mode)
return _testinternalcapi.restore_crossinterp_data(xid)
- def iter_roundtrip_values(self, values, *, mode=None):
+ def assert_roundtrip_identical(self, values, *, mode=None):
mode = self._resolve_mode(mode)
for obj in values:
with self.subTest(obj):
- xid = _testinternalcapi.get_crossinterp_data(obj, mode)
- got = _testinternalcapi.restore_crossinterp_data(xid)
- yield obj, got
-
- def assert_roundtrip_identical(self, values, *, mode=None):
- for obj, got in self.iter_roundtrip_values(values, mode=mode):
- # XXX What about between interpreters?
- self.assertIs(got, obj)
+ got = self._get_roundtrip(obj, mode)
+ self.assertIs(got, obj)
def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None):
- for obj, got in self.iter_roundtrip_values(values, mode=mode):
- self.assertEqual(got, obj)
- self.assertIs(type(got),
- type(obj) if expecttype is None else expecttype)
-
-# def assert_roundtrip_equal_not_identical(self, values, *,
-# mode=None, expecttype=None):
-# mode = self._resolve_mode(mode)
-# for obj in values:
-# cls = type(obj)
-# with self.subTest(obj):
-# got = self._get_roundtrip(obj, mode)
-# self.assertIsNot(got, obj)
-# self.assertIs(type(got), type(obj))
-# self.assertEqual(got, obj)
-# self.assertIs(type(got),
-# cls if expecttype is None else expecttype)
-#
-# def assert_roundtrip_not_equal(self, values, *, mode=None, expecttype=None):
-# mode = self._resolve_mode(mode)
-# for obj in values:
-# cls = type(obj)
-# with self.subTest(obj):
-# got = self._get_roundtrip(obj, mode)
-# self.assertIsNot(got, obj)
-# self.assertIs(type(got), type(obj))
-# self.assertNotEqual(got, obj)
-# self.assertIs(type(got),
-# cls if expecttype is None else expecttype)
+ mode = self._resolve_mode(mode)
+ for obj in values:
+ with self.subTest(obj):
+ got = self._get_roundtrip(obj, mode)
+ self.assertEqual(got, obj)
+ self.assertIs(type(got),
+ type(obj) if expecttype is None else expecttype)
+
+ def assert_roundtrip_equal_not_identical(self, values, *,
+ mode=None, expecttype=None):
+ mode = self._resolve_mode(mode)
+ for obj in values:
+ with self.subTest(obj):
+ got = self._get_roundtrip(obj, mode)
+ self.assertIsNot(got, obj)
+ self.assertIs(type(got),
+ type(obj) if expecttype is None else expecttype)
+ self.assertEqual(got, obj)
+
+ def assert_roundtrip_not_equal(self, values, *,
+ mode=None, expecttype=None):
+ mode = self._resolve_mode(mode)
+ for obj in values:
+ with self.subTest(obj):
+ got = self._get_roundtrip(obj, mode)
+ self.assertIsNot(got, obj)
+ self.assertIs(type(got),
+ type(obj) if expecttype is None else expecttype)
+ self.assertNotEqual(got, obj)
def assert_not_shareable(self, values, exctype=None, *, mode=None):
mode = self._resolve_mode(mode)
@@ -95,6 +177,363 @@ class _GetXIDataTests(unittest.TestCase):
return mode
+class PickleTests(_GetXIDataTests):
+
+ MODE = 'pickle'
+
+ def test_shareable(self):
+ self.assert_roundtrip_equal([
+ # singletons
+ None,
+ True,
+ False,
+ # bytes
+ *(i.to_bytes(2, 'little', signed=True)
+ for i in range(-1, 258)),
+ # str
+ 'hello world',
+ '你好世界',
+ '',
+ # int
+ sys.maxsize,
+ -sys.maxsize - 1,
+ *range(-1, 258),
+ # float
+ 0.0,
+ 1.1,
+ -1.0,
+ 0.12345678,
+ -0.12345678,
+ # tuple
+ (),
+ (1,),
+ ("hello", "world", ),
+ (1, True, "hello"),
+ ((1,),),
+ ((1, 2), (3, 4)),
+ ((1, 2), (3, 4), (5, 6)),
+ ])
+ # not shareable using xidata
+ self.assert_roundtrip_equal([
+ # int
+ sys.maxsize + 1,
+ -sys.maxsize - 2,
+ 2**1000,
+ # tuple
+ (0, 1.0, []),
+ (0, 1.0, {}),
+ (0, 1.0, ([],)),
+ (0, 1.0, ({},)),
+ ])
+
+ def test_list(self):
+ self.assert_roundtrip_equal_not_identical([
+ [],
+ [1, 2, 3],
+ [[1], (2,), {3: 4}],
+ ])
+
+ def test_dict(self):
+ self.assert_roundtrip_equal_not_identical([
+ {},
+ {1: 7, 2: 8, 3: 9},
+ {1: [1], 2: (2,), 3: {3: 4}},
+ ])
+
+ def test_set(self):
+ self.assert_roundtrip_equal_not_identical([
+ set(),
+ {1, 2, 3},
+ {frozenset({1}), (2,)},
+ ])
+
+ # classes
+
+ def assert_class_defs_same(self, defs):
+ # Unpickle relative to the unchanged original module.
+ self.assert_roundtrip_identical(defs.TOP_CLASSES)
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ if cls in defs.CLASSES_WITHOUT_EQUALITY:
+ continue
+ instances.append(cls(*args))
+ self.assert_roundtrip_equal_not_identical(instances)
+
+ # these don't compare equal
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ if cls not in defs.CLASSES_WITHOUT_EQUALITY:
+ continue
+ instances.append(cls(*args))
+ self.assert_roundtrip_not_equal(instances)
+
+ def assert_class_defs_other_pickle(self, defs, mod):
+ # Pickle relative to a different module than the original.
+ for cls in defs.TOP_CLASSES:
+ assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__))
+ self.assert_not_shareable(defs.TOP_CLASSES)
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ instances.append(cls(*args))
+ self.assert_not_shareable(instances)
+
+ def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False):
+ # Unpickle relative to a different module than the original.
+ for cls in defs.TOP_CLASSES:
+ assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__))
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ with self.subTest(cls):
+ setattr(mod, cls.__name__, cls)
+ xid = self.get_xidata(cls)
+ inst = cls(*args)
+ instxid = self.get_xidata(inst)
+ instances.append(
+ (cls, xid, inst, instxid))
+
+ for cls, xid, inst, instxid in instances:
+ with self.subTest(cls):
+ delattr(mod, cls.__name__)
+ if fail:
+ with self.assertRaises(NotShareableError):
+ _testinternalcapi.restore_crossinterp_data(xid)
+ continue
+ got = _testinternalcapi.restore_crossinterp_data(xid)
+ self.assertIsNot(got, cls)
+ self.assertNotEqual(got, cls)
+
+ gotcls = got
+ got = _testinternalcapi.restore_crossinterp_data(instxid)
+ self.assertIsNot(got, inst)
+ self.assertIs(type(got), gotcls)
+ if cls in defs.CLASSES_WITHOUT_EQUALITY:
+ self.assertNotEqual(got, inst)
+ elif cls in defs.BUILTIN_SUBCLASSES:
+ self.assertEqual(got, inst)
+ else:
+ self.assertNotEqual(got, inst)
+
+ def assert_class_defs_not_shareable(self, defs):
+ self.assert_not_shareable(defs.TOP_CLASSES)
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ instances.append(cls(*args))
+ self.assert_not_shareable(instances)
+
+ def test_user_class_normal(self):
+ self.assert_class_defs_same(defs)
+
+ def test_user_class_in___main__(self):
+ with using___main__() as mod:
+ defs = load_defs(mod)
+ self.assert_class_defs_same(defs)
+
+ def test_user_class_not_in___main___with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_not_in___main___without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_not_in___main___unpickle_with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_class_defs_other_unpickle(defs, mod)
+
+ def test_user_class_not_in___main___unpickle_without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_class_defs_other_unpickle(defs, mod, fail=True)
+
+ def test_user_class_in_module(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod)
+ self.assert_class_defs_same(defs)
+
+ def test_user_class_not_in_module_with_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ assert defs.__file__
+ # For now, we only address this case for __main__.
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_not_in_module_without_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ defs.__file__ = None
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_module_missing_then_imported(self):
+ with missing_defs_module('__spam__', prep=True) as modname:
+ defs = load_defs(modname)
+ # For now, we only address this case for __main__.
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_module_missing_not_available(self):
+ with missing_defs_module('__spam__') as modname:
+ defs = load_defs(modname)
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_nested_class(self):
+ eggs = defs.EggsNested()
+ with self.assertRaises(NotShareableError):
+ self.get_roundtrip(eggs)
+
+ # functions
+
+ def assert_func_defs_same(self, defs):
+ # Unpickle relative to the unchanged original module.
+ self.assert_roundtrip_identical(defs.TOP_FUNCTIONS)
+
+ def assert_func_defs_other_pickle(self, defs, mod):
+ # Pickle relative to a different module than the original.
+ for func in defs.TOP_FUNCTIONS:
+ assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__))
+ self.assert_not_shareable(defs.TOP_FUNCTIONS)
+
+ def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False):
+ # Unpickle relative to a different module than the original.
+ for func in defs.TOP_FUNCTIONS:
+ assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__))
+
+ captured = []
+ for func in defs.TOP_FUNCTIONS:
+ with self.subTest(func):
+ setattr(mod, func.__name__, func)
+ xid = self.get_xidata(func)
+ captured.append(
+ (func, xid))
+
+ for func, xid in captured:
+ with self.subTest(func):
+ delattr(mod, func.__name__)
+ if fail:
+ with self.assertRaises(NotShareableError):
+ _testinternalcapi.restore_crossinterp_data(xid)
+ continue
+ got = _testinternalcapi.restore_crossinterp_data(xid)
+ self.assertIsNot(got, func)
+ self.assertNotEqual(got, func)
+
+ def assert_func_defs_not_shareable(self, defs):
+ self.assert_not_shareable(defs.TOP_FUNCTIONS)
+
+ def test_user_function_normal(self):
+# self.assert_roundtrip_equal(defs.TOP_FUNCTIONS)
+ self.assert_func_defs_same(defs)
+
+ def test_user_func_in___main__(self):
+ with using___main__() as mod:
+ defs = load_defs(mod)
+ self.assert_func_defs_same(defs)
+
+ def test_user_func_not_in___main___with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_not_in___main___without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_not_in___main___unpickle_with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_func_defs_other_unpickle(defs, mod)
+
+ def test_user_func_not_in___main___unpickle_without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_func_defs_other_unpickle(defs, mod, fail=True)
+
+ def test_user_func_in_module(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod)
+ self.assert_func_defs_same(defs)
+
+ def test_user_func_not_in_module_with_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ assert defs.__file__
+ # For now, we only address this case for __main__.
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_not_in_module_without_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ defs.__file__ = None
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_module_missing_then_imported(self):
+ with missing_defs_module('__spam__', prep=True) as modname:
+ defs = load_defs(modname)
+ # For now, we only address this case for __main__.
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_module_missing_not_available(self):
+ with missing_defs_module('__spam__') as modname:
+ defs = load_defs(modname)
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_nested_function(self):
+ self.assert_not_shareable(defs.NESTED_FUNCTIONS)
+
+ # exceptions
+
+ def test_user_exception_normal(self):
+ self.assert_roundtrip_not_equal([
+ defs.MimimalError('error!'),
+ ])
+ self.assert_roundtrip_equal_not_identical([
+ defs.RichError('error!', 42),
+ ])
+
+ def test_builtin_exception(self):
+ msg = 'error!'
+ try:
+ raise Exception
+ except Exception as exc:
+ caught = exc
+ special = {
+ BaseExceptionGroup: (msg, [caught]),
+ ExceptionGroup: (msg, [caught]),
+# UnicodeError: (None, msg, None, None, None),
+ UnicodeEncodeError: ('utf-8', '', 1, 3, msg),
+ UnicodeDecodeError: ('utf-8', b'', 1, 3, msg),
+ UnicodeTranslateError: ('', 1, 3, msg),
+ }
+ exceptions = []
+ for cls in EXCEPTION_TYPES:
+ args = special.get(cls) or (msg,)
+ exceptions.append(cls(*args))
+
+ self.assert_roundtrip_not_equal(exceptions)
+
+
class MarshalTests(_GetXIDataTests):
MODE = 'marshal'
@@ -444,22 +883,12 @@ class ShareableTypeTests(_GetXIDataTests):
])
def test_class(self):
- self.assert_not_shareable([
- defs.Spam,
- defs.SpamOkay,
- defs.SpamFull,
- defs.SubSpamFull,
- defs.SubTuple,
- defs.EggsNested,
- ])
- self.assert_not_shareable([
- defs.Spam(),
- defs.SpamOkay(),
- defs.SpamFull(1, 2, 3),
- defs.SubSpamFull(1, 2, 3),
- defs.SubTuple([1, 2, 3]),
- defs.EggsNested(),
- ])
+ self.assert_not_shareable(defs.CLASSES)
+
+ instances = []
+ for cls, args in defs.CLASSES.items():
+ instances.append(cls(*args))
+ self.assert_not_shareable(instances)
def test_builtin_type(self):
self.assert_not_shareable([
diff --git a/Modules/_testinternalcapi.c b/Modules/_testinternalcapi.c
index 4bfe88f..812737e 100644
--- a/Modules/_testinternalcapi.c
+++ b/Modules/_testinternalcapi.c
@@ -1939,6 +1939,11 @@ get_crossinterp_data(PyObject *self, PyObject *args, PyObject *kwargs)
goto error;
}
}
+ else if (strcmp(mode, "pickle") == 0) {
+ if (_PyPickle_GetXIData(tstate, obj, xidata) != 0) {
+ goto error;
+ }
+ }
else if (strcmp(mode, "marshal") == 0) {
if (_PyMarshal_GetXIData(tstate, obj, xidata) != 0) {
goto error;
diff --git a/Python/crossinterp.c b/Python/crossinterp.c
index 753d784..a9f9b78 100644
--- a/Python/crossinterp.c
+++ b/Python/crossinterp.c
@@ -3,6 +3,7 @@
#include "Python.h"
#include "marshal.h" // PyMarshal_WriteObjectToString()
+#include "osdefs.h" // MAXPATHLEN
#include "pycore_ceval.h" // _Py_simple_func
#include "pycore_crossinterp.h" // _PyXIData_t
#include "pycore_initconfig.h" // _PyStatus_OK()
@@ -10,6 +11,155 @@
#include "pycore_typeobject.h" // _PyStaticType_InitBuiltin()
+static Py_ssize_t
+_Py_GetMainfile(char *buffer, size_t maxlen)
+{
+ // We don't expect subinterpreters to have the __main__ module's
+ // __name__ set, but proceed just in case.
+ PyThreadState *tstate = _PyThreadState_GET();
+ PyObject *module = _Py_GetMainModule(tstate);
+ if (_Py_CheckMainModule(module) < 0) {
+ return -1;
+ }
+ Py_ssize_t size = _PyModule_GetFilenameUTF8(module, buffer, maxlen);
+ Py_DECREF(module);
+ return size;
+}
+
+
+static PyObject *
+import_get_module(PyThreadState *tstate, const char *modname)
+{
+ PyObject *module = NULL;
+ if (strcmp(modname, "__main__") == 0) {
+ module = _Py_GetMainModule(tstate);
+ if (_Py_CheckMainModule(module) < 0) {
+ assert(_PyErr_Occurred(tstate));
+ return NULL;
+ }
+ }
+ else {
+ module = PyImport_ImportModule(modname);
+ if (module == NULL) {
+ return NULL;
+ }
+ }
+ return module;
+}
+
+
+static PyObject *
+runpy_run_path(const char *filename, const char *modname)
+{
+ PyObject *run_path = PyImport_ImportModuleAttrString("runpy", "run_path");
+ if (run_path == NULL) {
+ return NULL;
+ }
+ PyObject *args = Py_BuildValue("(sOs)", filename, Py_None, modname);
+ if (args == NULL) {
+ Py_DECREF(run_path);
+ return NULL;
+ }
+ PyObject *ns = PyObject_Call(run_path, args, NULL);
+ Py_DECREF(run_path);
+ Py_DECREF(args);
+ return ns;
+}
+
+
+static PyObject *
+pyerr_get_message(PyObject *exc)
+{
+ assert(!PyErr_Occurred());
+ PyObject *args = PyException_GetArgs(exc);
+ if (args == NULL || args == Py_None || PyObject_Size(args) < 1) {
+ return NULL;
+ }
+ if (PyUnicode_Check(args)) {
+ return args;
+ }
+ PyObject *msg = PySequence_GetItem(args, 0);
+ Py_DECREF(args);
+ if (msg == NULL) {
+ PyErr_Clear();
+ return NULL;
+ }
+ if (!PyUnicode_Check(msg)) {
+ Py_DECREF(msg);
+ return NULL;
+ }
+ return msg;
+}
+
+#define MAX_MODNAME (255)
+#define MAX_ATTRNAME (255)
+
+struct attributeerror_info {
+ char modname[MAX_MODNAME+1];
+ char attrname[MAX_ATTRNAME+1];
+};
+
+static int
+_parse_attributeerror(PyObject *exc, struct attributeerror_info *info)
+{
+ assert(exc != NULL);
+ assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
+ int res = -1;
+
+ PyObject *msgobj = pyerr_get_message(exc);
+ if (msgobj == NULL) {
+ return -1;
+ }
+ const char *err = PyUnicode_AsUTF8(msgobj);
+
+ if (strncmp(err, "module '", 8) != 0) {
+ goto finally;
+ }
+ err += 8;
+
+ const char *matched = strchr(err, '\'');
+ if (matched == NULL) {
+ goto finally;
+ }
+ Py_ssize_t len = matched - err;
+ if (len > MAX_MODNAME) {
+ goto finally;
+ }
+ (void)strncpy(info->modname, err, len);
+ info->modname[len] = '\0';
+ err = matched;
+
+ if (strncmp(err, "' has no attribute '", 20) != 0) {
+ goto finally;
+ }
+ err += 20;
+
+ matched = strchr(err, '\'');
+ if (matched == NULL) {
+ goto finally;
+ }
+ len = matched - err;
+ if (len > MAX_ATTRNAME) {
+ goto finally;
+ }
+ (void)strncpy(info->attrname, err, len);
+ info->attrname[len] = '\0';
+ err = matched + 1;
+
+ if (strlen(err) > 0) {
+ goto finally;
+ }
+ res = 0;
+
+finally:
+ Py_DECREF(msgobj);
+ return res;
+}
+
+#undef MAX_MODNAME
+#undef MAX_ATTRNAME
+
+
/**************/
/* exceptions */
/**************/
@@ -287,6 +437,308 @@ _PyObject_GetXIData(PyThreadState *tstate,
}
+/* pickle C-API */
+
+struct _pickle_context {
+ PyThreadState *tstate;
+};
+
+static PyObject *
+_PyPickle_Dumps(struct _pickle_context *ctx, PyObject *obj)
+{
+ PyObject *dumps = PyImport_ImportModuleAttrString("pickle", "dumps");
+ if (dumps == NULL) {
+ return NULL;
+ }
+ PyObject *bytes = PyObject_CallOneArg(dumps, obj);
+ Py_DECREF(dumps);
+ return bytes;
+}
+
+
+struct sync_module_result {
+ PyObject *module;
+ PyObject *loaded;
+ PyObject *failed;
+};
+
+struct sync_module {
+ const char *filename;
+ char _filename[MAXPATHLEN+1];
+ struct sync_module_result cached;
+};
+
+static void
+sync_module_clear(struct sync_module *data)
+{
+ data->filename = NULL;
+ Py_CLEAR(data->cached.module);
+ Py_CLEAR(data->cached.loaded);
+ Py_CLEAR(data->cached.failed);
+}
+
+
+struct _unpickle_context {
+ PyThreadState *tstate;
+ // We only special-case the __main__ module,
+ // since other modules behave consistently.
+ struct sync_module main;
+};
+
+static void
+_unpickle_context_clear(struct _unpickle_context *ctx)
+{
+ sync_module_clear(&ctx->main);
+}
+
+static struct sync_module_result
+_unpickle_context_get_module(struct _unpickle_context *ctx,
+ const char *modname)
+{
+ if (strcmp(modname, "__main__") == 0) {
+ return ctx->main.cached;
+ }
+ else {
+ return (struct sync_module_result){
+ .failed = PyExc_NotImplementedError,
+ };
+ }
+}
+
+static struct sync_module_result
+_unpickle_context_set_module(struct _unpickle_context *ctx,
+ const char *modname)
+{
+ struct sync_module_result res = {0};
+ struct sync_module_result *cached = NULL;
+ const char *filename = NULL;
+ if (strcmp(modname, "__main__") == 0) {
+ cached = &ctx->main.cached;
+ filename = ctx->main.filename;
+ }
+ else {
+ res.failed = PyExc_NotImplementedError;
+ goto finally;
+ }
+
+ res.module = import_get_module(ctx->tstate, modname);
+ if (res.module == NULL) {
+ res.failed = _PyErr_GetRaisedException(ctx->tstate);
+ assert(res.failed != NULL);
+ goto finally;
+ }
+
+ if (filename == NULL) {
+ Py_CLEAR(res.module);
+ res.failed = PyExc_NotImplementedError;
+ goto finally;
+ }
+ res.loaded = runpy_run_path(filename, modname);
+ if (res.loaded == NULL) {
+ Py_CLEAR(res.module);
+ res.failed = _PyErr_GetRaisedException(ctx->tstate);
+ assert(res.failed != NULL);
+ goto finally;
+ }
+
+finally:
+ if (cached != NULL) {
+ assert(cached->module == NULL);
+ assert(cached->loaded == NULL);
+ assert(cached->failed == NULL);
+ *cached = res;
+ }
+ return res;
+}
+
+
+static int
+_handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc)
+{
+ // The caller must check if an exception is set or not when -1 is returned.
+ assert(!_PyErr_Occurred(ctx->tstate));
+ assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
+ struct attributeerror_info info;
+ if (_parse_attributeerror(exc, &info) < 0) {
+ return -1;
+ }
+
+ // Get the module.
+ struct sync_module_result mod = _unpickle_context_get_module(ctx, info.modname);
+ if (mod.failed != NULL) {
+ // It must have failed previously.
+ return -1;
+ }
+ if (mod.module == NULL) {
+ mod = _unpickle_context_set_module(ctx, info.modname);
+ if (mod.failed != NULL) {
+ return -1;
+ }
+ assert(mod.module != NULL);
+ }
+
+ // Bail out if it is unexpectedly set already.
+ if (PyObject_HasAttrString(mod.module, info.attrname)) {
+ return -1;
+ }
+
+ // Try setting the attribute.
+ PyObject *value = NULL;
+ if (PyDict_GetItemStringRef(mod.loaded, info.attrname, &value) <= 0) {
+ return -1;
+ }
+ assert(value != NULL);
+ int res = PyObject_SetAttrString(mod.module, info.attrname, value);
+ Py_DECREF(value);
+ if (res < 0) {
+ return -1;
+ }
+
+ return 0;
+}
+
+static PyObject *
+_PyPickle_Loads(struct _unpickle_context *ctx, PyObject *pickled)
+{
+ PyObject *loads = PyImport_ImportModuleAttrString("pickle", "loads");
+ if (loads == NULL) {
+ return NULL;
+ }
+ PyObject *obj = PyObject_CallOneArg(loads, pickled);
+ if (ctx != NULL) {
+ while (obj == NULL) {
+ assert(_PyErr_Occurred(ctx->tstate));
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ // We leave other failures unhandled.
+ break;
+ }
+ // Try setting the attr if not set.
+ PyObject *exc = _PyErr_GetRaisedException(ctx->tstate);
+ if (_handle_unpickle_missing_attr(ctx, exc) < 0) {
+ // Any resulting exceptions are ignored
+ // in favor of the original.
+ _PyErr_SetRaisedException(ctx->tstate, exc);
+ break;
+ }
+ Py_CLEAR(exc);
+ // Retry with the attribute set.
+ obj = PyObject_CallOneArg(loads, pickled);
+ }
+ }
+ Py_DECREF(loads);
+ return obj;
+}
+
+
+/* pickle wrapper */
+
+struct _pickle_xid_context {
+ // __main__.__file__
+ struct {
+ const char *utf8;
+ size_t len;
+ char _utf8[MAXPATHLEN+1];
+ } mainfile;
+};
+
+static int
+_set_pickle_xid_context(PyThreadState *tstate, struct _pickle_xid_context *ctx)
+{
+ // Set mainfile if possible.
+ Py_ssize_t len = _Py_GetMainfile(ctx->mainfile._utf8, MAXPATHLEN);
+ if (len < 0) {
+ // For now we ignore any exceptions.
+ PyErr_Clear();
+ }
+ else if (len > 0) {
+ ctx->mainfile.utf8 = ctx->mainfile._utf8;
+ ctx->mainfile.len = (size_t)len;
+ }
+
+ return 0;
+}
+
+
+struct _shared_pickle_data {
+ _PyBytes_data_t pickled; // Must be first if we use _PyBytes_FromXIData().
+ struct _pickle_xid_context ctx;
+};
+
+PyObject *
+_PyPickle_LoadFromXIData(_PyXIData_t *xidata)
+{
+ PyThreadState *tstate = _PyThreadState_GET();
+ struct _shared_pickle_data *shared =
+ (struct _shared_pickle_data *)xidata->data;
+ // We avoid copying the pickled data by wrapping it in a memoryview.
+ // The alternative is to get a bytes object using _PyBytes_FromXIData().
+ PyObject *pickled = PyMemoryView_FromMemory(
+ (char *)shared->pickled.bytes, shared->pickled.len, PyBUF_READ);
+ if (pickled == NULL) {
+ return NULL;
+ }
+
+ // Unpickle the object.
+ struct _unpickle_context ctx = {
+ .tstate = tstate,
+ .main = {
+ .filename = shared->ctx.mainfile.utf8,
+ },
+ };
+ PyObject *obj = _PyPickle_Loads(&ctx, pickled);
+ Py_DECREF(pickled);
+ _unpickle_context_clear(&ctx);
+ if (obj == NULL) {
+ PyObject *cause = _PyErr_GetRaisedException(tstate);
+ assert(cause != NULL);
+ _set_xid_lookup_failure(
+ tstate, NULL, "object could not be unpickled", cause);
+ Py_DECREF(cause);
+ }
+ return obj;
+}
+
+
+int
+_PyPickle_GetXIData(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata)
+{
+ // Pickle the object.
+ struct _pickle_context ctx = {
+ .tstate = tstate,
+ };
+ PyObject *bytes = _PyPickle_Dumps(&ctx, obj);
+ if (bytes == NULL) {
+ PyObject *cause = _PyErr_GetRaisedException(tstate);
+ assert(cause != NULL);
+ _set_xid_lookup_failure(
+ tstate, NULL, "object could not be pickled", cause);
+ Py_DECREF(cause);
+ return -1;
+ }
+
+ // If we had an "unwrapper" mechnanism, we could call
+ // _PyObject_GetXIData() on the bytes object directly and add
+ // a simple unwrapper to call pickle.loads() on the bytes.
+ size_t size = sizeof(struct _shared_pickle_data);
+ struct _shared_pickle_data *shared =
+ (struct _shared_pickle_data *)_PyBytes_GetXIDataWrapped(
+ tstate, bytes, size, _PyPickle_LoadFromXIData, xidata);
+ Py_DECREF(bytes);
+ if (shared == NULL) {
+ return -1;
+ }
+
+ // If it mattered, we could skip getting __main__.__file__
+ // when "__main__" doesn't show up in the pickle bytes.
+ if (_set_pickle_xid_context(tstate, &shared->ctx) < 0) {
+ _xidata_clear(xidata);
+ return -1;
+ }
+
+ return 0;
+}
+
+
/* marshal wrapper */
PyObject *