diff options
author | Ken Jin <28750310+Fidget-Spinner@users.noreply.github.com> | 2021-07-19 14:22:59 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-19 14:22:59 (GMT) |
commit | a2721649598eb715304a1ac8678a409585f73b27 (patch) | |
tree | 4a562541e287da5190b8348fa2c3a991a5742e91 /Lib | |
parent | 37bdd2221ce3607a81d5d7fafc4603d95ca3e8cb (diff) | |
download | cpython-a2721649598eb715304a1ac8678a409585f73b27.zip cpython-a2721649598eb715304a1ac8678a409585f73b27.tar.gz cpython-a2721649598eb715304a1ac8678a409585f73b27.tar.bz2 |
bpo-44490: Improve typing module compatibility with types.Union (GH-27048) (#27222)
(cherry picked from commit bf89ff96e6ba21bb52b8597b5e51e8ffc57e6589)
Co-authored-by: Yurii Karabas <1998uriyyo@gmail.com>
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/ann_module.py | 2 | ||||
-rw-r--r-- | Lib/test/test_grammar.py | 2 | ||||
-rw-r--r-- | Lib/test/test_typing.py | 22 | ||||
-rw-r--r-- | Lib/typing.py | 18 |
4 files changed, 37 insertions, 7 deletions
diff --git a/Lib/test/ann_module.py b/Lib/test/ann_module.py index 0567d6d..5081e6b 100644 --- a/Lib/test/ann_module.py +++ b/Lib/test/ann_module.py @@ -58,3 +58,5 @@ def dec(func): def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper + +u: int | float diff --git a/Lib/test/test_grammar.py b/Lib/test/test_grammar.py index c0820fd..b6c4574 100644 --- a/Lib/test/test_grammar.py +++ b/Lib/test/test_grammar.py @@ -473,7 +473,7 @@ class GrammarTests(unittest.TestCase): def test_var_annot_module_semantics(self): self.assertEqual(test.__annotations__, {}) self.assertEqual(ann_module.__annotations__, - {1: 2, 'x': int, 'y': str, 'f': typing.Tuple[int, int]}) + {1: 2, 'x': int, 'y': str, 'f': typing.Tuple[int, int], 'u': int | float}) self.assertEqual(ann_module.M.__annotations__, {'123': 123, 'o': type}) self.assertEqual(ann_module2.__annotations__, {}) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 0c72784..5602150 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -315,6 +315,8 @@ class UnionTests(BaseTestCase): self.assertEqual(repr(u), 'typing.Union[typing.List[int], int]') u = Union[list[int], dict[str, float]] self.assertEqual(repr(u), 'typing.Union[list[int], dict[str, float]]') + u = Union[int | float] + self.assertEqual(repr(u), 'typing.Union[int, float]') def test_cannot_subclass(self): with self.assertRaises(TypeError): @@ -1449,6 +1451,8 @@ class GenericTests(BaseTestCase): with self.assertRaises(TypeError): issubclass(SM1, SimpleMapping) self.assertIsInstance(SM1(), SimpleMapping) + T = TypeVar("T") + self.assertEqual(List[list[T] | float].__parameters__, (T,)) def test_generic_errors(self): T = TypeVar('T') @@ -1785,6 +1789,7 @@ class GenericTests(BaseTestCase): def test_generic_forward_ref(self): def foobar(x: List[List['CC']]): ... def foobar2(x: list[list[ForwardRef('CC')]]): ... + def foobar3(x: list[ForwardRef('CC | int')] | int): ... class CC: ... self.assertEqual( get_type_hints(foobar, globals(), locals()), @@ -1794,6 +1799,10 @@ class GenericTests(BaseTestCase): get_type_hints(foobar2, globals(), locals()), {'x': list[list[CC]]} ) + self.assertEqual( + get_type_hints(foobar3, globals(), locals()), + {'x': list[CC | int] | int} + ) T = TypeVar('T') AT = Tuple[T, ...] @@ -2467,6 +2476,12 @@ class ForwardRefTests(BaseTestCase): self.assertEqual(get_type_hints(foo, globals(), locals()), {'a': Union[T]}) + def foo(a: tuple[ForwardRef('T')] | int): + pass + + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': tuple[T] | int}) + def test_tuple_forward(self): def foo(a: Tuple['T']): @@ -2851,7 +2866,7 @@ class GetTypeHintTests(BaseTestCase): gth(None) def test_get_type_hints_modules(self): - ann_module_type_hints = {1: 2, 'f': Tuple[int, int], 'x': int, 'y': str} + ann_module_type_hints = {1: 2, 'f': Tuple[int, int], 'x': int, 'y': str, 'u': int | float} self.assertEqual(gth(ann_module), ann_module_type_hints) self.assertEqual(gth(ann_module2), {}) self.assertEqual(gth(ann_module3), {}) @@ -4393,6 +4408,9 @@ class ParamSpecTests(BaseTestCase): self.assertNotIn(P, list[P].__parameters__) self.assertIn(T, tuple[T, P].__parameters__) + self.assertNotIn(P, (list[P] | int).__parameters__) + self.assertIn(T, (tuple[T, P] | int).__parameters__) + def test_paramspec_in_nested_generics(self): # Although ParamSpec should not be found in __parameters__ of most # generics, they probably should be found when nested in @@ -4402,8 +4420,10 @@ class ParamSpecTests(BaseTestCase): C1 = Callable[P, T] G1 = List[C1] G2 = list[C1] + G3 = list[C1] | int self.assertEqual(G1.__parameters__, (P, T)) self.assertEqual(G2.__parameters__, (P, T)) + self.assertEqual(G3.__parameters__, (P, T)) class ConcatenateTests(BaseTestCase): diff --git a/Lib/typing.py b/Lib/typing.py index 508f4b6..660ad35 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -195,7 +195,7 @@ def _type_repr(obj): return repr(obj) -def _collect_type_vars(types, typevar_types=None): +def _collect_type_vars(types_, typevar_types=None): """Collect all type variable contained in types in order of first appearance (lexicographic order). For example:: @@ -204,10 +204,10 @@ def _collect_type_vars(types, typevar_types=None): if typevar_types is None: typevar_types = TypeVar tvars = [] - for t in types: + for t in types_: if isinstance(t, typevar_types) and t not in tvars: tvars.append(t) - if isinstance(t, (_GenericAlias, GenericAlias)): + if isinstance(t, (_GenericAlias, GenericAlias, types.Union)): tvars.extend([t for t in t.__parameters__ if t not in tvars]) return tuple(tvars) @@ -314,12 +314,14 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()): """ if isinstance(t, ForwardRef): return t._evaluate(globalns, localns, recursive_guard) - if isinstance(t, (_GenericAlias, GenericAlias)): + if isinstance(t, (_GenericAlias, GenericAlias, types.Union)): ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__) if ev_args == t.__args__: return t if isinstance(t, GenericAlias): return GenericAlias(t.__origin__, ev_args) + if isinstance(t, types.Union): + return functools.reduce(operator.or_, ev_args) else: return t.copy_with(ev_args) return t @@ -1013,7 +1015,7 @@ class _GenericAlias(_BaseGenericAlias, _root=True): for arg in self.__args__: if isinstance(arg, self._typevar_types): arg = subst[arg] - elif isinstance(arg, (_GenericAlias, GenericAlias)): + elif isinstance(arg, (_GenericAlias, GenericAlias, types.Union)): subparams = arg.__parameters__ if subparams: subargs = tuple(subst[x] for x in subparams) @@ -1779,6 +1781,12 @@ def _strip_annotations(t): if stripped_args == t.__args__: return t return GenericAlias(t.__origin__, stripped_args) + if isinstance(t, types.Union): + stripped_args = tuple(_strip_annotations(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return functools.reduce(operator.or_, stripped_args) + return t |