From 5e130a8da4e4f13444ec20dfe88a3e2e070005ca Mon Sep 17 00:00:00 2001 From: Matthew Rahtz Date: Fri, 22 Apr 2022 05:22:53 +0100 Subject: bpo-43224: Implement pickling of TypeVarTuples (#32119) Co-authored-by: Jelle Zijlstra --- Lib/test/test_typing.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++++- Lib/typing.py | 25 ++++++++++++++++------ 2 files changed, 74 insertions(+), 7 deletions(-) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index d480847..1fd99a0 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -1,7 +1,7 @@ import contextlib import collections from collections import defaultdict -from functools import lru_cache +from functools import lru_cache, wraps import inspect import pickle import re @@ -70,6 +70,18 @@ class BaseTestCase(TestCase): f() +def all_pickle_protocols(test_func): + """Runs `test_func` with various values for `proto` argument.""" + + @wraps(test_func) + def wrapper(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(pickle_proto=proto): + test_func(self, proto=proto) + + return wrapper + + class Employee: pass @@ -911,6 +923,48 @@ class TypeVarTupleTests(BaseTestCase): self.assertNotEqual(C[Unpack[Ts1]], C[Unpack[Ts2]]) +class TypeVarTuplePicklingTests(BaseTestCase): + # These are slightly awkward tests to run, because TypeVarTuples are only + # picklable if defined in the global scope. We therefore need to push + # various things defined in these tests into the global scope with `global` + # statements at the start of each test. + + @all_pickle_protocols + def test_pickling_then_unpickling_results_in_same_identity(self, proto): + global Ts1 # See explanation at start of class. + Ts1 = TypeVarTuple('Ts1') + Ts2 = pickle.loads(pickle.dumps(Ts1, proto)) + self.assertIs(Ts1, Ts2) + + @all_pickle_protocols + def test_pickling_then_unpickling_unpacked_results_in_same_identity(self, proto): + global Ts # See explanation at start of class. + Ts = TypeVarTuple('Ts') + unpacked1 = Unpack[Ts] + unpacked2 = pickle.loads(pickle.dumps(unpacked1, proto)) + self.assertIs(unpacked1, unpacked2) + + @all_pickle_protocols + def test_pickling_then_unpickling_tuple_with_typevartuple_equality( + self, proto + ): + global T, Ts # See explanation at start of class. + T = TypeVar('T') + Ts = TypeVarTuple('Ts') + + a1 = Tuple[Unpack[Ts]] + a2 = pickle.loads(pickle.dumps(a1, proto)) + self.assertEqual(a1, a2) + + a1 = Tuple[T, Unpack[Ts]] + a2 = pickle.loads(pickle.dumps(a1, proto)) + self.assertEqual(a1, a2) + + a1 = Tuple[int, Unpack[Ts]] + a2 = pickle.loads(pickle.dumps(a1, proto)) + self.assertEqual(a1, a2) + + class UnionTests(BaseTestCase): def test_basics(self): diff --git a/Lib/typing.py b/Lib/typing.py index 3e0fbdb..a6f4fa9 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -867,6 +867,13 @@ def _is_typevar_like(x: Any) -> bool: return isinstance(x, (TypeVar, ParamSpec)) or _is_unpacked_typevartuple(x) +class _PickleUsingNameMixin: + """Mixin enabling pickling based on self.__name__.""" + + def __reduce__(self): + return self.__name__ + + class _BoundVarianceMixin: """Mixin giving __init__ bound and variance arguments. @@ -903,11 +910,9 @@ class _BoundVarianceMixin: prefix = '~' return prefix + self.__name__ - def __reduce__(self): - return self.__name__ - -class TypeVar(_Final, _Immutable, _BoundVarianceMixin, _root=True): +class TypeVar(_Final, _Immutable, _BoundVarianceMixin, _PickleUsingNameMixin, + _root=True): """Type variable. Usage:: @@ -973,7 +978,7 @@ class TypeVar(_Final, _Immutable, _BoundVarianceMixin, _root=True): return arg -class TypeVarTuple(_Final, _Immutable, _root=True): +class TypeVarTuple(_Final, _Immutable, _PickleUsingNameMixin, _root=True): """Type variable tuple. Usage: @@ -994,11 +999,18 @@ class TypeVarTuple(_Final, _Immutable, _root=True): C[()] # Even this is fine For more details, see PEP 646. + + Note that only TypeVarTuples defined in global scope can be pickled. """ def __init__(self, name): self.__name__ = name + # Used for pickling. + def_mod = _caller() + if def_mod != 'typing': + self.__module__ = def_mod + def __iter__(self): yield Unpack[self] @@ -1057,7 +1069,8 @@ class ParamSpecKwargs(_Final, _Immutable, _root=True): return self.__origin__ == other.__origin__ -class ParamSpec(_Final, _Immutable, _BoundVarianceMixin, _root=True): +class ParamSpec(_Final, _Immutable, _BoundVarianceMixin, _PickleUsingNameMixin, + _root=True): """Parameter specification variable. Usage:: -- cgit v0.12