summaryrefslogtreecommitdiffstats
path: root/Lib/typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/typing.py')
-rw-r--r--Lib/typing.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/Lib/typing.py b/Lib/typing.py
index 508f4b6..660ad35 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -195,7 +195,7 @@ def _type_repr(obj):
return repr(obj)
-def _collect_type_vars(types, typevar_types=None):
+def _collect_type_vars(types_, typevar_types=None):
"""Collect all type variable contained
in types in order of first appearance (lexicographic order). For example::
@@ -204,10 +204,10 @@ def _collect_type_vars(types, typevar_types=None):
if typevar_types is None:
typevar_types = TypeVar
tvars = []
- for t in types:
+ for t in types_:
if isinstance(t, typevar_types) and t not in tvars:
tvars.append(t)
- if isinstance(t, (_GenericAlias, GenericAlias)):
+ if isinstance(t, (_GenericAlias, GenericAlias, types.Union)):
tvars.extend([t for t in t.__parameters__ if t not in tvars])
return tuple(tvars)
@@ -314,12 +314,14 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
"""
if isinstance(t, ForwardRef):
return t._evaluate(globalns, localns, recursive_guard)
- if isinstance(t, (_GenericAlias, GenericAlias)):
+ if isinstance(t, (_GenericAlias, GenericAlias, types.Union)):
ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
if ev_args == t.__args__:
return t
if isinstance(t, GenericAlias):
return GenericAlias(t.__origin__, ev_args)
+ if isinstance(t, types.Union):
+ return functools.reduce(operator.or_, ev_args)
else:
return t.copy_with(ev_args)
return t
@@ -1013,7 +1015,7 @@ class _GenericAlias(_BaseGenericAlias, _root=True):
for arg in self.__args__:
if isinstance(arg, self._typevar_types):
arg = subst[arg]
- elif isinstance(arg, (_GenericAlias, GenericAlias)):
+ elif isinstance(arg, (_GenericAlias, GenericAlias, types.Union)):
subparams = arg.__parameters__
if subparams:
subargs = tuple(subst[x] for x in subparams)
@@ -1779,6 +1781,12 @@ def _strip_annotations(t):
if stripped_args == t.__args__:
return t
return GenericAlias(t.__origin__, stripped_args)
+ if isinstance(t, types.Union):
+ stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
+ if stripped_args == t.__args__:
+ return t
+ return functools.reduce(operator.or_, stripped_args)
+
return t