diff options
Diffstat (limited to 'Lib/test/test_pickle.py')
-rw-r--r-- | Lib/test/test_pickle.py | 152 |
1 files changed, 151 insertions, 1 deletions
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index e1a88b6..0159b18 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -1,3 +1,6 @@ +from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING, + NAME_MAPPING, REVERSE_NAME_MAPPING) +import builtins import pickle import io import collections @@ -207,9 +210,156 @@ if has_c_implementation: check(u, stdsize + 32 * P + 2 + 1) +ALT_IMPORT_MAPPING = { + ('_elementtree', 'xml.etree.ElementTree'), + ('cPickle', 'pickle'), +} + +ALT_NAME_MAPPING = { + ('__builtin__', 'basestring', 'builtins', 'str'), + ('exceptions', 'StandardError', 'builtins', 'Exception'), + ('UserDict', 'UserDict', 'collections', 'UserDict'), + ('socket', '_socketobject', 'socket', 'SocketType'), +} + +def mapping(module, name): + if (module, name) in NAME_MAPPING: + module, name = NAME_MAPPING[(module, name)] + elif module in IMPORT_MAPPING: + module = IMPORT_MAPPING[module] + return module, name + +def reverse_mapping(module, name): + if (module, name) in REVERSE_NAME_MAPPING: + module, name = REVERSE_NAME_MAPPING[(module, name)] + elif module in REVERSE_IMPORT_MAPPING: + module = REVERSE_IMPORT_MAPPING[module] + return module, name + +def getmodule(module): + try: + return sys.modules[module] + except KeyError: + __import__(module) + return sys.modules[module] + +def getattribute(module, name): + obj = getmodule(module) + for n in name.split('.'): + obj = getattr(obj, n) + return obj + +def get_exceptions(mod): + for name in dir(mod): + attr = getattr(mod, name) + if isinstance(attr, type) and issubclass(attr, BaseException): + yield name, attr + +class CompatPickleTests(unittest.TestCase): + def test_import(self): + modules = set(IMPORT_MAPPING.values()) + modules |= set(REVERSE_IMPORT_MAPPING) + modules |= {module for module, name in REVERSE_NAME_MAPPING} + modules |= {module for module, name in NAME_MAPPING.values()} + for module in modules: + try: + getmodule(module) + except ImportError as exc: + if support.verbose: + print(exc) + + def test_import_mapping(self): + for module3, module2 in REVERSE_IMPORT_MAPPING.items(): + with self.subTest((module3, module2)): + try: + getmodule(module3) + except ImportError as exc: + if support.verbose: + print(exc) + if module3[:1] != '_': + self.assertIn(module2, IMPORT_MAPPING) + self.assertEqual(IMPORT_MAPPING[module2], module3) + + def test_name_mapping(self): + for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items(): + with self.subTest(((module3, name3), (module2, name2))): + attr = getattribute(module3, name3) + if (module2, name2) == ('exceptions', 'OSError'): + self.assertTrue(issubclass(attr, OSError)) + else: + module, name = mapping(module2, name2) + if module3[:1] != '_': + self.assertEqual((module, name), (module3, name3)) + self.assertEqual(getattribute(module, name), attr) + + def test_reverse_import_mapping(self): + for module2, module3 in IMPORT_MAPPING.items(): + with self.subTest((module2, module3)): + try: + getmodule(module3) + except ImportError as exc: + if support.verbose: + print(exc) + if ((module2, module3) not in ALT_IMPORT_MAPPING and + REVERSE_IMPORT_MAPPING.get(module3, None) != module2): + for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items(): + if (module3, module2) == (m3, m2): + break + else: + self.fail('No reverse mapping from %r to %r' % + (module3, module2)) + module = REVERSE_IMPORT_MAPPING.get(module3, module3) + module = IMPORT_MAPPING.get(module, module) + self.assertEqual(module, module3) + + def test_reverse_name_mapping(self): + for (module2, name2), (module3, name3) in NAME_MAPPING.items(): + with self.subTest(((module2, name2), (module3, name3))): + attr = getattribute(module3, name3) + module, name = reverse_mapping(module3, name3) + if (module2, name2, module3, name3) not in ALT_NAME_MAPPING: + self.assertEqual((module, name), (module2, name2)) + module, name = mapping(module, name) + self.assertEqual((module, name), (module3, name3)) + + def test_exceptions(self): + self.assertEqual(mapping('exceptions', 'StandardError'), + ('builtins', 'Exception')) + self.assertEqual(mapping('exceptions', 'Exception'), + ('builtins', 'Exception')) + self.assertEqual(reverse_mapping('builtins', 'Exception'), + ('exceptions', 'Exception')) + self.assertEqual(mapping('exceptions', 'OSError'), + ('builtins', 'OSError')) + self.assertEqual(reverse_mapping('builtins', 'OSError'), + ('exceptions', 'OSError')) + + for name, exc in get_exceptions(builtins): + with self.subTest(name): + if exc in (BlockingIOError, ResourceWarning): + continue + if exc is not OSError and issubclass(exc, OSError): + self.assertEqual(reverse_mapping('builtins', name), + ('exceptions', 'OSError')) + else: + self.assertEqual(reverse_mapping('builtins', name), + ('exceptions', name)) + self.assertEqual(mapping('exceptions', name), + ('builtins', name)) + + import multiprocessing.context + for name, exc in get_exceptions(multiprocessing.context): + with self.subTest(name): + self.assertEqual(reverse_mapping('multiprocessing.context', name), + ('multiprocessing', name)) + self.assertEqual(mapping('multiprocessing', name), + ('multiprocessing.context', name)) + + def test_main(): tests = [PickleTests, PyPicklerTests, PyPersPicklerTests, - PyDispatchTableTests, PyChainDispatchTableTests] + PyDispatchTableTests, PyChainDispatchTableTests, + CompatPickleTests] if has_c_implementation: tests.extend([CPicklerTests, CPersPicklerTests, CDumpPickle_LoadPickle, DumpPickle_CLoadPickle, |