diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2020-05-07 01:09:33 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-07 01:09:33 (GMT) |
commit | c1c7d8ead9eb214a6149a43e31a3213c52448877 (patch) | |
tree | 87c656b19a20d5c5ade41f8b2088d0b84a9fccec /Lib | |
parent | 470aac4d8e76556bd8f820f3f3928dca2b4d2849 (diff) | |
download | cpython-c1c7d8ead9eb214a6149a43e31a3213c52448877.zip cpython-c1c7d8ead9eb214a6149a43e31a3213c52448877.tar.gz cpython-c1c7d8ead9eb214a6149a43e31a3213c52448877.tar.bz2 |
bpo-40397: Refactor typing._GenericAlias (GH-19719)
Make the design more object-oriented.
Split _GenericAlias on two almost independent classes: for special
generic aliases like List and for parametrized generic aliases like List[int].
Add specialized subclasses for Callable, Callable[...], Tuple and Union[...].
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/typing.py | 389 |
1 files changed, 198 insertions, 191 deletions
diff --git a/Lib/typing.py b/Lib/typing.py index f3cd280..681ab6d 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -181,34 +181,11 @@ def _collect_type_vars(types): for t in types: if isinstance(t, TypeVar) and t not in tvars: tvars.append(t) - if ((isinstance(t, _GenericAlias) and not t._special) - or isinstance(t, GenericAlias)): + if isinstance(t, (_GenericAlias, GenericAlias)): tvars.extend([t for t in t.__parameters__ if t not in tvars]) return tuple(tvars) -def _subs_tvars(tp, tvars, subs): - """Substitute type variables 'tvars' with substitutions 'subs'. - These two must have the same length. - """ - if not isinstance(tp, (_GenericAlias, GenericAlias)): - return tp - new_args = list(tp.__args__) - for a, arg in enumerate(tp.__args__): - if isinstance(arg, TypeVar): - for i, tvar in enumerate(tvars): - if arg == tvar: - new_args[a] = subs[i] - else: - new_args[a] = _subs_tvars(arg, tvars, subs) - if tp.__origin__ is Union: - return Union[tuple(new_args)] - if isinstance(tp, GenericAlias): - return GenericAlias(tp.__origin__, tuple(new_args)) - else: - return tp.copy_with(tuple(new_args)) - - def _check_generic(cls, parameters): """Check correct count for parameters of a generic cls (internal helper). This gives a nice error message in case of count mismatch. @@ -229,7 +206,7 @@ def _remove_dups_flatten(parameters): # Flatten out Union[Union[...], ...]. params = [] for p in parameters: - if isinstance(p, _GenericAlias) and p.__origin__ is Union: + if isinstance(p, _UnionGenericAlias): params.extend(p.__args__) elif isinstance(p, tuple) and len(p) > 0 and p[0] is Union: params.extend(p[1:]) @@ -274,18 +251,14 @@ def _eval_type(t, globalns, localns): """ if isinstance(t, ForwardRef): return t._evaluate(globalns, localns) - if isinstance(t, _GenericAlias): + if isinstance(t, (_GenericAlias, GenericAlias)): ev_args = tuple(_eval_type(a, globalns, localns) for a in t.__args__) if ev_args == t.__args__: return t - res = t.copy_with(ev_args) - res._special = t._special - return res - if isinstance(t, GenericAlias): - ev_args = tuple(_eval_type(a, globalns, localns) for a in t.__args__) - if ev_args == t.__args__: - return t - return GenericAlias(t.__origin__, ev_args) + if isinstance(t, GenericAlias): + return GenericAlias(t.__origin__, ev_args) + else: + return t.copy_with(ev_args) return t @@ -300,6 +273,7 @@ class _Final: class _Immutable: """Mixin to indicate that object should not be copied.""" + __slots__ = () def __copy__(self): return self @@ -446,7 +420,7 @@ def Union(self, parameters): parameters = _remove_dups_flatten(parameters) if len(parameters) == 1: return parameters[0] - return _GenericAlias(self, parameters) + return _UnionGenericAlias(self, parameters) @_SpecialForm def Optional(self, parameters): @@ -579,7 +553,7 @@ class TypeVar(_Final, _Immutable, _root=True): """ __slots__ = ('__name__', '__bound__', '__constraints__', - '__covariant__', '__contravariant__') + '__covariant__', '__contravariant__', '__dict__') def __init__(self, name, *constraints, bound=None, covariant=False, contravariant=False): @@ -629,23 +603,10 @@ class TypeVar(_Final, _Immutable, _root=True): # e.g., Dict[T, int].__args__ == (T, int). -# Mapping from non-generic type names that have a generic alias in typing -# but with a different name. -_normalize_alias = {'list': 'List', - 'tuple': 'Tuple', - 'dict': 'Dict', - 'set': 'Set', - 'frozenset': 'FrozenSet', - 'deque': 'Deque', - 'defaultdict': 'DefaultDict', - 'type': 'Type', - 'Set': 'AbstractSet'} - def _is_dunder(attr): return attr.startswith('__') and attr.endswith('__') - -class _GenericAlias(_Final, _root=True): +class _BaseGenericAlias(_Final, _root=True): """The central part of internal API. This represents a generic version of type 'origin' with type arguments 'params'. @@ -654,12 +615,8 @@ class _GenericAlias(_Final, _root=True): have 'name' always set. If 'inst' is False, then the alias can't be instantiated, this is used by e.g. typing.List and typing.Dict. """ - def __init__(self, origin, params, *, inst=True, special=False, name=None): + def __init__(self, origin, params, *, inst=True, name=None): self._inst = inst - self._special = special - if special and name is None: - orig_name = origin.__name__ - name = _normalize_alias.get(orig_name, orig_name) self._name = name if not isinstance(params, tuple): params = (params,) @@ -671,68 +628,20 @@ class _GenericAlias(_Final, _root=True): self.__slots__ = None # This is not documented. if not name: self.__module__ = origin.__module__ - if special: - self.__doc__ = f'A generic version of {origin.__module__}.{origin.__qualname__}' - - @_tp_cache - def __getitem__(self, params): - if self.__origin__ in (Generic, Protocol): - # Can't subscript Generic[...] or Protocol[...]. - raise TypeError(f"Cannot subscript already-subscripted {self}") - if not isinstance(params, tuple): - params = (params,) - msg = "Parameters to generic types must be types." - params = tuple(_type_check(p, msg) for p in params) - _check_generic(self, params) - return _subs_tvars(self, self.__parameters__, params) - - def copy_with(self, params): - # We don't copy self._special. - return _GenericAlias(self.__origin__, params, name=self._name, inst=self._inst) - - def __repr__(self): - if (self.__origin__ == Union and len(self.__args__) == 2 - and type(None) in self.__args__): - if self.__args__[0] is not type(None): - arg = self.__args__[0] - else: - arg = self.__args__[1] - return (f'typing.Optional[{_type_repr(arg)}]') - if (self._name != 'Callable' or - len(self.__args__) == 2 and self.__args__[0] is Ellipsis): - if self._name: - name = 'typing.' + self._name - else: - name = _type_repr(self.__origin__) - if not self._special: - args = f'[{", ".join([_type_repr(a) for a in self.__args__])}]' - else: - args = '' - return (f'{name}{args}') - if self._special: - return 'typing.Callable' - return (f'typing.Callable' - f'[[{", ".join([_type_repr(a) for a in self.__args__[:-1]])}], ' - f'{_type_repr(self.__args__[-1])}]') def __eq__(self, other): - if not isinstance(other, _GenericAlias): + if not isinstance(other, _BaseGenericAlias): return NotImplemented - if self.__origin__ != other.__origin__: - return False - if self.__origin__ is Union and other.__origin__ is Union: - return frozenset(self.__args__) == frozenset(other.__args__) - return self.__args__ == other.__args__ + return (self.__origin__ == other.__origin__ + and self.__args__ == other.__args__) def __hash__(self): - if self.__origin__ is Union: - return hash((Union, frozenset(self.__args__))) return hash((self.__origin__, self.__args__)) def __call__(self, *args, **kwargs): if not self._inst: raise TypeError(f"Type {self._name} cannot be instantiated; " - f"use {self._name.lower()}() instead") + f"use {self.__origin__.__name__}() instead") result = self.__origin__(*args, **kwargs) try: result.__orig_class__ = self @@ -741,23 +650,16 @@ class _GenericAlias(_Final, _root=True): return result def __mro_entries__(self, bases): - if self._name: # generic version of an ABC or built-in class - res = [] - if self.__origin__ not in bases: - res.append(self.__origin__) - i = bases.index(self) - if not any(isinstance(b, _GenericAlias) or issubclass(b, Generic) - for b in bases[i+1:]): - res.append(Generic) - return tuple(res) - if self.__origin__ is Generic: - if Protocol in bases: - return () - i = bases.index(self) - for b in bases[i+1:]: - if isinstance(b, _GenericAlias) and b is not self: - return () - return (self.__origin__,) + res = [] + if self.__origin__ not in bases: + res.append(self.__origin__) + i = bases.index(self) + for b in bases[i+1:]: + if isinstance(b, _BaseGenericAlias) or issubclass(b, Generic): + break + else: + res.append(Generic) + return tuple(res) def __getattr__(self, attr): # We are careful for copy and pickle. @@ -767,7 +669,7 @@ class _GenericAlias(_Final, _root=True): raise AttributeError(attr) def __setattr__(self, attr, val): - if _is_dunder(attr) or attr in ('_name', '_inst', '_special'): + if _is_dunder(attr) or attr in ('_name', '_inst'): super().__setattr__(attr, val) else: setattr(self.__origin__, attr, val) @@ -776,39 +678,124 @@ class _GenericAlias(_Final, _root=True): return self.__subclasscheck__(type(obj)) def __subclasscheck__(self, cls): - if self._special: - if not isinstance(cls, _GenericAlias): - return issubclass(cls, self.__origin__) - if cls._special: - return issubclass(cls.__origin__, self.__origin__) raise TypeError("Subscripted generics cannot be used with" " class and instance checks") - def __reduce__(self): - if self._special: - return self._name +class _GenericAlias(_BaseGenericAlias, _root=True): + @_tp_cache + def __getitem__(self, params): + if self.__origin__ in (Generic, Protocol): + # Can't subscript Generic[...] or Protocol[...]. + raise TypeError(f"Cannot subscript already-subscripted {self}") + if not isinstance(params, tuple): + params = (params,) + msg = "Parameters to generic types must be types." + params = tuple(_type_check(p, msg) for p in params) + _check_generic(self, params) + + subst = dict(zip(self.__parameters__, params)) + new_args = [] + for arg in self.__args__: + if isinstance(arg, TypeVar): + arg = subst[arg] + elif isinstance(arg, (_BaseGenericAlias, GenericAlias)): + subargs = tuple(subst[x] for x in arg.__parameters__) + arg = arg[subargs] + new_args.append(arg) + return self.copy_with(tuple(new_args)) + + def copy_with(self, params): + return self.__class__(self.__origin__, params, name=self._name, inst=self._inst) + + def __repr__(self): + if self._name: + name = 'typing.' + self._name + else: + name = _type_repr(self.__origin__) + args = ", ".join([_type_repr(a) for a in self.__args__]) + return f'{name}[{args}]' + + def __reduce__(self): if self._name: origin = globals()[self._name] else: origin = self.__origin__ - if (origin is Callable and - not (len(self.__args__) == 2 and self.__args__[0] is Ellipsis)): - args = list(self.__args__[:-1]), self.__args__[-1] - else: - args = tuple(self.__args__) - if len(args) == 1 and not isinstance(args[0], tuple): - args, = args + args = tuple(self.__args__) + if len(args) == 1 and not isinstance(args[0], tuple): + args, = args return operator.getitem, (origin, args) + def __mro_entries__(self, bases): + if self._name: # generic version of an ABC or built-in class + return super().__mro_entries__(bases) + if self.__origin__ is Generic: + if Protocol in bases: + return () + i = bases.index(self) + for b in bases[i+1:]: + if isinstance(b, _BaseGenericAlias) and b is not self: + return () + return (self.__origin__,) + + +class _SpecialGenericAlias(_BaseGenericAlias, _root=True): + def __init__(self, origin, params, *, inst=True, name=None): + if name is None: + name = origin.__name__ + super().__init__(origin, params, inst=inst, name=name) + self.__doc__ = f'A generic version of {origin.__module__}.{origin.__qualname__}' + + @_tp_cache + def __getitem__(self, params): + if not isinstance(params, tuple): + params = (params,) + msg = "Parameters to generic types must be types." + params = tuple(_type_check(p, msg) for p in params) + _check_generic(self, params) + assert self.__args__ == self.__parameters__ + return self.copy_with(params) + + def copy_with(self, params): + return _GenericAlias(self.__origin__, params, + name=self._name, inst=self._inst) + + def __repr__(self): + return 'typing.' + self._name + + def __subclasscheck__(self, cls): + if isinstance(cls, _SpecialGenericAlias): + return issubclass(cls.__origin__, self.__origin__) + if not isinstance(cls, _GenericAlias): + return issubclass(cls, self.__origin__) + return super().__subclasscheck__(cls) + + def __reduce__(self): + return self._name + + +class _CallableGenericAlias(_GenericAlias, _root=True): + def __repr__(self): + assert self._name == 'Callable' + if len(self.__args__) == 2 and self.__args__[0] is Ellipsis: + return super().__repr__() + return (f'typing.Callable' + f'[[{", ".join([_type_repr(a) for a in self.__args__[:-1]])}], ' + f'{_type_repr(self.__args__[-1])}]') + + def __reduce__(self): + args = self.__args__ + if not (len(args) == 2 and args[0] is ...): + args = list(args[:-1]), args[-1] + return operator.getitem, (Callable, args) + + +class _CallableType(_SpecialGenericAlias, _root=True): + def copy_with(self, params): + return _CallableGenericAlias(self.__origin__, params, + name=self._name, inst=self._inst) -class _VariadicGenericAlias(_GenericAlias, _root=True): - """Same as _GenericAlias above but for variadic aliases. Currently, - this is used only by special internal aliases: Tuple and Callable. - """ def __getitem__(self, params): - if self._name != 'Callable' or not self._special: - return self.__getitem_inner__(params) if not isinstance(params, tuple) or len(params) != 2: raise TypeError("Callable must be used as " "Callable[[arg, ...], result].") @@ -824,29 +811,53 @@ class _VariadicGenericAlias(_GenericAlias, _root=True): @_tp_cache def __getitem_inner__(self, params): - if self.__origin__ is tuple and self._special: - if params == (): - return self.copy_with((_TypingEmpty,)) - if not isinstance(params, tuple): - params = (params,) - if len(params) == 2 and params[1] is ...: - msg = "Tuple[t, ...]: t must be a type." - p = _type_check(params[0], msg) - return self.copy_with((p, _TypingEllipsis)) - msg = "Tuple[t0, t1, ...]: each t must be a type." - params = tuple(_type_check(p, msg) for p in params) - return self.copy_with(params) - if self.__origin__ is collections.abc.Callable and self._special: - args, result = params - msg = "Callable[args, result]: result must be a type." - result = _type_check(result, msg) - if args is Ellipsis: - return self.copy_with((_TypingEllipsis, result)) - msg = "Callable[[arg, ...], result]: each arg must be a type." - args = tuple(_type_check(arg, msg) for arg in args) - params = args + (result,) - return self.copy_with(params) - return super().__getitem__(params) + args, result = params + msg = "Callable[args, result]: result must be a type." + result = _type_check(result, msg) + if args is Ellipsis: + return self.copy_with((_TypingEllipsis, result)) + msg = "Callable[[arg, ...], result]: each arg must be a type." + args = tuple(_type_check(arg, msg) for arg in args) + params = args + (result,) + return self.copy_with(params) + + +class _TupleType(_SpecialGenericAlias, _root=True): + @_tp_cache + def __getitem__(self, params): + if params == (): + return self.copy_with((_TypingEmpty,)) + if not isinstance(params, tuple): + params = (params,) + if len(params) == 2 and params[1] is ...: + msg = "Tuple[t, ...]: t must be a type." + p = _type_check(params[0], msg) + return self.copy_with((p, _TypingEllipsis)) + msg = "Tuple[t0, t1, ...]: each t must be a type." + params = tuple(_type_check(p, msg) for p in params) + return self.copy_with(params) + + +class _UnionGenericAlias(_GenericAlias, _root=True): + def copy_with(self, params): + return Union[params] + + def __eq__(self, other): + if not isinstance(other, _UnionGenericAlias): + return NotImplemented + return set(self.__args__) == set(other.__args__) + + def __hash__(self): + return hash(frozenset(self.__args__)) + + def __repr__(self): + args = self.__args__ + if len(args) == 2: + if args[0] is type(None): + return f'typing.Optional[{_type_repr(args[1])}]' + elif args[1] is type(None): + return f'typing.Optional[{_type_repr(args[0])}]' + return super().__repr__() class Generic: @@ -1162,9 +1173,8 @@ class _AnnotatedAlias(_GenericAlias, _root=True): def __eq__(self, other): if not isinstance(other, _AnnotatedAlias): return NotImplemented - if self.__origin__ != other.__origin__: - return False - return self.__metadata__ == other.__metadata__ + return (self.__origin__ == other.__origin__ + and self.__metadata__ == other.__metadata__) def __hash__(self): return hash((self.__origin__, self.__metadata__)) @@ -1380,9 +1390,7 @@ def _strip_annotations(t): stripped_args = tuple(_strip_annotations(a) for a in t.__args__) if stripped_args == t.__args__: return t - res = t.copy_with(stripped_args) - res._special = t._special - return res + return t.copy_with(stripped_args) if isinstance(t, GenericAlias): stripped_args = tuple(_strip_annotations(a) for a in t.__args__) if stripped_args == t.__args__: @@ -1407,7 +1415,7 @@ def get_origin(tp): """ if isinstance(tp, _AnnotatedAlias): return Annotated - if isinstance(tp, (_GenericAlias, GenericAlias)): + if isinstance(tp, (_BaseGenericAlias, GenericAlias)): return tp.__origin__ if tp is Generic: return Generic @@ -1427,7 +1435,7 @@ def get_args(tp): """ if isinstance(tp, _AnnotatedAlias): return (tp.__origin__,) + tp.__metadata__ - if isinstance(tp, _GenericAlias) and not tp._special: + if isinstance(tp, _GenericAlias): res = tp.__args__ if tp.__origin__ is collections.abc.Callable and res[0] is not Ellipsis: res = (list(res[:-1]), res[-1]) @@ -1561,8 +1569,7 @@ AnyStr = TypeVar('AnyStr', bytes, str) # Various ABCs mimicking those in collections.abc. -def _alias(origin, params, inst=True): - return _GenericAlias(origin, params, special=True, inst=inst) +_alias = _SpecialGenericAlias Hashable = _alias(collections.abc.Hashable, ()) # Not generic. Awaitable = _alias(collections.abc.Awaitable, T_co) @@ -1575,7 +1582,7 @@ Reversible = _alias(collections.abc.Reversible, T_co) Sized = _alias(collections.abc.Sized, ()) # Not generic. Container = _alias(collections.abc.Container, T_co) Collection = _alias(collections.abc.Collection, T_co) -Callable = _VariadicGenericAlias(collections.abc.Callable, (), special=True) +Callable = _CallableType(collections.abc.Callable, ()) Callable.__doc__ = \ """Callable type; Callable[[int], str] is a function of (int) -> str. @@ -1586,7 +1593,7 @@ Callable.__doc__ = \ There is no syntax to indicate optional or keyword arguments, such function types are rarely used as callback types. """ -AbstractSet = _alias(collections.abc.Set, T_co) +AbstractSet = _alias(collections.abc.Set, T_co, name='AbstractSet') MutableSet = _alias(collections.abc.MutableSet, T) # NOTE: Mapping is only covariant in the value type. Mapping = _alias(collections.abc.Mapping, (KT, VT_co)) @@ -1594,7 +1601,7 @@ MutableMapping = _alias(collections.abc.MutableMapping, (KT, VT)) Sequence = _alias(collections.abc.Sequence, T_co) MutableSequence = _alias(collections.abc.MutableSequence, T) ByteString = _alias(collections.abc.ByteString, ()) # Not generic -Tuple = _VariadicGenericAlias(tuple, (), inst=False, special=True) +Tuple = _TupleType(tuple, (), inst=False, name='Tuple') Tuple.__doc__ = \ """Tuple type; Tuple[X, Y] is the cross-product type of X and Y. @@ -1604,24 +1611,24 @@ Tuple.__doc__ = \ To specify a variable-length tuple of homogeneous type, use Tuple[T, ...]. """ -List = _alias(list, T, inst=False) -Deque = _alias(collections.deque, T) -Set = _alias(set, T, inst=False) -FrozenSet = _alias(frozenset, T_co, inst=False) +List = _alias(list, T, inst=False, name='List') +Deque = _alias(collections.deque, T, name='Deque') +Set = _alias(set, T, inst=False, name='Set') +FrozenSet = _alias(frozenset, T_co, inst=False, name='FrozenSet') MappingView = _alias(collections.abc.MappingView, T_co) KeysView = _alias(collections.abc.KeysView, KT) ItemsView = _alias(collections.abc.ItemsView, (KT, VT_co)) ValuesView = _alias(collections.abc.ValuesView, VT_co) -ContextManager = _alias(contextlib.AbstractContextManager, T_co) -AsyncContextManager = _alias(contextlib.AbstractAsyncContextManager, T_co) -Dict = _alias(dict, (KT, VT), inst=False) -DefaultDict = _alias(collections.defaultdict, (KT, VT)) +ContextManager = _alias(contextlib.AbstractContextManager, T_co, name='ContextManager') +AsyncContextManager = _alias(contextlib.AbstractAsyncContextManager, T_co, name='AsyncContextManager') +Dict = _alias(dict, (KT, VT), inst=False, name='Dict') +DefaultDict = _alias(collections.defaultdict, (KT, VT), name='DefaultDict') OrderedDict = _alias(collections.OrderedDict, (KT, VT)) Counter = _alias(collections.Counter, T) ChainMap = _alias(collections.ChainMap, (KT, VT)) Generator = _alias(collections.abc.Generator, (T_co, T_contra, V_co)) AsyncGenerator = _alias(collections.abc.AsyncGenerator, (T_co, T_contra)) -Type = _alias(type, CT_co, inst=False) +Type = _alias(type, CT_co, inst=False, name='Type') Type.__doc__ = \ """A special construct usable to annotate class objects. |