diff options
Diffstat (limited to 'Lib/typing.py')
-rw-r--r-- | Lib/typing.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/Lib/typing.py b/Lib/typing.py index 508f4b6..660ad35 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -195,7 +195,7 @@ def _type_repr(obj): return repr(obj) -def _collect_type_vars(types, typevar_types=None): +def _collect_type_vars(types_, typevar_types=None): """Collect all type variable contained in types in order of first appearance (lexicographic order). For example:: @@ -204,10 +204,10 @@ def _collect_type_vars(types, typevar_types=None): if typevar_types is None: typevar_types = TypeVar tvars = [] - for t in types: + for t in types_: if isinstance(t, typevar_types) and t not in tvars: tvars.append(t) - if isinstance(t, (_GenericAlias, GenericAlias)): + if isinstance(t, (_GenericAlias, GenericAlias, types.Union)): tvars.extend([t for t in t.__parameters__ if t not in tvars]) return tuple(tvars) @@ -314,12 +314,14 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()): """ if isinstance(t, ForwardRef): return t._evaluate(globalns, localns, recursive_guard) - if isinstance(t, (_GenericAlias, GenericAlias)): + if isinstance(t, (_GenericAlias, GenericAlias, types.Union)): ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__) if ev_args == t.__args__: return t if isinstance(t, GenericAlias): return GenericAlias(t.__origin__, ev_args) + if isinstance(t, types.Union): + return functools.reduce(operator.or_, ev_args) else: return t.copy_with(ev_args) return t @@ -1013,7 +1015,7 @@ class _GenericAlias(_BaseGenericAlias, _root=True): for arg in self.__args__: if isinstance(arg, self._typevar_types): arg = subst[arg] - elif isinstance(arg, (_GenericAlias, GenericAlias)): + elif isinstance(arg, (_GenericAlias, GenericAlias, types.Union)): subparams = arg.__parameters__ if subparams: subargs = tuple(subst[x] for x in subparams) @@ -1779,6 +1781,12 @@ def _strip_annotations(t): if stripped_args == t.__args__: return t return GenericAlias(t.__origin__, stripped_args) + if isinstance(t, types.Union): + stripped_args = tuple(_strip_annotations(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return functools.reduce(operator.or_, stripped_args) + return t |