diff options
Diffstat (limited to 'Lib/test/test_crossinterp.py')
-rw-r--r-- | Lib/test/test_crossinterp.py | 541 |
1 files changed, 485 insertions, 56 deletions
diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py index 5ebb78b..32d6fd4 100644 --- a/Lib/test/test_crossinterp.py +++ b/Lib/test/test_crossinterp.py @@ -1,3 +1,6 @@ +import contextlib +import importlib +import importlib.util import itertools import sys import types @@ -9,7 +12,7 @@ _testinternalcapi = import_helper.import_module('_testinternalcapi') _interpreters = import_helper.import_module('_interpreters') from _interpreters import NotShareableError - +from test import _code_definitions as code_defs from test import _crossinterp_definitions as defs @@ -21,6 +24,88 @@ OTHER_TYPES = [o for n, o in vars(types).items() if (isinstance(o, type) and n not in ('DynamicClassAttribute', '_GeneratorWrapper'))] +DEFS = defs +with open(code_defs.__file__) as infile: + _code_defs_text = infile.read() +with open(DEFS.__file__) as infile: + _defs_text = infile.read() + _defs_text = _defs_text.replace('from ', '# from ') +DEFS_TEXT = f""" +####################################### +# from {code_defs.__file__} + +{_code_defs_text} + +####################################### +# from {defs.__file__} + +{_defs_text} +""" +del infile, _code_defs_text, _defs_text + + +def load_defs(module=None): + """Return a new copy of the test._crossinterp_definitions module. + + The module's __name__ matches the "module" arg, which is either + a str or a module. + + If the "module" arg is a module then the just-loaded defs are also + copied into that module. + + Note that the new module is not added to sys.modules. + """ + if module is None: + modname = DEFS.__name__ + elif isinstance(module, str): + modname = module + module = None + else: + modname = module.__name__ + # Create the new module and populate it. + defs = import_helper.create_module(modname) + defs.__file__ = DEFS.__file__ + exec(DEFS_TEXT, defs.__dict__) + # Copy the defs into the module arg, if any. + if module is not None: + for name, value in defs.__dict__.items(): + if name.startswith('_'): + continue + assert not hasattr(module, name), (name, getattr(module, name)) + setattr(module, name, value) + return defs + + +@contextlib.contextmanager +def using___main__(): + """Make sure __main__ module exists (and clean up after).""" + modname = '__main__' + if modname not in sys.modules: + with import_helper.isolated_modules(): + yield import_helper.add_module(modname) + else: + with import_helper.module_restored(modname) as mod: + yield mod + + +@contextlib.contextmanager +def temp_module(modname): + """Create the module and add to sys.modules, then remove it after.""" + assert modname not in sys.modules, (modname,) + with import_helper.isolated_modules(): + yield import_helper.add_module(modname) + + +@contextlib.contextmanager +def missing_defs_module(modname, *, prep=False): + assert modname not in sys.modules, (modname,) + if prep: + with import_helper.ready_to_import(modname, DEFS_TEXT): + yield modname + else: + with import_helper.isolated_modules(): + yield modname + class _GetXIDataTests(unittest.TestCase): @@ -32,52 +117,49 @@ class _GetXIDataTests(unittest.TestCase): def get_roundtrip(self, obj, *, mode=None): mode = self._resolve_mode(mode) - xid =_testinternalcapi.get_crossinterp_data(obj, mode) + return self._get_roundtrip(obj, mode) + + def _get_roundtrip(self, obj, mode): + xid = _testinternalcapi.get_crossinterp_data(obj, mode) return _testinternalcapi.restore_crossinterp_data(xid) - def iter_roundtrip_values(self, values, *, mode=None): + def assert_roundtrip_identical(self, values, *, mode=None): mode = self._resolve_mode(mode) for obj in values: with self.subTest(obj): - xid = _testinternalcapi.get_crossinterp_data(obj, mode) - got = _testinternalcapi.restore_crossinterp_data(xid) - yield obj, got - - def assert_roundtrip_identical(self, values, *, mode=None): - for obj, got in self.iter_roundtrip_values(values, mode=mode): - # XXX What about between interpreters? - self.assertIs(got, obj) + got = self._get_roundtrip(obj, mode) + self.assertIs(got, obj) def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None): - for obj, got in self.iter_roundtrip_values(values, mode=mode): - self.assertEqual(got, obj) - self.assertIs(type(got), - type(obj) if expecttype is None else expecttype) - -# def assert_roundtrip_equal_not_identical(self, values, *, -# mode=None, expecttype=None): -# mode = self._resolve_mode(mode) -# for obj in values: -# cls = type(obj) -# with self.subTest(obj): -# got = self._get_roundtrip(obj, mode) -# self.assertIsNot(got, obj) -# self.assertIs(type(got), type(obj)) -# self.assertEqual(got, obj) -# self.assertIs(type(got), -# cls if expecttype is None else expecttype) -# -# def assert_roundtrip_not_equal(self, values, *, mode=None, expecttype=None): -# mode = self._resolve_mode(mode) -# for obj in values: -# cls = type(obj) -# with self.subTest(obj): -# got = self._get_roundtrip(obj, mode) -# self.assertIsNot(got, obj) -# self.assertIs(type(got), type(obj)) -# self.assertNotEqual(got, obj) -# self.assertIs(type(got), -# cls if expecttype is None else expecttype) + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertEqual(got, obj) + self.assertIs(type(got), + type(obj) if expecttype is None else expecttype) + + def assert_roundtrip_equal_not_identical(self, values, *, + mode=None, expecttype=None): + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertIsNot(got, obj) + self.assertIs(type(got), + type(obj) if expecttype is None else expecttype) + self.assertEqual(got, obj) + + def assert_roundtrip_not_equal(self, values, *, + mode=None, expecttype=None): + mode = self._resolve_mode(mode) + for obj in values: + with self.subTest(obj): + got = self._get_roundtrip(obj, mode) + self.assertIsNot(got, obj) + self.assertIs(type(got), + type(obj) if expecttype is None else expecttype) + self.assertNotEqual(got, obj) def assert_not_shareable(self, values, exctype=None, *, mode=None): mode = self._resolve_mode(mode) @@ -95,6 +177,363 @@ class _GetXIDataTests(unittest.TestCase): return mode +class PickleTests(_GetXIDataTests): + + MODE = 'pickle' + + def test_shareable(self): + self.assert_roundtrip_equal([ + # singletons + None, + True, + False, + # bytes + *(i.to_bytes(2, 'little', signed=True) + for i in range(-1, 258)), + # str + 'hello world', + '你好世界', + '', + # int + sys.maxsize, + -sys.maxsize - 1, + *range(-1, 258), + # float + 0.0, + 1.1, + -1.0, + 0.12345678, + -0.12345678, + # tuple + (), + (1,), + ("hello", "world", ), + (1, True, "hello"), + ((1,),), + ((1, 2), (3, 4)), + ((1, 2), (3, 4), (5, 6)), + ]) + # not shareable using xidata + self.assert_roundtrip_equal([ + # int + sys.maxsize + 1, + -sys.maxsize - 2, + 2**1000, + # tuple + (0, 1.0, []), + (0, 1.0, {}), + (0, 1.0, ([],)), + (0, 1.0, ({},)), + ]) + + def test_list(self): + self.assert_roundtrip_equal_not_identical([ + [], + [1, 2, 3], + [[1], (2,), {3: 4}], + ]) + + def test_dict(self): + self.assert_roundtrip_equal_not_identical([ + {}, + {1: 7, 2: 8, 3: 9}, + {1: [1], 2: (2,), 3: {3: 4}}, + ]) + + def test_set(self): + self.assert_roundtrip_equal_not_identical([ + set(), + {1, 2, 3}, + {frozenset({1}), (2,)}, + ]) + + # classes + + def assert_class_defs_same(self, defs): + # Unpickle relative to the unchanged original module. + self.assert_roundtrip_identical(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + if cls in defs.CLASSES_WITHOUT_EQUALITY: + continue + instances.append(cls(*args)) + self.assert_roundtrip_equal_not_identical(instances) + + # these don't compare equal + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + if cls not in defs.CLASSES_WITHOUT_EQUALITY: + continue + instances.append(cls(*args)) + self.assert_roundtrip_not_equal(instances) + + def assert_class_defs_other_pickle(self, defs, mod): + # Pickle relative to a different module than the original. + for cls in defs.TOP_CLASSES: + assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__)) + self.assert_not_shareable(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) + + def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False): + # Unpickle relative to a different module than the original. + for cls in defs.TOP_CLASSES: + assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__)) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + with self.subTest(cls): + setattr(mod, cls.__name__, cls) + xid = self.get_xidata(cls) + inst = cls(*args) + instxid = self.get_xidata(inst) + instances.append( + (cls, xid, inst, instxid)) + + for cls, xid, inst, instxid in instances: + with self.subTest(cls): + delattr(mod, cls.__name__) + if fail: + with self.assertRaises(NotShareableError): + _testinternalcapi.restore_crossinterp_data(xid) + continue + got = _testinternalcapi.restore_crossinterp_data(xid) + self.assertIsNot(got, cls) + self.assertNotEqual(got, cls) + + gotcls = got + got = _testinternalcapi.restore_crossinterp_data(instxid) + self.assertIsNot(got, inst) + self.assertIs(type(got), gotcls) + if cls in defs.CLASSES_WITHOUT_EQUALITY: + self.assertNotEqual(got, inst) + elif cls in defs.BUILTIN_SUBCLASSES: + self.assertEqual(got, inst) + else: + self.assertNotEqual(got, inst) + + def assert_class_defs_not_shareable(self, defs): + self.assert_not_shareable(defs.TOP_CLASSES) + + instances = [] + for cls, args in defs.TOP_CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) + + def test_user_class_normal(self): + self.assert_class_defs_same(defs) + + def test_user_class_in___main__(self): + with using___main__() as mod: + defs = load_defs(mod) + self.assert_class_defs_same(defs) + + def test_user_class_not_in___main___with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in___main___without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in___main___unpickle_with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_class_defs_other_unpickle(defs, mod) + + def test_user_class_not_in___main___unpickle_without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_class_defs_other_unpickle(defs, mod, fail=True) + + def test_user_class_in_module(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod) + self.assert_class_defs_same(defs) + + def test_user_class_not_in_module_with_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + assert defs.__file__ + # For now, we only address this case for __main__. + self.assert_class_defs_not_shareable(defs) + + def test_user_class_not_in_module_without_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + defs.__file__ = None + self.assert_class_defs_not_shareable(defs) + + def test_user_class_module_missing_then_imported(self): + with missing_defs_module('__spam__', prep=True) as modname: + defs = load_defs(modname) + # For now, we only address this case for __main__. + self.assert_class_defs_not_shareable(defs) + + def test_user_class_module_missing_not_available(self): + with missing_defs_module('__spam__') as modname: + defs = load_defs(modname) + self.assert_class_defs_not_shareable(defs) + + def test_nested_class(self): + eggs = defs.EggsNested() + with self.assertRaises(NotShareableError): + self.get_roundtrip(eggs) + + # functions + + def assert_func_defs_same(self, defs): + # Unpickle relative to the unchanged original module. + self.assert_roundtrip_identical(defs.TOP_FUNCTIONS) + + def assert_func_defs_other_pickle(self, defs, mod): + # Pickle relative to a different module than the original. + for func in defs.TOP_FUNCTIONS: + assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__)) + self.assert_not_shareable(defs.TOP_FUNCTIONS) + + def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False): + # Unpickle relative to a different module than the original. + for func in defs.TOP_FUNCTIONS: + assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__)) + + captured = [] + for func in defs.TOP_FUNCTIONS: + with self.subTest(func): + setattr(mod, func.__name__, func) + xid = self.get_xidata(func) + captured.append( + (func, xid)) + + for func, xid in captured: + with self.subTest(func): + delattr(mod, func.__name__) + if fail: + with self.assertRaises(NotShareableError): + _testinternalcapi.restore_crossinterp_data(xid) + continue + got = _testinternalcapi.restore_crossinterp_data(xid) + self.assertIsNot(got, func) + self.assertNotEqual(got, func) + + def assert_func_defs_not_shareable(self, defs): + self.assert_not_shareable(defs.TOP_FUNCTIONS) + + def test_user_function_normal(self): +# self.assert_roundtrip_equal(defs.TOP_FUNCTIONS) + self.assert_func_defs_same(defs) + + def test_user_func_in___main__(self): + with using___main__() as mod: + defs = load_defs(mod) + self.assert_func_defs_same(defs) + + def test_user_func_not_in___main___with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in___main___without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in___main___unpickle_with_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + assert defs.__file__ + mod.__file__ = defs.__file__ + self.assert_func_defs_other_unpickle(defs, mod) + + def test_user_func_not_in___main___unpickle_without_filename(self): + with using___main__() as mod: + defs = load_defs('__main__') + defs.__file__ = None + mod.__file__ = None + self.assert_func_defs_other_unpickle(defs, mod, fail=True) + + def test_user_func_in_module(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod) + self.assert_func_defs_same(defs) + + def test_user_func_not_in_module_with_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + assert defs.__file__ + # For now, we only address this case for __main__. + self.assert_func_defs_not_shareable(defs) + + def test_user_func_not_in_module_without_filename(self): + with temp_module('__spam__') as mod: + defs = load_defs(mod.__name__) + defs.__file__ = None + self.assert_func_defs_not_shareable(defs) + + def test_user_func_module_missing_then_imported(self): + with missing_defs_module('__spam__', prep=True) as modname: + defs = load_defs(modname) + # For now, we only address this case for __main__. + self.assert_func_defs_not_shareable(defs) + + def test_user_func_module_missing_not_available(self): + with missing_defs_module('__spam__') as modname: + defs = load_defs(modname) + self.assert_func_defs_not_shareable(defs) + + def test_nested_function(self): + self.assert_not_shareable(defs.NESTED_FUNCTIONS) + + # exceptions + + def test_user_exception_normal(self): + self.assert_roundtrip_not_equal([ + defs.MimimalError('error!'), + ]) + self.assert_roundtrip_equal_not_identical([ + defs.RichError('error!', 42), + ]) + + def test_builtin_exception(self): + msg = 'error!' + try: + raise Exception + except Exception as exc: + caught = exc + special = { + BaseExceptionGroup: (msg, [caught]), + ExceptionGroup: (msg, [caught]), +# UnicodeError: (None, msg, None, None, None), + UnicodeEncodeError: ('utf-8', '', 1, 3, msg), + UnicodeDecodeError: ('utf-8', b'', 1, 3, msg), + UnicodeTranslateError: ('', 1, 3, msg), + } + exceptions = [] + for cls in EXCEPTION_TYPES: + args = special.get(cls) or (msg,) + exceptions.append(cls(*args)) + + self.assert_roundtrip_not_equal(exceptions) + + class MarshalTests(_GetXIDataTests): MODE = 'marshal' @@ -444,22 +883,12 @@ class ShareableTypeTests(_GetXIDataTests): ]) def test_class(self): - self.assert_not_shareable([ - defs.Spam, - defs.SpamOkay, - defs.SpamFull, - defs.SubSpamFull, - defs.SubTuple, - defs.EggsNested, - ]) - self.assert_not_shareable([ - defs.Spam(), - defs.SpamOkay(), - defs.SpamFull(1, 2, 3), - defs.SubSpamFull(1, 2, 3), - defs.SubTuple([1, 2, 3]), - defs.EggsNested(), - ]) + self.assert_not_shareable(defs.CLASSES) + + instances = [] + for cls, args in defs.CLASSES.items(): + instances.append(cls(*args)) + self.assert_not_shareable(instances) def test_builtin_type(self): self.assert_not_shareable([ |