summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_crossinterp.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_crossinterp.py')
-rw-r--r--Lib/test/test_crossinterp.py541
1 files changed, 485 insertions, 56 deletions
diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py
index 5ebb78b..32d6fd4 100644
--- a/Lib/test/test_crossinterp.py
+++ b/Lib/test/test_crossinterp.py
@@ -1,3 +1,6 @@
+import contextlib
+import importlib
+import importlib.util
import itertools
import sys
import types
@@ -9,7 +12,7 @@ _testinternalcapi = import_helper.import_module('_testinternalcapi')
_interpreters = import_helper.import_module('_interpreters')
from _interpreters import NotShareableError
-
+from test import _code_definitions as code_defs
from test import _crossinterp_definitions as defs
@@ -21,6 +24,88 @@ OTHER_TYPES = [o for n, o in vars(types).items()
if (isinstance(o, type) and
n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
+DEFS = defs
+with open(code_defs.__file__) as infile:
+ _code_defs_text = infile.read()
+with open(DEFS.__file__) as infile:
+ _defs_text = infile.read()
+ _defs_text = _defs_text.replace('from ', '# from ')
+DEFS_TEXT = f"""
+#######################################
+# from {code_defs.__file__}
+
+{_code_defs_text}
+
+#######################################
+# from {defs.__file__}
+
+{_defs_text}
+"""
+del infile, _code_defs_text, _defs_text
+
+
+def load_defs(module=None):
+ """Return a new copy of the test._crossinterp_definitions module.
+
+ The module's __name__ matches the "module" arg, which is either
+ a str or a module.
+
+ If the "module" arg is a module then the just-loaded defs are also
+ copied into that module.
+
+ Note that the new module is not added to sys.modules.
+ """
+ if module is None:
+ modname = DEFS.__name__
+ elif isinstance(module, str):
+ modname = module
+ module = None
+ else:
+ modname = module.__name__
+ # Create the new module and populate it.
+ defs = import_helper.create_module(modname)
+ defs.__file__ = DEFS.__file__
+ exec(DEFS_TEXT, defs.__dict__)
+ # Copy the defs into the module arg, if any.
+ if module is not None:
+ for name, value in defs.__dict__.items():
+ if name.startswith('_'):
+ continue
+ assert not hasattr(module, name), (name, getattr(module, name))
+ setattr(module, name, value)
+ return defs
+
+
+@contextlib.contextmanager
+def using___main__():
+ """Make sure __main__ module exists (and clean up after)."""
+ modname = '__main__'
+ if modname not in sys.modules:
+ with import_helper.isolated_modules():
+ yield import_helper.add_module(modname)
+ else:
+ with import_helper.module_restored(modname) as mod:
+ yield mod
+
+
+@contextlib.contextmanager
+def temp_module(modname):
+ """Create the module and add to sys.modules, then remove it after."""
+ assert modname not in sys.modules, (modname,)
+ with import_helper.isolated_modules():
+ yield import_helper.add_module(modname)
+
+
+@contextlib.contextmanager
+def missing_defs_module(modname, *, prep=False):
+ assert modname not in sys.modules, (modname,)
+ if prep:
+ with import_helper.ready_to_import(modname, DEFS_TEXT):
+ yield modname
+ else:
+ with import_helper.isolated_modules():
+ yield modname
+
class _GetXIDataTests(unittest.TestCase):
@@ -32,52 +117,49 @@ class _GetXIDataTests(unittest.TestCase):
def get_roundtrip(self, obj, *, mode=None):
mode = self._resolve_mode(mode)
- xid =_testinternalcapi.get_crossinterp_data(obj, mode)
+ return self._get_roundtrip(obj, mode)
+
+ def _get_roundtrip(self, obj, mode):
+ xid = _testinternalcapi.get_crossinterp_data(obj, mode)
return _testinternalcapi.restore_crossinterp_data(xid)
- def iter_roundtrip_values(self, values, *, mode=None):
+ def assert_roundtrip_identical(self, values, *, mode=None):
mode = self._resolve_mode(mode)
for obj in values:
with self.subTest(obj):
- xid = _testinternalcapi.get_crossinterp_data(obj, mode)
- got = _testinternalcapi.restore_crossinterp_data(xid)
- yield obj, got
-
- def assert_roundtrip_identical(self, values, *, mode=None):
- for obj, got in self.iter_roundtrip_values(values, mode=mode):
- # XXX What about between interpreters?
- self.assertIs(got, obj)
+ got = self._get_roundtrip(obj, mode)
+ self.assertIs(got, obj)
def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None):
- for obj, got in self.iter_roundtrip_values(values, mode=mode):
- self.assertEqual(got, obj)
- self.assertIs(type(got),
- type(obj) if expecttype is None else expecttype)
-
-# def assert_roundtrip_equal_not_identical(self, values, *,
-# mode=None, expecttype=None):
-# mode = self._resolve_mode(mode)
-# for obj in values:
-# cls = type(obj)
-# with self.subTest(obj):
-# got = self._get_roundtrip(obj, mode)
-# self.assertIsNot(got, obj)
-# self.assertIs(type(got), type(obj))
-# self.assertEqual(got, obj)
-# self.assertIs(type(got),
-# cls if expecttype is None else expecttype)
-#
-# def assert_roundtrip_not_equal(self, values, *, mode=None, expecttype=None):
-# mode = self._resolve_mode(mode)
-# for obj in values:
-# cls = type(obj)
-# with self.subTest(obj):
-# got = self._get_roundtrip(obj, mode)
-# self.assertIsNot(got, obj)
-# self.assertIs(type(got), type(obj))
-# self.assertNotEqual(got, obj)
-# self.assertIs(type(got),
-# cls if expecttype is None else expecttype)
+ mode = self._resolve_mode(mode)
+ for obj in values:
+ with self.subTest(obj):
+ got = self._get_roundtrip(obj, mode)
+ self.assertEqual(got, obj)
+ self.assertIs(type(got),
+ type(obj) if expecttype is None else expecttype)
+
+ def assert_roundtrip_equal_not_identical(self, values, *,
+ mode=None, expecttype=None):
+ mode = self._resolve_mode(mode)
+ for obj in values:
+ with self.subTest(obj):
+ got = self._get_roundtrip(obj, mode)
+ self.assertIsNot(got, obj)
+ self.assertIs(type(got),
+ type(obj) if expecttype is None else expecttype)
+ self.assertEqual(got, obj)
+
+ def assert_roundtrip_not_equal(self, values, *,
+ mode=None, expecttype=None):
+ mode = self._resolve_mode(mode)
+ for obj in values:
+ with self.subTest(obj):
+ got = self._get_roundtrip(obj, mode)
+ self.assertIsNot(got, obj)
+ self.assertIs(type(got),
+ type(obj) if expecttype is None else expecttype)
+ self.assertNotEqual(got, obj)
def assert_not_shareable(self, values, exctype=None, *, mode=None):
mode = self._resolve_mode(mode)
@@ -95,6 +177,363 @@ class _GetXIDataTests(unittest.TestCase):
return mode
+class PickleTests(_GetXIDataTests):
+
+ MODE = 'pickle'
+
+ def test_shareable(self):
+ self.assert_roundtrip_equal([
+ # singletons
+ None,
+ True,
+ False,
+ # bytes
+ *(i.to_bytes(2, 'little', signed=True)
+ for i in range(-1, 258)),
+ # str
+ 'hello world',
+ '你好世界',
+ '',
+ # int
+ sys.maxsize,
+ -sys.maxsize - 1,
+ *range(-1, 258),
+ # float
+ 0.0,
+ 1.1,
+ -1.0,
+ 0.12345678,
+ -0.12345678,
+ # tuple
+ (),
+ (1,),
+ ("hello", "world", ),
+ (1, True, "hello"),
+ ((1,),),
+ ((1, 2), (3, 4)),
+ ((1, 2), (3, 4), (5, 6)),
+ ])
+ # not shareable using xidata
+ self.assert_roundtrip_equal([
+ # int
+ sys.maxsize + 1,
+ -sys.maxsize - 2,
+ 2**1000,
+ # tuple
+ (0, 1.0, []),
+ (0, 1.0, {}),
+ (0, 1.0, ([],)),
+ (0, 1.0, ({},)),
+ ])
+
+ def test_list(self):
+ self.assert_roundtrip_equal_not_identical([
+ [],
+ [1, 2, 3],
+ [[1], (2,), {3: 4}],
+ ])
+
+ def test_dict(self):
+ self.assert_roundtrip_equal_not_identical([
+ {},
+ {1: 7, 2: 8, 3: 9},
+ {1: [1], 2: (2,), 3: {3: 4}},
+ ])
+
+ def test_set(self):
+ self.assert_roundtrip_equal_not_identical([
+ set(),
+ {1, 2, 3},
+ {frozenset({1}), (2,)},
+ ])
+
+ # classes
+
+ def assert_class_defs_same(self, defs):
+ # Unpickle relative to the unchanged original module.
+ self.assert_roundtrip_identical(defs.TOP_CLASSES)
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ if cls in defs.CLASSES_WITHOUT_EQUALITY:
+ continue
+ instances.append(cls(*args))
+ self.assert_roundtrip_equal_not_identical(instances)
+
+ # these don't compare equal
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ if cls not in defs.CLASSES_WITHOUT_EQUALITY:
+ continue
+ instances.append(cls(*args))
+ self.assert_roundtrip_not_equal(instances)
+
+ def assert_class_defs_other_pickle(self, defs, mod):
+ # Pickle relative to a different module than the original.
+ for cls in defs.TOP_CLASSES:
+ assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__))
+ self.assert_not_shareable(defs.TOP_CLASSES)
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ instances.append(cls(*args))
+ self.assert_not_shareable(instances)
+
+ def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False):
+ # Unpickle relative to a different module than the original.
+ for cls in defs.TOP_CLASSES:
+ assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__))
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ with self.subTest(cls):
+ setattr(mod, cls.__name__, cls)
+ xid = self.get_xidata(cls)
+ inst = cls(*args)
+ instxid = self.get_xidata(inst)
+ instances.append(
+ (cls, xid, inst, instxid))
+
+ for cls, xid, inst, instxid in instances:
+ with self.subTest(cls):
+ delattr(mod, cls.__name__)
+ if fail:
+ with self.assertRaises(NotShareableError):
+ _testinternalcapi.restore_crossinterp_data(xid)
+ continue
+ got = _testinternalcapi.restore_crossinterp_data(xid)
+ self.assertIsNot(got, cls)
+ self.assertNotEqual(got, cls)
+
+ gotcls = got
+ got = _testinternalcapi.restore_crossinterp_data(instxid)
+ self.assertIsNot(got, inst)
+ self.assertIs(type(got), gotcls)
+ if cls in defs.CLASSES_WITHOUT_EQUALITY:
+ self.assertNotEqual(got, inst)
+ elif cls in defs.BUILTIN_SUBCLASSES:
+ self.assertEqual(got, inst)
+ else:
+ self.assertNotEqual(got, inst)
+
+ def assert_class_defs_not_shareable(self, defs):
+ self.assert_not_shareable(defs.TOP_CLASSES)
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ instances.append(cls(*args))
+ self.assert_not_shareable(instances)
+
+ def test_user_class_normal(self):
+ self.assert_class_defs_same(defs)
+
+ def test_user_class_in___main__(self):
+ with using___main__() as mod:
+ defs = load_defs(mod)
+ self.assert_class_defs_same(defs)
+
+ def test_user_class_not_in___main___with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_not_in___main___without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_not_in___main___unpickle_with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_class_defs_other_unpickle(defs, mod)
+
+ def test_user_class_not_in___main___unpickle_without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_class_defs_other_unpickle(defs, mod, fail=True)
+
+ def test_user_class_in_module(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod)
+ self.assert_class_defs_same(defs)
+
+ def test_user_class_not_in_module_with_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ assert defs.__file__
+ # For now, we only address this case for __main__.
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_not_in_module_without_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ defs.__file__ = None
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_module_missing_then_imported(self):
+ with missing_defs_module('__spam__', prep=True) as modname:
+ defs = load_defs(modname)
+ # For now, we only address this case for __main__.
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_module_missing_not_available(self):
+ with missing_defs_module('__spam__') as modname:
+ defs = load_defs(modname)
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_nested_class(self):
+ eggs = defs.EggsNested()
+ with self.assertRaises(NotShareableError):
+ self.get_roundtrip(eggs)
+
+ # functions
+
+ def assert_func_defs_same(self, defs):
+ # Unpickle relative to the unchanged original module.
+ self.assert_roundtrip_identical(defs.TOP_FUNCTIONS)
+
+ def assert_func_defs_other_pickle(self, defs, mod):
+ # Pickle relative to a different module than the original.
+ for func in defs.TOP_FUNCTIONS:
+ assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__))
+ self.assert_not_shareable(defs.TOP_FUNCTIONS)
+
+ def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False):
+ # Unpickle relative to a different module than the original.
+ for func in defs.TOP_FUNCTIONS:
+ assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__))
+
+ captured = []
+ for func in defs.TOP_FUNCTIONS:
+ with self.subTest(func):
+ setattr(mod, func.__name__, func)
+ xid = self.get_xidata(func)
+ captured.append(
+ (func, xid))
+
+ for func, xid in captured:
+ with self.subTest(func):
+ delattr(mod, func.__name__)
+ if fail:
+ with self.assertRaises(NotShareableError):
+ _testinternalcapi.restore_crossinterp_data(xid)
+ continue
+ got = _testinternalcapi.restore_crossinterp_data(xid)
+ self.assertIsNot(got, func)
+ self.assertNotEqual(got, func)
+
+ def assert_func_defs_not_shareable(self, defs):
+ self.assert_not_shareable(defs.TOP_FUNCTIONS)
+
+ def test_user_function_normal(self):
+# self.assert_roundtrip_equal(defs.TOP_FUNCTIONS)
+ self.assert_func_defs_same(defs)
+
+ def test_user_func_in___main__(self):
+ with using___main__() as mod:
+ defs = load_defs(mod)
+ self.assert_func_defs_same(defs)
+
+ def test_user_func_not_in___main___with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_not_in___main___without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_not_in___main___unpickle_with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_func_defs_other_unpickle(defs, mod)
+
+ def test_user_func_not_in___main___unpickle_without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_func_defs_other_unpickle(defs, mod, fail=True)
+
+ def test_user_func_in_module(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod)
+ self.assert_func_defs_same(defs)
+
+ def test_user_func_not_in_module_with_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ assert defs.__file__
+ # For now, we only address this case for __main__.
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_not_in_module_without_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ defs.__file__ = None
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_module_missing_then_imported(self):
+ with missing_defs_module('__spam__', prep=True) as modname:
+ defs = load_defs(modname)
+ # For now, we only address this case for __main__.
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_module_missing_not_available(self):
+ with missing_defs_module('__spam__') as modname:
+ defs = load_defs(modname)
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_nested_function(self):
+ self.assert_not_shareable(defs.NESTED_FUNCTIONS)
+
+ # exceptions
+
+ def test_user_exception_normal(self):
+ self.assert_roundtrip_not_equal([
+ defs.MimimalError('error!'),
+ ])
+ self.assert_roundtrip_equal_not_identical([
+ defs.RichError('error!', 42),
+ ])
+
+ def test_builtin_exception(self):
+ msg = 'error!'
+ try:
+ raise Exception
+ except Exception as exc:
+ caught = exc
+ special = {
+ BaseExceptionGroup: (msg, [caught]),
+ ExceptionGroup: (msg, [caught]),
+# UnicodeError: (None, msg, None, None, None),
+ UnicodeEncodeError: ('utf-8', '', 1, 3, msg),
+ UnicodeDecodeError: ('utf-8', b'', 1, 3, msg),
+ UnicodeTranslateError: ('', 1, 3, msg),
+ }
+ exceptions = []
+ for cls in EXCEPTION_TYPES:
+ args = special.get(cls) or (msg,)
+ exceptions.append(cls(*args))
+
+ self.assert_roundtrip_not_equal(exceptions)
+
+
class MarshalTests(_GetXIDataTests):
MODE = 'marshal'
@@ -444,22 +883,12 @@ class ShareableTypeTests(_GetXIDataTests):
])
def test_class(self):
- self.assert_not_shareable([
- defs.Spam,
- defs.SpamOkay,
- defs.SpamFull,
- defs.SubSpamFull,
- defs.SubTuple,
- defs.EggsNested,
- ])
- self.assert_not_shareable([
- defs.Spam(),
- defs.SpamOkay(),
- defs.SpamFull(1, 2, 3),
- defs.SubSpamFull(1, 2, 3),
- defs.SubTuple([1, 2, 3]),
- defs.EggsNested(),
- ])
+ self.assert_not_shareable(defs.CLASSES)
+
+ instances = []
+ for cls, args in defs.CLASSES.items():
+ instances.append(cls(*args))
+ self.assert_not_shareable(instances)
def test_builtin_type(self):
self.assert_not_shareable([