summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_pickle.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_pickle.py')
-rw-r--r--Lib/test/test_pickle.py152
1 files changed, 151 insertions, 1 deletions
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py
index e1a88b6..0159b18 100644
--- a/Lib/test/test_pickle.py
+++ b/Lib/test/test_pickle.py
@@ -1,3 +1,6 @@
+from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING,
+ NAME_MAPPING, REVERSE_NAME_MAPPING)
+import builtins
import pickle
import io
import collections
@@ -207,9 +210,156 @@ if has_c_implementation:
check(u, stdsize + 32 * P + 2 + 1)
+ALT_IMPORT_MAPPING = {
+ ('_elementtree', 'xml.etree.ElementTree'),
+ ('cPickle', 'pickle'),
+}
+
+ALT_NAME_MAPPING = {
+ ('__builtin__', 'basestring', 'builtins', 'str'),
+ ('exceptions', 'StandardError', 'builtins', 'Exception'),
+ ('UserDict', 'UserDict', 'collections', 'UserDict'),
+ ('socket', '_socketobject', 'socket', 'SocketType'),
+}
+
+def mapping(module, name):
+ if (module, name) in NAME_MAPPING:
+ module, name = NAME_MAPPING[(module, name)]
+ elif module in IMPORT_MAPPING:
+ module = IMPORT_MAPPING[module]
+ return module, name
+
+def reverse_mapping(module, name):
+ if (module, name) in REVERSE_NAME_MAPPING:
+ module, name = REVERSE_NAME_MAPPING[(module, name)]
+ elif module in REVERSE_IMPORT_MAPPING:
+ module = REVERSE_IMPORT_MAPPING[module]
+ return module, name
+
+def getmodule(module):
+ try:
+ return sys.modules[module]
+ except KeyError:
+ __import__(module)
+ return sys.modules[module]
+
+def getattribute(module, name):
+ obj = getmodule(module)
+ for n in name.split('.'):
+ obj = getattr(obj, n)
+ return obj
+
+def get_exceptions(mod):
+ for name in dir(mod):
+ attr = getattr(mod, name)
+ if isinstance(attr, type) and issubclass(attr, BaseException):
+ yield name, attr
+
+class CompatPickleTests(unittest.TestCase):
+ def test_import(self):
+ modules = set(IMPORT_MAPPING.values())
+ modules |= set(REVERSE_IMPORT_MAPPING)
+ modules |= {module for module, name in REVERSE_NAME_MAPPING}
+ modules |= {module for module, name in NAME_MAPPING.values()}
+ for module in modules:
+ try:
+ getmodule(module)
+ except ImportError as exc:
+ if support.verbose:
+ print(exc)
+
+ def test_import_mapping(self):
+ for module3, module2 in REVERSE_IMPORT_MAPPING.items():
+ with self.subTest((module3, module2)):
+ try:
+ getmodule(module3)
+ except ImportError as exc:
+ if support.verbose:
+ print(exc)
+ if module3[:1] != '_':
+ self.assertIn(module2, IMPORT_MAPPING)
+ self.assertEqual(IMPORT_MAPPING[module2], module3)
+
+ def test_name_mapping(self):
+ for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items():
+ with self.subTest(((module3, name3), (module2, name2))):
+ attr = getattribute(module3, name3)
+ if (module2, name2) == ('exceptions', 'OSError'):
+ self.assertTrue(issubclass(attr, OSError))
+ else:
+ module, name = mapping(module2, name2)
+ if module3[:1] != '_':
+ self.assertEqual((module, name), (module3, name3))
+ self.assertEqual(getattribute(module, name), attr)
+
+ def test_reverse_import_mapping(self):
+ for module2, module3 in IMPORT_MAPPING.items():
+ with self.subTest((module2, module3)):
+ try:
+ getmodule(module3)
+ except ImportError as exc:
+ if support.verbose:
+ print(exc)
+ if ((module2, module3) not in ALT_IMPORT_MAPPING and
+ REVERSE_IMPORT_MAPPING.get(module3, None) != module2):
+ for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items():
+ if (module3, module2) == (m3, m2):
+ break
+ else:
+ self.fail('No reverse mapping from %r to %r' %
+ (module3, module2))
+ module = REVERSE_IMPORT_MAPPING.get(module3, module3)
+ module = IMPORT_MAPPING.get(module, module)
+ self.assertEqual(module, module3)
+
+ def test_reverse_name_mapping(self):
+ for (module2, name2), (module3, name3) in NAME_MAPPING.items():
+ with self.subTest(((module2, name2), (module3, name3))):
+ attr = getattribute(module3, name3)
+ module, name = reverse_mapping(module3, name3)
+ if (module2, name2, module3, name3) not in ALT_NAME_MAPPING:
+ self.assertEqual((module, name), (module2, name2))
+ module, name = mapping(module, name)
+ self.assertEqual((module, name), (module3, name3))
+
+ def test_exceptions(self):
+ self.assertEqual(mapping('exceptions', 'StandardError'),
+ ('builtins', 'Exception'))
+ self.assertEqual(mapping('exceptions', 'Exception'),
+ ('builtins', 'Exception'))
+ self.assertEqual(reverse_mapping('builtins', 'Exception'),
+ ('exceptions', 'Exception'))
+ self.assertEqual(mapping('exceptions', 'OSError'),
+ ('builtins', 'OSError'))
+ self.assertEqual(reverse_mapping('builtins', 'OSError'),
+ ('exceptions', 'OSError'))
+
+ for name, exc in get_exceptions(builtins):
+ with self.subTest(name):
+ if exc in (BlockingIOError, ResourceWarning):
+ continue
+ if exc is not OSError and issubclass(exc, OSError):
+ self.assertEqual(reverse_mapping('builtins', name),
+ ('exceptions', 'OSError'))
+ else:
+ self.assertEqual(reverse_mapping('builtins', name),
+ ('exceptions', name))
+ self.assertEqual(mapping('exceptions', name),
+ ('builtins', name))
+
+ import multiprocessing.context
+ for name, exc in get_exceptions(multiprocessing.context):
+ with self.subTest(name):
+ self.assertEqual(reverse_mapping('multiprocessing.context', name),
+ ('multiprocessing', name))
+ self.assertEqual(mapping('multiprocessing', name),
+ ('multiprocessing.context', name))
+
+
def test_main():
tests = [PickleTests, PyPicklerTests, PyPersPicklerTests,
- PyDispatchTableTests, PyChainDispatchTableTests]
+ PyDispatchTableTests, PyChainDispatchTableTests,
+ CompatPickleTests]
if has_c_implementation:
tests.extend([CPicklerTests, CPersPicklerTests,
CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,