From 5b1b9eacb92dd47d10793a8868246df6ea477ed6 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Fri, 11 Mar 2022 21:43:58 +0200 Subject: bpo-43224: Implement substitution of unpacked TypeVarTuple (GH-31800) --- Lib/test/test_typing.py | 143 ++++++++++++++++++++++++++++-------------------- Lib/typing.py | 49 +++++++++++------ 2 files changed, 117 insertions(+), 75 deletions(-) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 91b2e77..a693665 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -411,6 +411,10 @@ class UnpackTests(BaseTestCase): class TypeVarTupleTests(BaseTestCase): + def assertEndsWith(self, string, tail): + if not string.endswith(tail): + self.fail(f"String {string!r} does not end with {tail!r}") + def test_instance_is_equal_to_itself(self): Ts = TypeVarTuple('Ts') self.assertEqual(Ts, Ts) @@ -457,6 +461,56 @@ class TypeVarTupleTests(BaseTestCase): self.assertEqual(t2.__args__, (Unpack[Ts],)) self.assertEqual(t2.__parameters__, (Ts,)) + def test_var_substitution(self): + Ts = TypeVarTuple('Ts') + T = TypeVar('T') + T2 = TypeVar('T2') + class G(Generic[Unpack[Ts]]): pass + + for A in G, Tuple: + B = A[Unpack[Ts]] + if A != Tuple: + self.assertEqual(B[()], A[()]) + self.assertEqual(B[float], A[float]) + self.assertEqual(B[float, str], A[float, str]) + + C = List[A[Unpack[Ts]]] + if A != Tuple: + self.assertEqual(C[()], List[A[()]]) + self.assertEqual(C[float], List[A[float]]) + self.assertEqual(C[float, str], List[A[float, str]]) + + D = A[T, Unpack[Ts], T2] + with self.assertRaises(TypeError): + D[()] + with self.assertRaises(TypeError): + D[float] + self.assertEqual(D[float, str], A[float, str]) + self.assertEqual(D[float, str, int], A[float, str, int]) + self.assertEqual(D[float, str, int, bytes], A[float, str, int, bytes]) + + E = Tuple[List[T], A[Unpack[Ts]], List[T2]] + with self.assertRaises(TypeError): + E[()] + with self.assertRaises(TypeError): + E[float] + if A != Tuple: + self.assertEqual(E[float, str], + Tuple[List[float], A[()], List[str]]) + self.assertEqual(E[float, str, int], + Tuple[List[float], A[str], List[int]]) + self.assertEqual(E[float, str, int, bytes], + Tuple[List[float], A[str, int], List[bytes]]) + + def test_repr_is_correct(self): + Ts = TypeVarTuple('Ts') + self.assertEqual(repr(Ts), 'Ts') + self.assertEqual(repr(Unpack[Ts]), '*Ts') + self.assertEqual(repr(tuple[Unpack[Ts]]), 'tuple[*Ts]') + self.assertEqual(repr(Tuple[Unpack[Ts]]), 'typing.Tuple[*Ts]') + self.assertEqual(repr(Unpack[tuple[Unpack[Ts]]]), '*tuple[*Ts]') + self.assertEqual(repr(Unpack[Tuple[Unpack[Ts]]]), '*typing.Tuple[*Ts]') + def test_repr_is_correct(self): Ts = TypeVarTuple('Ts') self.assertEqual(repr(Ts), 'Ts') @@ -470,78 +524,51 @@ class TypeVarTupleTests(BaseTestCase): Ts = TypeVarTuple('Ts') class A(Generic[Unpack[Ts]]): pass - self.assertTrue(repr(A[()]).endswith('A[()]')) - self.assertTrue(repr(A[float]).endswith('A[float]')) - self.assertTrue(repr(A[float, str]).endswith('A[float, str]')) - self.assertTrue(repr( - A[Unpack[tuple[int, ...]]] - ).endswith( - 'A[*tuple[int, ...]]' - )) - self.assertTrue(repr( - A[float, Unpack[tuple[int, ...]]] - ).endswith( - 'A[float, *tuple[int, ...]]' - )) - self.assertTrue(repr( - A[Unpack[tuple[int, ...]], str] - ).endswith( - 'A[*tuple[int, ...], str]' - )) - self.assertTrue(repr( - A[float, Unpack[tuple[int, ...]], str] - ).endswith( - 'A[float, *tuple[int, ...], str]' - )) + self.assertEndsWith(repr(A[()]), 'A[()]') + self.assertEndsWith(repr(A[float]), 'A[float]') + self.assertEndsWith(repr(A[float, str]), 'A[float, str]') + self.assertEndsWith(repr(A[Unpack[tuple[int, ...]]]), + 'A[*tuple[int, ...]]') + self.assertEndsWith(repr(A[float, Unpack[tuple[int, ...]]]), + 'A[float, *tuple[int, ...]]') + self.assertEndsWith(repr(A[Unpack[tuple[int, ...]], str]), + 'A[*tuple[int, ...], str]') + self.assertEndsWith(repr(A[float, Unpack[tuple[int, ...]], str]), + 'A[float, *tuple[int, ...], str]') def test_variadic_class_alias_repr_is_correct(self): Ts = TypeVarTuple('Ts') class A(Generic[Unpack[Ts]]): pass B = A[Unpack[Ts]] - self.assertTrue(repr(B).endswith('A[*Ts]')) - with self.assertRaises(NotImplementedError): - B[()] - with self.assertRaises(NotImplementedError): - B[float] - with self.assertRaises(NotImplementedError): - B[float, str] + self.assertEndsWith(repr(B), 'A[*Ts]') + self.assertEndsWith(repr(B[()]), 'A[()]') + self.assertEndsWith(repr(B[float]), 'A[float]') + self.assertEndsWith(repr(B[float, str]), 'A[float, str]') C = A[Unpack[Ts], int] - self.assertTrue(repr(C).endswith('A[*Ts, int]')) - with self.assertRaises(NotImplementedError): - C[()] - with self.assertRaises(NotImplementedError): - C[float] - with self.assertRaises(NotImplementedError): - C[float, str] + self.assertEndsWith(repr(C), 'A[*Ts, int]') + self.assertEndsWith(repr(C[()]), 'A[int]') + self.assertEndsWith(repr(C[float]), 'A[float, int]') + self.assertEndsWith(repr(C[float, str]), 'A[float, str, int]') D = A[int, Unpack[Ts]] - self.assertTrue(repr(D).endswith('A[int, *Ts]')) - with self.assertRaises(NotImplementedError): - D[()] - with self.assertRaises(NotImplementedError): - D[float] - with self.assertRaises(NotImplementedError): - D[float, str] + self.assertEndsWith(repr(D), 'A[int, *Ts]') + self.assertEndsWith(repr(D[()]), 'A[int]') + self.assertEndsWith(repr(D[float]), 'A[int, float]') + self.assertEndsWith(repr(D[float, str]), 'A[int, float, str]') E = A[int, Unpack[Ts], str] - self.assertTrue(repr(E).endswith('A[int, *Ts, str]')) - with self.assertRaises(NotImplementedError): - E[()] - with self.assertRaises(NotImplementedError): - E[float] - with self.assertRaises(NotImplementedError): - E[float, bool] + self.assertEndsWith(repr(E), 'A[int, *Ts, str]') + self.assertEndsWith(repr(E[()]), 'A[int, str]') + self.assertEndsWith(repr(E[float]), 'A[int, float, str]') + self.assertEndsWith(repr(E[float, str]), 'A[int, float, str, str]') F = A[Unpack[Ts], Unpack[tuple[str, ...]]] - self.assertTrue(repr(F).endswith('A[*Ts, *tuple[str, ...]]')) - with self.assertRaises(NotImplementedError): - F[()] - with self.assertRaises(NotImplementedError): - F[float] - with self.assertRaises(NotImplementedError): - F[float, int] + self.assertEndsWith(repr(F), 'A[*Ts, *tuple[str, ...]]') + self.assertEndsWith(repr(F[()]), 'A[*tuple[str, ...]]') + self.assertEndsWith(repr(F[float]), 'A[float, *tuple[str, ...]]') + self.assertEndsWith(repr(F[float, str]), 'A[float, str, *tuple[str, ...]]') def test_cannot_subclass_class(self): with self.assertRaises(TypeError): diff --git a/Lib/typing.py b/Lib/typing.py index 062c01e..842554f 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -1297,30 +1297,39 @@ class _GenericAlias(_BaseGenericAlias, _root=True): # anything more exotic than a plain `TypeVar`, we need to consider # edge cases. - if any(isinstance(p, TypeVarTuple) for p in self.__parameters__): - raise NotImplementedError( - "Type substitution for TypeVarTuples is not yet implemented" - ) + params = self.__parameters__ # In the example above, this would be {T3: str} - new_arg_by_param = dict(zip(self.__parameters__, args)) + new_arg_by_param = {} + for i, param in enumerate(params): + if isinstance(param, TypeVarTuple): + j = len(args) - (len(params) - i - 1) + if j < i: + raise TypeError(f"Too few arguments for {self}") + new_arg_by_param.update(zip(params[:i], args[:i])) + new_arg_by_param[param] = args[i: j] + new_arg_by_param.update(zip(params[i + 1:], args[j:])) + break + else: + new_arg_by_param.update(zip(params, args)) new_args = [] for old_arg in self.__args__: - if _is_unpacked_typevartuple(old_arg): - original_typevartuple = old_arg.__parameters__[0] - new_arg = new_arg_by_param[original_typevartuple] + substfunc = getattr(old_arg, '__typing_subst__', None) + if substfunc: + new_arg = substfunc(new_arg_by_param[old_arg]) else: - substfunc = getattr(old_arg, '__typing_subst__', None) - if substfunc: - new_arg = substfunc(new_arg_by_param[old_arg]) + subparams = getattr(old_arg, '__parameters__', ()) + if not subparams: + new_arg = old_arg else: - subparams = getattr(old_arg, '__parameters__', ()) - if not subparams: - new_arg = old_arg - else: - subargs = tuple(new_arg_by_param[x] for x in subparams) - new_arg = old_arg[subargs] + subargs = [] + for x in subparams: + if isinstance(x, TypeVarTuple): + subargs.extend(new_arg_by_param[x]) + else: + subargs.append(new_arg_by_param[x]) + new_arg = old_arg[tuple(subargs)] if self.__origin__ == collections.abc.Callable and isinstance(new_arg, tuple): # Consider the following `Callable`. @@ -1612,6 +1621,12 @@ class _UnpackGenericAlias(_GenericAlias, _root=True): # a single item. return '*' + repr(self.__args__[0]) + def __getitem__(self, args): + if (len(self.__parameters__) == 1 and + isinstance(self.__parameters__[0], TypeVarTuple)): + return args + return super().__getitem__(args) + class Generic: """Abstract base class for generic types. -- cgit v0.12