diff options
author | Yurii Karabas <1998uriyyo@gmail.com> | 2021-07-17 03:33:40 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-17 03:33:40 (GMT) |
commit | bf89ff96e6ba21bb52b8597b5e51e8ffc57e6589 (patch) | |
tree | 9ffa39f9f8ad5786e7dd4dccc5b69118fa66077d /Lib/typing.py | |
parent | f783428a2313a729ca8b539c5a86ff114b9ff375 (diff) | |
download | cpython-bf89ff96e6ba21bb52b8597b5e51e8ffc57e6589.zip cpython-bf89ff96e6ba21bb52b8597b5e51e8ffc57e6589.tar.gz cpython-bf89ff96e6ba21bb52b8597b5e51e8ffc57e6589.tar.bz2 |
bpo-44490: Improve typing module compatibility with types.Union (GH-27048)
Diffstat (limited to 'Lib/typing.py')
-rw-r--r-- | Lib/typing.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/Lib/typing.py b/Lib/typing.py index 2f22868..f7386ea 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -196,7 +196,7 @@ def _type_repr(obj): return repr(obj) -def _collect_type_vars(types, typevar_types=None): +def _collect_type_vars(types_, typevar_types=None): """Collect all type variable contained in types in order of first appearance (lexicographic order). For example:: @@ -205,10 +205,10 @@ def _collect_type_vars(types, typevar_types=None): if typevar_types is None: typevar_types = TypeVar tvars = [] - for t in types: + for t in types_: if isinstance(t, typevar_types) and t not in tvars: tvars.append(t) - if isinstance(t, (_GenericAlias, GenericAlias)): + if isinstance(t, (_GenericAlias, GenericAlias, types.Union)): tvars.extend([t for t in t.__parameters__ if t not in tvars]) return tuple(tvars) @@ -315,12 +315,14 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()): """ if isinstance(t, ForwardRef): return t._evaluate(globalns, localns, recursive_guard) - if isinstance(t, (_GenericAlias, GenericAlias)): + if isinstance(t, (_GenericAlias, GenericAlias, types.Union)): ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__) if ev_args == t.__args__: return t if isinstance(t, GenericAlias): return GenericAlias(t.__origin__, ev_args) + if isinstance(t, types.Union): + return functools.reduce(operator.or_, ev_args) else: return t.copy_with(ev_args) return t @@ -1009,7 +1011,7 @@ class _GenericAlias(_BaseGenericAlias, _root=True): for arg in self.__args__: if isinstance(arg, self._typevar_types): arg = subst[arg] - elif isinstance(arg, (_GenericAlias, GenericAlias)): + elif isinstance(arg, (_GenericAlias, GenericAlias, types.Union)): subparams = arg.__parameters__ if subparams: subargs = tuple(subst[x] for x in subparams) @@ -1775,6 +1777,12 @@ def _strip_annotations(t): if stripped_args == t.__args__: return t return GenericAlias(t.__origin__, stripped_args) + if isinstance(t, types.Union): + stripped_args = tuple(_strip_annotations(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return functools.reduce(operator.or_, stripped_args) + return t |