diff options
Diffstat (limited to 'Lib/test/pickletester.py')
| -rw-r--r-- | Lib/test/pickletester.py | 155 |
1 files changed, 135 insertions, 20 deletions
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index b948c55..7922b54 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -1815,16 +1815,62 @@ class AbstractPickleTests(unittest.TestCase): self.assertGreaterEqual(num_additems, 2) def test_simple_newobj(self): - x = object.__new__(SimpleNewObj) # avoid __init__ + x = SimpleNewObj.__new__(SimpleNewObj, 0xface) # avoid __init__ x.abc = 666 for proto in protocols: - s = self.dumps(x, proto) - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), - 2 <= proto < 4) - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s), - proto >= 4) - y = self.loads(s) # will raise TypeError if __init__ called - self.assert_is_copy(x, y) + with self.subTest(proto=proto): + s = self.dumps(x, proto) + if proto < 1: + self.assertIn(b'\nL64206', s) # LONG + else: + self.assertIn(b'M\xce\xfa', s) # BININT2 + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto) + self.assertFalse(opcode_in_pickle(pickle.NEWOBJ_EX, s)) + y = self.loads(s) # will raise TypeError if __init__ called + self.assert_is_copy(x, y) + + def test_complex_newobj(self): + x = ComplexNewObj.__new__(ComplexNewObj, 0xface) # avoid __init__ + x.abc = 666 + for proto in protocols: + with self.subTest(proto=proto): + s = self.dumps(x, proto) + if proto < 1: + self.assertIn(b'\nL64206', s) # LONG + elif proto < 2: + self.assertIn(b'M\xce\xfa', s) # BININT2 + elif proto < 4: + self.assertIn(b'X\x04\x00\x00\x00FACE', s) # BINUNICODE + else: + self.assertIn(b'\x8c\x04FACE', s) # SHORT_BINUNICODE + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto) + self.assertFalse(opcode_in_pickle(pickle.NEWOBJ_EX, s)) + y = self.loads(s) # will raise TypeError if __init__ called + self.assert_is_copy(x, y) + + def test_complex_newobj_ex(self): + x = ComplexNewObjEx.__new__(ComplexNewObjEx, 0xface) # avoid __init__ + x.abc = 666 + for proto in protocols: + with self.subTest(proto=proto): + if 2 <= proto < 4: + self.assertRaises(ValueError, self.dumps, x, proto) + continue + s = self.dumps(x, proto) + if proto < 1: + self.assertIn(b'\nL64206', s) # LONG + elif proto < 2: + self.assertIn(b'M\xce\xfa', s) # BININT2 + else: + assert proto >= 4 + self.assertIn(b'\x8c\x04FACE', s) # SHORT_BINUNICODE + self.assertFalse(opcode_in_pickle(pickle.NEWOBJ, s)) + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s), + 4 <= proto) + y = self.loads(s) # will raise TypeError if __init__ called + self.assert_is_copy(x, y) def test_newobj_list_slots(self): x = SlotList([1, 2, 3]) @@ -1898,15 +1944,15 @@ class AbstractPickleTests(unittest.TestCase): # 5th item is not an iterator return dict, (), None, None, [] - # Protocol 0 is less strict and also accept iterables. + # Python implementation is less strict and also accepts iterables. for proto in protocols: try: self.dumps(C(), proto) - except (pickle.PickleError): + except pickle.PicklingError: pass try: self.dumps(D(), proto) - except (pickle.PickleError): + except pickle.PicklingError: pass def test_many_puts_and_gets(self): @@ -2088,13 +2134,24 @@ class AbstractPickleTests(unittest.TestCase): class B: class C: pass - - for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): for obj in [Nested.A, Nested.A.B, Nested.A.B.C]: with self.subTest(proto=proto, obj=obj): unpickled = self.loads(self.dumps(obj, proto)) self.assertIs(obj, unpickled) + def test_recursive_nested_names(self): + global Recursive + class Recursive: + pass + Recursive.mod = sys.modules[Recursive.__module__] + Recursive.__qualname__ = 'Recursive.mod.Recursive' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + unpickled = self.loads(self.dumps(Recursive, proto)) + self.assertIs(unpickled, Recursive) + del Recursive.mod # break reference loop + def test_py_methods(self): global PyMethodsTest class PyMethodsTest: @@ -2133,7 +2190,7 @@ class AbstractPickleTests(unittest.TestCase): (PyMethodsTest.biscuits, PyMethodsTest), (PyMethodsTest.Nested.pie, PyMethodsTest.Nested) ) - for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): for method in py_methods: with self.subTest(proto=proto, method=method): unpickled = self.loads(self.dumps(method, proto)) @@ -2173,7 +2230,7 @@ class AbstractPickleTests(unittest.TestCase): (Subclass.Nested("sweet").count, ("e",)), (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")), ) - for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): for method, args in c_methods: with self.subTest(proto=proto, method=method): unpickled = self.loads(self.dumps(method, proto)) @@ -2197,6 +2254,27 @@ class AbstractPickleTests(unittest.TestCase): self.assertIn(('c%s\n%s' % (mod, name)).encode(), pickled) self.assertIs(type(self.loads(pickled)), type(val)) + def test_local_lookup_error(self): + # Test that whichmodule() errors out cleanly when looking up + # an assumed globally-reachable object fails. + def f(): + pass + # Since the function is local, lookup will fail + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises((AttributeError, pickle.PicklingError)): + pickletools.dis(self.dumps(f, proto)) + # Same without a __module__ attribute (exercises a different path + # in _pickle.c). + del f.__module__ + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises((AttributeError, pickle.PicklingError)): + pickletools.dis(self.dumps(f, proto)) + # Yet a different path. + f.__name__ = f.__qualname__ + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises((AttributeError, pickle.PicklingError)): + pickletools.dis(self.dumps(f, proto)) + class BigmemPickleTests(unittest.TestCase): @@ -2370,7 +2448,7 @@ class REX_six(object): def __init__(self, items=None): self.items = items if items is not None else [] def __eq__(self, other): - return type(self) is type(other) and self.items == self.items + return type(self) is type(other) and self.items == other.items def append(self, item): self.items.append(item) def __reduce__(self): @@ -2383,7 +2461,7 @@ class REX_seven(object): def __init__(self, table=None): self.table = table if table is not None else {} def __eq__(self, other): - return type(self) is type(other) and self.table == self.table + return type(self) is type(other) and self.table == other.table def __setitem__(self, key, value): self.table[key] = value def __reduce__(self): @@ -2431,12 +2509,20 @@ myclasses = [MyInt, MyFloat, class SlotList(MyList): __slots__ = ["foo"] -class SimpleNewObj(object): - def __init__(self, a, b, c): +class SimpleNewObj(int): + def __init__(self, *args, **kwargs): # raise an error, to make sure this isn't called raise TypeError("SimpleNewObj.__init__() didn't expect to get called") def __eq__(self, other): - return self.__dict__ == other.__dict__ + return int(self) == int(other) and self.__dict__ == other.__dict__ + +class ComplexNewObj(SimpleNewObj): + def __getnewargs__(self): + return ('%X' % self, 16) + +class ComplexNewObjEx(SimpleNewObj): + def __getnewargs_ex__(self): + return ('%X' % self,), {'base': 16} class BadGetattr: def __getattr__(self, key): @@ -2543,6 +2629,35 @@ class AbstractPersistentPicklerTests(unittest.TestCase): self.assertEqual(self.load_false_count, 1) +class AbstractIdentityPersistentPicklerTests(unittest.TestCase): + + def persistent_id(self, obj): + return obj + + def persistent_load(self, pid): + return pid + + def _check_return_correct_type(self, obj, proto): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertIsInstance(unpickled, type(obj)) + self.assertEqual(unpickled, obj) + + def test_return_correct_type(self): + for proto in protocols: + # Protocol 0 supports only ASCII strings. + if proto == 0: + self._check_return_correct_type("abc", 0) + else: + for obj in [b"abc\n", "abc\n", -1, -1.1 * 0.1, str]: + self._check_return_correct_type(obj, proto) + + def test_protocol0_is_ascii_only(self): + non_ascii_str = "\N{EMPTY SET}" + self.assertRaises(pickle.PicklingError, self.dumps, non_ascii_str, 0) + pickled = pickle.PERSID + non_ascii_str.encode('utf-8') + b'\n.' + self.assertRaises(pickle.UnpicklingError, self.loads, pickled) + + class AbstractPicklerUnpicklerObjectTests(unittest.TestCase): pickler_class = None |
