diff options
Diffstat (limited to 'Lib/test/test_pickle.py')
| -rw-r--r-- | Lib/test/test_pickle.py | 279 |
1 files changed, 274 insertions, 5 deletions
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index fbe96ac..f6d9cc0 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -1,9 +1,16 @@ +from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING, + NAME_MAPPING, REVERSE_NAME_MAPPING) +import builtins import pickle import io import collections +import struct +import sys +import unittest from test import support +from test.pickletester import AbstractUnpickleTests from test.pickletester import AbstractPickleTests from test.pickletester import AbstractPickleModuleTests from test.pickletester import AbstractPersistentPicklerTests @@ -22,6 +29,22 @@ class PickleTests(AbstractPickleModuleTests): pass +class PyUnpicklerTests(AbstractUnpickleTests): + + unpickler = pickle._Unpickler + bad_stack_errors = (IndexError,) + bad_mark_errors = (IndexError, pickle.UnpicklingError, + TypeError, AttributeError, EOFError) + truncated_errors = (pickle.UnpicklingError, EOFError, + AttributeError, ValueError, + struct.error, IndexError, ImportError) + + def loads(self, buf, **kwds): + f = io.BytesIO(buf) + u = self.unpickler(f, **kwds) + return u.load() + + class PyPicklerTests(AbstractPickleTests): pickler = pickle._Pickler @@ -40,10 +63,17 @@ class PyPicklerTests(AbstractPickleTests): return u.load() -class InMemoryPickleTests(AbstractPickleTests, BigmemPickleTests): +class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests, + BigmemPickleTests): pickler = pickle._Pickler unpickler = pickle._Unpickler + bad_stack_errors = (pickle.UnpicklingError, IndexError) + bad_mark_errors = (pickle.UnpicklingError, IndexError, + TypeError, AttributeError, EOFError) + truncated_errors = (pickle.UnpicklingError, EOFError, + AttributeError, ValueError, + struct.error, IndexError, ImportError) def dumps(self, arg, protocol=None): return pickle.dumps(arg, protocol) @@ -83,18 +113,29 @@ class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests): class PyDispatchTableTests(AbstractDispatchTableTests): + pickler_class = pickle._Pickler + def get_dispatch_table(self): return pickle.dispatch_table.copy() class PyChainDispatchTableTests(AbstractDispatchTableTests): + pickler_class = pickle._Pickler + def get_dispatch_table(self): return collections.ChainMap({}, pickle.dispatch_table) if has_c_implementation: + class CUnpicklerTests(PyUnpicklerTests): + unpickler = _pickle.Unpickler + bad_stack_errors = (pickle.UnpicklingError,) + bad_mark_errors = (EOFError,) + truncated_errors = (pickle.UnpicklingError, EOFError, + AttributeError, ValueError) + class CPicklerTests(PyPicklerTests): pickler = _pickle.Pickler unpickler = _pickle.Unpickler @@ -134,17 +175,245 @@ if has_c_implementation: def get_dispatch_table(self): return collections.ChainMap({}, pickle.dispatch_table) + @support.cpython_only + class SizeofTests(unittest.TestCase): + check_sizeof = support.check_sizeof + + def test_pickler(self): + basesize = support.calcobjsize('5P2n3i2n3iP') + p = _pickle.Pickler(io.BytesIO()) + self.assertEqual(object.__sizeof__(p), basesize) + MT_size = struct.calcsize('3nP0n') + ME_size = struct.calcsize('Pn0P') + check = self.check_sizeof + check(p, basesize + + MT_size + 8 * ME_size + # Minimal memo table size. + sys.getsizeof(b'x'*4096)) # Minimal write buffer size. + for i in range(6): + p.dump(chr(i)) + check(p, basesize + + MT_size + 32 * ME_size + # Size of memo table required to + # save references to 6 objects. + 0) # Write buffer is cleared after every dump(). + + def test_unpickler(self): + basesize = support.calcobjsize('2Pn2P 2P2n2i5P 2P3n6P2n2i') + unpickler = _pickle.Unpickler + P = struct.calcsize('P') # Size of memo table entry. + n = struct.calcsize('n') # Size of mark table entry. + check = self.check_sizeof + for encoding in 'ASCII', 'UTF-16', 'latin-1': + for errors in 'strict', 'replace': + u = unpickler(io.BytesIO(), + encoding=encoding, errors=errors) + self.assertEqual(object.__sizeof__(u), basesize) + check(u, basesize + + 32 * P + # Minimal memo table size. + len(encoding) + 1 + len(errors) + 1) + + stdsize = basesize + len('ASCII') + 1 + len('strict') + 1 + def check_unpickler(data, memo_size, marks_size): + dump = pickle.dumps(data) + u = unpickler(io.BytesIO(dump), + encoding='ASCII', errors='strict') + u.load() + check(u, stdsize + memo_size * P + marks_size * n) + + check_unpickler(0, 32, 0) + # 20 is minimal non-empty mark stack size. + check_unpickler([0] * 100, 32, 20) + # 128 is memo table size required to save references to 100 objects. + check_unpickler([chr(i) for i in range(100)], 128, 20) + def recurse(deep): + data = 0 + for i in range(deep): + data = [data, data] + return data + check_unpickler(recurse(0), 32, 0) + check_unpickler(recurse(1), 32, 20) + check_unpickler(recurse(20), 32, 58) + check_unpickler(recurse(50), 64, 58) + check_unpickler(recurse(100), 128, 134) + + u = unpickler(io.BytesIO(pickle.dumps('a', 0)), + encoding='ASCII', errors='strict') + u.load() + check(u, stdsize + 32 * P + 2 + 1) + + +ALT_IMPORT_MAPPING = { + ('_elementtree', 'xml.etree.ElementTree'), + ('cPickle', 'pickle'), +} + +ALT_NAME_MAPPING = { + ('__builtin__', 'basestring', 'builtins', 'str'), + ('exceptions', 'StandardError', 'builtins', 'Exception'), + ('UserDict', 'UserDict', 'collections', 'UserDict'), + ('socket', '_socketobject', 'socket', 'SocketType'), +} + +def mapping(module, name): + if (module, name) in NAME_MAPPING: + module, name = NAME_MAPPING[(module, name)] + elif module in IMPORT_MAPPING: + module = IMPORT_MAPPING[module] + return module, name + +def reverse_mapping(module, name): + if (module, name) in REVERSE_NAME_MAPPING: + module, name = REVERSE_NAME_MAPPING[(module, name)] + elif module in REVERSE_IMPORT_MAPPING: + module = REVERSE_IMPORT_MAPPING[module] + return module, name + +def getmodule(module): + try: + return sys.modules[module] + except KeyError: + try: + __import__(module) + except AttributeError as exc: + if support.verbose: + print("Can't import module %r: %s" % (module, exc)) + raise ImportError + except ImportError as exc: + if support.verbose: + print(exc) + raise + return sys.modules[module] + +def getattribute(module, name): + obj = getmodule(module) + for n in name.split('.'): + obj = getattr(obj, n) + return obj + +def get_exceptions(mod): + for name in dir(mod): + attr = getattr(mod, name) + if isinstance(attr, type) and issubclass(attr, BaseException): + yield name, attr + +class CompatPickleTests(unittest.TestCase): + def test_import(self): + modules = set(IMPORT_MAPPING.values()) + modules |= set(REVERSE_IMPORT_MAPPING) + modules |= {module for module, name in REVERSE_NAME_MAPPING} + modules |= {module for module, name in NAME_MAPPING.values()} + for module in modules: + try: + getmodule(module) + except ImportError: + pass + + def test_import_mapping(self): + for module3, module2 in REVERSE_IMPORT_MAPPING.items(): + with self.subTest((module3, module2)): + try: + getmodule(module3) + except ImportError: + pass + if module3[:1] != '_': + self.assertIn(module2, IMPORT_MAPPING) + self.assertEqual(IMPORT_MAPPING[module2], module3) + + def test_name_mapping(self): + for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items(): + with self.subTest(((module3, name3), (module2, name2))): + if (module2, name2) == ('exceptions', 'OSError'): + attr = getattribute(module3, name3) + self.assertTrue(issubclass(attr, OSError)) + else: + module, name = mapping(module2, name2) + if module3[:1] != '_': + self.assertEqual((module, name), (module3, name3)) + try: + attr = getattribute(module3, name3) + except ImportError: + pass + else: + self.assertEqual(getattribute(module, name), attr) + + def test_reverse_import_mapping(self): + for module2, module3 in IMPORT_MAPPING.items(): + with self.subTest((module2, module3)): + try: + getmodule(module3) + except ImportError as exc: + if support.verbose: + print(exc) + if ((module2, module3) not in ALT_IMPORT_MAPPING and + REVERSE_IMPORT_MAPPING.get(module3, None) != module2): + for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items(): + if (module3, module2) == (m3, m2): + break + else: + self.fail('No reverse mapping from %r to %r' % + (module3, module2)) + module = REVERSE_IMPORT_MAPPING.get(module3, module3) + module = IMPORT_MAPPING.get(module, module) + self.assertEqual(module, module3) + + def test_reverse_name_mapping(self): + for (module2, name2), (module3, name3) in NAME_MAPPING.items(): + with self.subTest(((module2, name2), (module3, name3))): + try: + attr = getattribute(module3, name3) + except ImportError: + pass + module, name = reverse_mapping(module3, name3) + if (module2, name2, module3, name3) not in ALT_NAME_MAPPING: + self.assertEqual((module, name), (module2, name2)) + module, name = mapping(module, name) + self.assertEqual((module, name), (module3, name3)) + + def test_exceptions(self): + self.assertEqual(mapping('exceptions', 'StandardError'), + ('builtins', 'Exception')) + self.assertEqual(mapping('exceptions', 'Exception'), + ('builtins', 'Exception')) + self.assertEqual(reverse_mapping('builtins', 'Exception'), + ('exceptions', 'Exception')) + self.assertEqual(mapping('exceptions', 'OSError'), + ('builtins', 'OSError')) + self.assertEqual(reverse_mapping('builtins', 'OSError'), + ('exceptions', 'OSError')) + + for name, exc in get_exceptions(builtins): + with self.subTest(name): + if exc in (BlockingIOError, ResourceWarning): + continue + if exc is not OSError and issubclass(exc, OSError): + self.assertEqual(reverse_mapping('builtins', name), + ('exceptions', 'OSError')) + else: + self.assertEqual(reverse_mapping('builtins', name), + ('exceptions', name)) + self.assertEqual(mapping('exceptions', name), + ('builtins', name)) + + def test_multiprocessing_exceptions(self): + module = support.import_module('multiprocessing.context') + for name, exc in get_exceptions(module): + with self.subTest(name): + self.assertEqual(reverse_mapping('multiprocessing.context', name), + ('multiprocessing', name)) + self.assertEqual(mapping('multiprocessing', name), + ('multiprocessing.context', name)) + def test_main(): - tests = [PickleTests, PyPicklerTests, PyPersPicklerTests, - PyDispatchTableTests, PyChainDispatchTableTests] + tests = [PickleTests, PyUnpicklerTests, PyPicklerTests, PyPersPicklerTests, + PyDispatchTableTests, PyChainDispatchTableTests, + CompatPickleTests] if has_c_implementation: - tests.extend([CPicklerTests, CPersPicklerTests, + tests.extend([CUnpicklerTests, CPicklerTests, CPersPicklerTests, CDumpPickle_LoadPickle, DumpPickle_CLoadPickle, PyPicklerUnpicklerObjectTests, CPicklerUnpicklerObjectTests, CDispatchTableTests, CChainDispatchTableTests, - InMemoryPickleTests]) + InMemoryPickleTests, SizeofTests]) support.run_unittest(*tests) support.run_doctest(pickle) |
