diff options
author | Ivan Levkivskyi <levkivskyi@gmail.com> | 2019-05-28 07:40:15 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-05-28 07:40:15 (GMT) |
commit | 74d7f76e2c953fbfdb7ce01b7319d91d471cc5ef (patch) | |
tree | 6bba7b64dc4b4a88569809f0758113c87bb690b4 /Lib | |
parent | 3880f263d2994fb1eba25835dddccb0cf696fdf0 (diff) | |
download | cpython-74d7f76e2c953fbfdb7ce01b7319d91d471cc5ef.zip cpython-74d7f76e2c953fbfdb7ce01b7319d91d471cc5ef.tar.gz cpython-74d7f76e2c953fbfdb7ce01b7319d91d471cc5ef.tar.bz2 |
bpo-37058: PEP 544: Add Protocol to typing module (GH-13585)
I tried to get rid of the `_ProtocolMeta`, but unfortunately it didn'y work. My idea to return a generic alias from `@runtime_checkable` made runtime protocols unpickleable. I am not sure what is worse (a custom metaclass or having some classes unpickleable), so I decided to stick with the status quo (since there were no complains so far). So essentially this is a copy of the implementation in `typing_extensions` with two modifications:
* Rename `@runtime` to `@runtime_checkable` (plus corresponding updates).
* Allow protocols that extend `collections.abc.Iterable` etc.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/test_typing.py | 731 | ||||
-rw-r--r-- | Lib/typing.py | 342 |
2 files changed, 955 insertions, 118 deletions
diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 46b7621..2b4b934 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -12,8 +12,8 @@ from typing import T, KT, VT # Not in __all__. from typing import Union, Optional, Literal from typing import Tuple, List, MutableMapping from typing import Callable -from typing import Generic, ClassVar, Final, final -from typing import cast +from typing import Generic, ClassVar, Final, final, Protocol +from typing import cast, runtime_checkable from typing import get_type_hints from typing import no_type_check, no_type_check_decorator from typing import Type @@ -24,6 +24,7 @@ from typing import Pattern, Match import abc import typing import weakref +import types from test import mod_generics_cache @@ -585,7 +586,710 @@ class MySimpleMapping(SimpleMapping[XK, XV]): return default +class Coordinate(Protocol): + x: int + y: int + +@runtime_checkable +class Point(Coordinate, Protocol): + label: str + +class MyPoint: + x: int + y: int + label: str + +class XAxis(Protocol): + x: int + +class YAxis(Protocol): + y: int + +@runtime_checkable +class Position(XAxis, YAxis, Protocol): + pass + +@runtime_checkable +class Proto(Protocol): + attr: int + def meth(self, arg: str) -> int: + ... + +class Concrete(Proto): + pass + +class Other: + attr: int = 1 + def meth(self, arg: str) -> int: + if arg == 'this': + return 1 + return 0 + +class NT(NamedTuple): + x: int + y: int + +@runtime_checkable +class HasCallProtocol(Protocol): + __call__: typing.Callable + + class ProtocolTests(BaseTestCase): + def test_basic_protocol(self): + @runtime_checkable + class P(Protocol): + def meth(self): + pass + + class C: pass + + class D: + def meth(self): + pass + + def f(): + pass + + self.assertIsSubclass(D, P) + self.assertIsInstance(D(), P) + self.assertNotIsSubclass(C, P) + self.assertNotIsInstance(C(), P) + self.assertNotIsSubclass(types.FunctionType, P) + self.assertNotIsInstance(f, P) + + def test_everything_implements_empty_protocol(self): + @runtime_checkable + class Empty(Protocol): + pass + + class C: + pass + + def f(): + pass + + for thing in (object, type, tuple, C, types.FunctionType): + self.assertIsSubclass(thing, Empty) + for thing in (object(), 1, (), typing, f): + self.assertIsInstance(thing, Empty) + + def test_function_implements_protocol(self): + def f(): + pass + + self.assertIsInstance(f, HasCallProtocol) + + def test_no_inheritance_from_nominal(self): + class C: pass + + class BP(Protocol): pass + + with self.assertRaises(TypeError): + class P(C, Protocol): + pass + with self.assertRaises(TypeError): + class P(Protocol, C): + pass + with self.assertRaises(TypeError): + class P(BP, C, Protocol): + pass + + class D(BP, C): pass + + class E(C, BP): pass + + self.assertNotIsInstance(D(), E) + self.assertNotIsInstance(E(), D) + + def test_no_instantiation(self): + class P(Protocol): pass + + with self.assertRaises(TypeError): + P() + + class C(P): pass + + self.assertIsInstance(C(), C) + T = TypeVar('T') + + class PG(Protocol[T]): pass + + with self.assertRaises(TypeError): + PG() + with self.assertRaises(TypeError): + PG[int]() + with self.assertRaises(TypeError): + PG[T]() + + class CG(PG[T]): pass + + self.assertIsInstance(CG[int](), CG) + + def test_cannot_instantiate_abstract(self): + @runtime_checkable + class P(Protocol): + @abc.abstractmethod + def ameth(self) -> int: + raise NotImplementedError + + class B(P): + pass + + class C(B): + def ameth(self) -> int: + return 26 + + with self.assertRaises(TypeError): + B() + self.assertIsInstance(C(), P) + + def test_subprotocols_extending(self): + class P1(Protocol): + def meth1(self): + pass + + @runtime_checkable + class P2(P1, Protocol): + def meth2(self): + pass + + class C: + def meth1(self): + pass + + def meth2(self): + pass + + class C1: + def meth1(self): + pass + + class C2: + def meth2(self): + pass + + self.assertNotIsInstance(C1(), P2) + self.assertNotIsInstance(C2(), P2) + self.assertNotIsSubclass(C1, P2) + self.assertNotIsSubclass(C2, P2) + self.assertIsInstance(C(), P2) + self.assertIsSubclass(C, P2) + + def test_subprotocols_merging(self): + class P1(Protocol): + def meth1(self): + pass + + class P2(Protocol): + def meth2(self): + pass + + @runtime_checkable + class P(P1, P2, Protocol): + pass + + class C: + def meth1(self): + pass + + def meth2(self): + pass + + class C1: + def meth1(self): + pass + + class C2: + def meth2(self): + pass + + self.assertNotIsInstance(C1(), P) + self.assertNotIsInstance(C2(), P) + self.assertNotIsSubclass(C1, P) + self.assertNotIsSubclass(C2, P) + self.assertIsInstance(C(), P) + self.assertIsSubclass(C, P) + + def test_protocols_issubclass(self): + T = TypeVar('T') + + @runtime_checkable + class P(Protocol): + def x(self): ... + + @runtime_checkable + class PG(Protocol[T]): + def x(self): ... + + class BadP(Protocol): + def x(self): ... + + class BadPG(Protocol[T]): + def x(self): ... + + class C: + def x(self): ... + + self.assertIsSubclass(C, P) + self.assertIsSubclass(C, PG) + self.assertIsSubclass(BadP, PG) + + with self.assertRaises(TypeError): + issubclass(C, PG[T]) + with self.assertRaises(TypeError): + issubclass(C, PG[C]) + with self.assertRaises(TypeError): + issubclass(C, BadP) + with self.assertRaises(TypeError): + issubclass(C, BadPG) + with self.assertRaises(TypeError): + issubclass(P, PG[T]) + with self.assertRaises(TypeError): + issubclass(PG, PG[int]) + + def test_protocols_issubclass_non_callable(self): + class C: + x = 1 + + @runtime_checkable + class PNonCall(Protocol): + x = 1 + + with self.assertRaises(TypeError): + issubclass(C, PNonCall) + self.assertIsInstance(C(), PNonCall) + PNonCall.register(C) + with self.assertRaises(TypeError): + issubclass(C, PNonCall) + self.assertIsInstance(C(), PNonCall) + + # check that non-protocol subclasses are not affected + class D(PNonCall): ... + + self.assertNotIsSubclass(C, D) + self.assertNotIsInstance(C(), D) + D.register(C) + self.assertIsSubclass(C, D) + self.assertIsInstance(C(), D) + with self.assertRaises(TypeError): + issubclass(D, PNonCall) + + def test_protocols_isinstance(self): + T = TypeVar('T') + + @runtime_checkable + class P(Protocol): + def meth(x): ... + + @runtime_checkable + class PG(Protocol[T]): + def meth(x): ... + + class BadP(Protocol): + def meth(x): ... + + class BadPG(Protocol[T]): + def meth(x): ... + + class C: + def meth(x): ... + + self.assertIsInstance(C(), P) + self.assertIsInstance(C(), PG) + with self.assertRaises(TypeError): + isinstance(C(), PG[T]) + with self.assertRaises(TypeError): + isinstance(C(), PG[C]) + with self.assertRaises(TypeError): + isinstance(C(), BadP) + with self.assertRaises(TypeError): + isinstance(C(), BadPG) + + def test_protocols_isinstance_py36(self): + class APoint: + def __init__(self, x, y, label): + self.x = x + self.y = y + self.label = label + + class BPoint: + label = 'B' + + def __init__(self, x, y): + self.x = x + self.y = y + + class C: + def __init__(self, attr): + self.attr = attr + + def meth(self, arg): + return 0 + + class Bad: pass + + self.assertIsInstance(APoint(1, 2, 'A'), Point) + self.assertIsInstance(BPoint(1, 2), Point) + self.assertNotIsInstance(MyPoint(), Point) + self.assertIsInstance(BPoint(1, 2), Position) + self.assertIsInstance(Other(), Proto) + self.assertIsInstance(Concrete(), Proto) + self.assertIsInstance(C(42), Proto) + self.assertNotIsInstance(Bad(), Proto) + self.assertNotIsInstance(Bad(), Point) + self.assertNotIsInstance(Bad(), Position) + self.assertNotIsInstance(Bad(), Concrete) + self.assertNotIsInstance(Other(), Concrete) + self.assertIsInstance(NT(1, 2), Position) + + def test_protocols_isinstance_init(self): + T = TypeVar('T') + + @runtime_checkable + class P(Protocol): + x = 1 + + @runtime_checkable + class PG(Protocol[T]): + x = 1 + + class C: + def __init__(self, x): + self.x = x + + self.assertIsInstance(C(1), P) + self.assertIsInstance(C(1), PG) + + def test_protocols_support_register(self): + @runtime_checkable + class P(Protocol): + x = 1 + + class PM(Protocol): + def meth(self): pass + + class D(PM): pass + + class C: pass + + D.register(C) + P.register(C) + self.assertIsInstance(C(), P) + self.assertIsInstance(C(), D) + + def test_none_on_non_callable_doesnt_block_implementation(self): + @runtime_checkable + class P(Protocol): + x = 1 + + class A: + x = 1 + + class B(A): + x = None + + class C: + def __init__(self): + self.x = None + + self.assertIsInstance(B(), P) + self.assertIsInstance(C(), P) + + def test_none_on_callable_blocks_implementation(self): + @runtime_checkable + class P(Protocol): + def x(self): ... + + class A: + def x(self): ... + + class B(A): + x = None + + class C: + def __init__(self): + self.x = None + + self.assertNotIsInstance(B(), P) + self.assertNotIsInstance(C(), P) + + def test_non_protocol_subclasses(self): + class P(Protocol): + x = 1 + + @runtime_checkable + class PR(Protocol): + def meth(self): pass + + class NonP(P): + x = 1 + + class NonPR(PR): pass + + class C: + x = 1 + + class D: + def meth(self): pass + + self.assertNotIsInstance(C(), NonP) + self.assertNotIsInstance(D(), NonPR) + self.assertNotIsSubclass(C, NonP) + self.assertNotIsSubclass(D, NonPR) + self.assertIsInstance(NonPR(), PR) + self.assertIsSubclass(NonPR, PR) + + def test_custom_subclasshook(self): + class P(Protocol): + x = 1 + + class OKClass: pass + + class BadClass: + x = 1 + + class C(P): + @classmethod + def __subclasshook__(cls, other): + return other.__name__.startswith("OK") + + self.assertIsInstance(OKClass(), C) + self.assertNotIsInstance(BadClass(), C) + self.assertIsSubclass(OKClass, C) + self.assertNotIsSubclass(BadClass, C) + + def test_issubclass_fails_correctly(self): + @runtime_checkable + class P(Protocol): + x = 1 + + class C: pass + + with self.assertRaises(TypeError): + issubclass(C(), P) + + def test_defining_generic_protocols(self): + T = TypeVar('T') + S = TypeVar('S') + + @runtime_checkable + class PR(Protocol[T, S]): + def meth(self): pass + + class P(PR[int, T], Protocol[T]): + y = 1 + + with self.assertRaises(TypeError): + PR[int] + with self.assertRaises(TypeError): + P[int, str] + with self.assertRaises(TypeError): + PR[int, 1] + with self.assertRaises(TypeError): + PR[int, ClassVar] + + class C(PR[int, T]): pass + + self.assertIsInstance(C[str](), C) + + def test_defining_generic_protocols_old_style(self): + T = TypeVar('T') + S = TypeVar('S') + + @runtime_checkable + class PR(Protocol, Generic[T, S]): + def meth(self): pass + + class P(PR[int, str], Protocol): + y = 1 + + with self.assertRaises(TypeError): + issubclass(PR[int, str], PR) + self.assertIsSubclass(P, PR) + with self.assertRaises(TypeError): + PR[int] + with self.assertRaises(TypeError): + PR[int, 1] + + class P1(Protocol, Generic[T]): + def bar(self, x: T) -> str: ... + + class P2(Generic[T], Protocol): + def bar(self, x: T) -> str: ... + + @runtime_checkable + class PSub(P1[str], Protocol): + x = 1 + + class Test: + x = 1 + + def bar(self, x: str) -> str: + return x + + self.assertIsInstance(Test(), PSub) + with self.assertRaises(TypeError): + PR[int, ClassVar] + + def test_init_called(self): + T = TypeVar('T') + + class P(Protocol[T]): pass + + class C(P[T]): + def __init__(self): + self.test = 'OK' + + self.assertEqual(C[int]().test, 'OK') + + def test_protocols_bad_subscripts(self): + T = TypeVar('T') + S = TypeVar('S') + with self.assertRaises(TypeError): + class P(Protocol[T, T]): pass + with self.assertRaises(TypeError): + class P(Protocol[int]): pass + with self.assertRaises(TypeError): + class P(Protocol[T], Protocol[S]): pass + with self.assertRaises(TypeError): + class P(typing.Mapping[T, S], Protocol[T]): pass + + def test_generic_protocols_repr(self): + T = TypeVar('T') + S = TypeVar('S') + + class P(Protocol[T, S]): pass + + self.assertTrue(repr(P[T, S]).endswith('P[~T, ~S]')) + self.assertTrue(repr(P[int, str]).endswith('P[int, str]')) + + def test_generic_protocols_eq(self): + T = TypeVar('T') + S = TypeVar('S') + + class P(Protocol[T, S]): pass + + self.assertEqual(P, P) + self.assertEqual(P[int, T], P[int, T]) + self.assertEqual(P[T, T][Tuple[T, S]][int, str], + P[Tuple[int, str], Tuple[int, str]]) + + def test_generic_protocols_special_from_generic(self): + T = TypeVar('T') + + class P(Protocol[T]): pass + + self.assertEqual(P.__parameters__, (T,)) + self.assertEqual(P[int].__parameters__, ()) + self.assertEqual(P[int].__args__, (int,)) + self.assertIs(P[int].__origin__, P) + + def test_generic_protocols_special_from_protocol(self): + @runtime_checkable + class PR(Protocol): + x = 1 + + class P(Protocol): + def meth(self): + pass + + T = TypeVar('T') + + class PG(Protocol[T]): + x = 1 + + def meth(self): + pass + + self.assertTrue(P._is_protocol) + self.assertTrue(PR._is_protocol) + self.assertTrue(PG._is_protocol) + self.assertFalse(P._is_runtime_protocol) + self.assertTrue(PR._is_runtime_protocol) + self.assertTrue(PG[int]._is_protocol) + self.assertEqual(typing._get_protocol_attrs(P), {'meth'}) + self.assertEqual(typing._get_protocol_attrs(PR), {'x'}) + self.assertEqual(frozenset(typing._get_protocol_attrs(PG)), + frozenset({'x', 'meth'})) + + def test_no_runtime_deco_on_nominal(self): + with self.assertRaises(TypeError): + @runtime_checkable + class C: pass + + class Proto(Protocol): + x = 1 + + with self.assertRaises(TypeError): + @runtime_checkable + class Concrete(Proto): + pass + + def test_none_treated_correctly(self): + @runtime_checkable + class P(Protocol): + x = None # type: int + + class B(object): pass + + self.assertNotIsInstance(B(), P) + + class C: + x = 1 + + class D: + x = None + + self.assertIsInstance(C(), P) + self.assertIsInstance(D(), P) + + class CI: + def __init__(self): + self.x = 1 + + class DI: + def __init__(self): + self.x = None + + self.assertIsInstance(C(), P) + self.assertIsInstance(D(), P) + + def test_protocols_in_unions(self): + class P(Protocol): + x = None # type: int + + Alias = typing.Union[typing.Iterable, P] + Alias2 = typing.Union[P, typing.Iterable] + self.assertEqual(Alias, Alias2) + + def test_protocols_pickleable(self): + global P, CP # pickle wants to reference the class by name + T = TypeVar('T') + + @runtime_checkable + class P(Protocol[T]): + x = 1 + + class CP(P[int]): + pass + + c = CP() + c.foo = 42 + c.bar = 'abc' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z = pickle.dumps(c, proto) + x = pickle.loads(z) + self.assertEqual(x.foo, 42) + self.assertEqual(x.bar, 'abc') + self.assertEqual(x.x, 1) + self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'}) + s = pickle.dumps(P) + D = pickle.loads(s) + + class E: + x = 1 + + self.assertIsInstance(E(), D) def test_supports_int(self): self.assertIsSubclass(int, typing.SupportsInt) @@ -634,9 +1338,8 @@ class ProtocolTests(BaseTestCase): self.assertIsSubclass(int, typing.SupportsIndex) self.assertNotIsSubclass(str, typing.SupportsIndex) - def test_protocol_instance_type_error(self): - with self.assertRaises(TypeError): - isinstance(0, typing.SupportsAbs) + def test_bundled_protocol_instance_works(self): + self.assertIsInstance(0, typing.SupportsAbs) class C1(typing.SupportsInt): def __int__(self) -> int: return 42 @@ -645,6 +1348,20 @@ class ProtocolTests(BaseTestCase): c = C2() self.assertIsInstance(c, C1) + def test_collections_protocols_allowed(self): + @runtime_checkable + class Custom(collections.abc.Iterable, Protocol): + def close(self): ... + + class A: pass + class B: + def __iter__(self): + return [] + def close(self): + return 0 + + self.assertIsSubclass(B, Custom) + self.assertNotIsSubclass(A, Custom) class GenericTests(BaseTestCase): @@ -771,7 +1488,7 @@ class GenericTests(BaseTestCase): def test_new_repr_bare(self): T = TypeVar('T') self.assertEqual(repr(Generic[T]), 'typing.Generic[~T]') - self.assertEqual(repr(typing._Protocol[T]), 'typing._Protocol[~T]') + self.assertEqual(repr(typing.Protocol[T]), 'typing.Protocol[~T]') class C(typing.Dict[Any, Any]): ... # this line should just work repr(C.__mro__) @@ -1067,7 +1784,7 @@ class GenericTests(BaseTestCase): with self.assertRaises(TypeError): Tuple[Generic[T]] with self.assertRaises(TypeError): - List[typing._Protocol] + List[typing.Protocol] def test_type_erasure_special(self): T = TypeVar('T') diff --git a/Lib/typing.py b/Lib/typing.py index d3e84cd..14bd06b 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -9,8 +9,7 @@ At large scale, the structure of the module is following: * The core of internal generics API: _GenericAlias and _VariadicGenericAlias, the latter is currently only used by Tuple and Callable. All subscripted types like X[int], Union[int, str], etc., are instances of either of these classes. -* The public counterpart of the generics API consists of two classes: Generic and Protocol - (the latter is currently private, but will be made public after PEP 544 acceptance). +* The public counterpart of the generics API consists of two classes: Generic and Protocol. * Public helper functions: get_type_hints, overload, cast, no_type_check, no_type_check_decorator. * Generic aliases for collections.abc ABCs and few additional protocols. @@ -18,7 +17,7 @@ At large scale, the structure of the module is following: * Wrapper submodules for re and io related types. """ -from abc import abstractmethod, abstractproperty +from abc import abstractmethod, abstractproperty, ABCMeta import collections import collections.abc import contextlib @@ -39,6 +38,7 @@ __all__ = [ 'Generic', 'Literal', 'Optional', + 'Protocol', 'Tuple', 'Type', 'TypeVar', @@ -102,6 +102,7 @@ __all__ = [ 'no_type_check_decorator', 'NoReturn', 'overload', + 'runtime_checkable', 'Text', 'TYPE_CHECKING', ] @@ -123,7 +124,7 @@ def _type_check(arg, msg, is_argument=True): We append the repr() of the actual value (truncated to 100 chars). """ - invalid_generic_forms = (Generic, _Protocol) + invalid_generic_forms = (Generic, Protocol) if is_argument: invalid_generic_forms = invalid_generic_forms + (ClassVar, Final) @@ -135,7 +136,7 @@ def _type_check(arg, msg, is_argument=True): arg.__origin__ in invalid_generic_forms): raise TypeError(f"{arg} is not valid as type argument") if (isinstance(arg, _SpecialForm) and arg not in (Any, NoReturn) or - arg in (Generic, _Protocol)): + arg in (Generic, Protocol)): raise TypeError(f"Plain {arg} is not valid as type argument") if isinstance(arg, (type, TypeVar, ForwardRef)): return arg @@ -665,8 +666,8 @@ class _GenericAlias(_Final, _root=True): @_tp_cache def __getitem__(self, params): - if self.__origin__ in (Generic, _Protocol): - # Can't subscript Generic[...] or _Protocol[...]. + if self.__origin__ in (Generic, Protocol): + # Can't subscript Generic[...] or Protocol[...]. raise TypeError(f"Cannot subscript already-subscripted {self}") if not isinstance(params, tuple): params = (params,) @@ -733,6 +734,8 @@ class _GenericAlias(_Final, _root=True): res.append(Generic) return tuple(res) if self.__origin__ is Generic: + if Protocol in bases: + return () i = bases.index(self) for b in bases[i+1:]: if isinstance(b, _GenericAlias) and b is not self: @@ -850,10 +853,11 @@ class Generic: return default """ __slots__ = () + _is_protocol = False def __new__(cls, *args, **kwds): - if cls is Generic: - raise TypeError("Type Generic cannot be instantiated; " + if cls in (Generic, Protocol): + raise TypeError(f"Type {cls.__name__} cannot be instantiated; " "it can be used only as a base class") if super().__new__ is object.__new__ and cls.__init__ is not object.__init__: obj = super().__new__(cls) @@ -870,17 +874,14 @@ class Generic: f"Parameter list to {cls.__qualname__}[...] cannot be empty") msg = "Parameters to generic types must be types." params = tuple(_type_check(p, msg) for p in params) - if cls is Generic: - # Generic can only be subscripted with unique type variables. + if cls in (Generic, Protocol): + # Generic and Protocol can only be subscripted with unique type variables. if not all(isinstance(p, TypeVar) for p in params): raise TypeError( - "Parameters to Generic[...] must all be type variables") + f"Parameters to {cls.__name__}[...] must all be type variables") if len(set(params)) != len(params): raise TypeError( - "Parameters to Generic[...] must all be unique") - elif cls is _Protocol: - # _Protocol is internal at the moment, just skip the check - pass + f"Parameters to {cls.__name__}[...] must all be unique") else: # Subscripting a regular Generic subclass. _check_generic(cls, params) @@ -892,7 +893,7 @@ class Generic: if '__orig_bases__' in cls.__dict__: error = Generic in cls.__orig_bases__ else: - error = Generic in cls.__bases__ and cls.__name__ != '_Protocol' + error = Generic in cls.__bases__ and cls.__name__ != 'Protocol' if error: raise TypeError("Cannot inherit from plain Generic") if '__orig_bases__' in cls.__dict__: @@ -910,9 +911,7 @@ class Generic: raise TypeError( "Cannot inherit from Generic[...] multiple types.") gvars = base.__parameters__ - if gvars is None: - gvars = tvars - else: + if gvars is not None: tvarset = set(tvars) gvarset = set(gvars) if not tvarset <= gvarset: @@ -935,6 +934,204 @@ class _TypingEllipsis: """Internal placeholder for ... (ellipsis).""" +_TYPING_INTERNALS = ['__parameters__', '__orig_bases__', '__orig_class__', + '_is_protocol', '_is_runtime_protocol'] + +_SPECIAL_NAMES = ['__abstractmethods__', '__annotations__', '__dict__', '__doc__', + '__init__', '__module__', '__new__', '__slots__', + '__subclasshook__', '__weakref__'] + +# These special attributes will be not collected as protocol members. +EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS + _SPECIAL_NAMES + ['_MutableMapping__marker'] + + +def _get_protocol_attrs(cls): + """Collect protocol members from a protocol class objects. + + This includes names actually defined in the class dictionary, as well + as names that appear in annotations. Special names (above) are skipped. + """ + attrs = set() + for base in cls.__mro__[:-1]: # without object + if base.__name__ in ('Protocol', 'Generic'): + continue + annotations = getattr(base, '__annotations__', {}) + for attr in list(base.__dict__.keys()) + list(annotations.keys()): + if not attr.startswith('_abc_') and attr not in EXCLUDED_ATTRIBUTES: + attrs.add(attr) + return attrs + + +def _is_callable_members_only(cls): + # PEP 544 prohibits using issubclass() with protocols that have non-method members. + return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls)) + + +def _no_init(self, *args, **kwargs): + if type(self)._is_protocol: + raise TypeError('Protocols cannot be instantiated') + + +def _allow_reckless_class_cheks(): + """Allow instnance and class checks for special stdlib modules. + + The abc and functools modules indiscriminately call isinstance() and + issubclass() on the whole MRO of a user class, which may contain protocols. + """ + try: + return sys._getframe(3).f_globals['__name__'] in ['abc', 'functools'] + except (AttributeError, ValueError): # For platforms without _getframe(). + return True + + +_PROTO_WHITELIST = ['Callable', 'Awaitable', + 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', + 'ContextManager', 'AsyncContextManager'] + + +class _ProtocolMeta(ABCMeta): + # This metaclass is really unfortunate and exists only because of + # the lack of __instancehook__. + def __instancecheck__(cls, instance): + # We need this method for situations where attributes are + # assigned in __init__. + if ((not getattr(cls, '_is_protocol', False) or + _is_callable_members_only(cls)) and + issubclass(instance.__class__, cls)): + return True + if cls._is_protocol: + if all(hasattr(instance, attr) and + # All *methods* can be blocked by setting them to None. + (not callable(getattr(cls, attr, None)) or + getattr(instance, attr) is not None) + for attr in _get_protocol_attrs(cls)): + return True + return super().__instancecheck__(instance) + + +class Protocol(Generic, metaclass=_ProtocolMeta): + """Base class for protocol classes. + + Protocol classes are defined as:: + + class Proto(Protocol): + def meth(self) -> int: + ... + + Such classes are primarily used with static type checkers that recognize + structural subtyping (static duck-typing), for example:: + + class C: + def meth(self) -> int: + return 0 + + def func(x: Proto) -> int: + return x.meth() + + func(C()) # Passes static type check + + See PEP 544 for details. Protocol classes decorated with + @typing.runtime_checkable act as simple-minded runtime protocols that check + only the presence of given attributes, ignoring their type signatures. + Protocol classes can be generic, they are defined as:: + + class GenProto(Protocol[T]): + def meth(self) -> T: + ... + """ + __slots__ = () + _is_protocol = True + _is_runtime_protocol = False + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + + # Determine if this is a protocol or a concrete subclass. + if not cls.__dict__.get('_is_protocol', False): + cls._is_protocol = any(b is Protocol for b in cls.__bases__) + + # Set (or override) the protocol subclass hook. + def _proto_hook(other): + if not cls.__dict__.get('_is_protocol', False): + return NotImplemented + + # First, perform various sanity checks. + if not getattr(cls, '_is_runtime_protocol', False): + if _allow_reckless_class_cheks(): + return NotImplemented + raise TypeError("Instance and class checks can only be used with" + " @runtime_checkable protocols") + if not _is_callable_members_only(cls): + if _allow_reckless_class_cheks(): + return NotImplemented + raise TypeError("Protocols with non-method members" + " don't support issubclass()") + if not isinstance(other, type): + # Same error message as for issubclass(1, int). + raise TypeError('issubclass() arg 1 must be a class') + + # Second, perform the actual structural compatibility check. + for attr in _get_protocol_attrs(cls): + for base in other.__mro__: + # Check if the members appears in the class dictionary... + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + + # ...or in annotations, if it is a sub-protocol. + annotations = getattr(base, '__annotations__', {}) + if (isinstance(annotations, collections.abc.Mapping) and + attr in annotations and + issubclass(other, Generic) and other._is_protocol): + break + else: + return NotImplemented + return True + + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = _proto_hook + + # We have nothing more to do for non-protocols... + if not cls._is_protocol: + return + + # ... otherwise check consistency of bases, and prohibit instantiation. + for base in cls.__bases__: + if not (base in (object, Generic) or + base.__module__ == 'collections.abc' and base.__name__ in _PROTO_WHITELIST or + issubclass(base, Generic) and base._is_protocol): + raise TypeError('Protocols can only inherit from other' + ' protocols, got %r' % base) + cls.__init__ = _no_init + + +def runtime_checkable(cls): + """Mark a protocol class as a runtime protocol. + + Such protocol can be used with isinstance() and issubclass(). + Raise TypeError if applied to a non-protocol class. + This allows a simple-minded structural check very similar to + one trick ponies in collections.abc such as Iterable. + For example:: + + @runtime_checkable + class Closable(Protocol): + def close(self): ... + + assert isinstance(open('/some/file'), Closable) + + Warning: this will check only the presence of the required methods, + not their type signatures! + """ + if not issubclass(cls, Generic) or not cls._is_protocol: + raise TypeError('@runtime_checkable can be only applied to protocol classes,' + ' got %r' % cls) + cls._is_runtime_protocol = True + return cls + + def cast(typ, val): """Cast a value to a type. @@ -1159,90 +1356,6 @@ def final(f): return f -class _ProtocolMeta(type): - """Internal metaclass for _Protocol. - - This exists so _Protocol classes can be generic without deriving - from Generic. - """ - - def __instancecheck__(self, obj): - if _Protocol not in self.__bases__: - return super().__instancecheck__(obj) - raise TypeError("Protocols cannot be used with isinstance().") - - def __subclasscheck__(self, cls): - if not self._is_protocol: - # No structural checks since this isn't a protocol. - return NotImplemented - - if self is _Protocol: - # Every class is a subclass of the empty protocol. - return True - - # Find all attributes defined in the protocol. - attrs = self._get_protocol_attrs() - - for attr in attrs: - if not any(attr in d.__dict__ for d in cls.__mro__): - return False - return True - - def _get_protocol_attrs(self): - # Get all Protocol base classes. - protocol_bases = [] - for c in self.__mro__: - if getattr(c, '_is_protocol', False) and c.__name__ != '_Protocol': - protocol_bases.append(c) - - # Get attributes included in protocol. - attrs = set() - for base in protocol_bases: - for attr in base.__dict__.keys(): - # Include attributes not defined in any non-protocol bases. - for c in self.__mro__: - if (c is not base and attr in c.__dict__ and - not getattr(c, '_is_protocol', False)): - break - else: - if (not attr.startswith('_abc_') and - attr != '__abstractmethods__' and - attr != '__annotations__' and - attr != '__weakref__' and - attr != '_is_protocol' and - attr != '_gorg' and - attr != '__dict__' and - attr != '__args__' and - attr != '__slots__' and - attr != '_get_protocol_attrs' and - attr != '__next_in_mro__' and - attr != '__parameters__' and - attr != '__origin__' and - attr != '__orig_bases__' and - attr != '__extra__' and - attr != '__tree_hash__' and - attr != '__module__'): - attrs.add(attr) - - return attrs - - -class _Protocol(Generic, metaclass=_ProtocolMeta): - """Internal base class for protocol classes. - - This implements a simple-minded structural issubclass check - (similar but more general than the one-offs in collections.abc - such as Hashable). - """ - - __slots__ = () - - _is_protocol = True - - def __class_getitem__(cls, params): - return super().__class_getitem__(params) - - # Some unconstrained type variables. These are used by the container types. # (These are not for export.) T = TypeVar('T') # Any type. @@ -1347,7 +1460,8 @@ Type.__doc__ = \ """ -class SupportsInt(_Protocol): +@runtime_checkable +class SupportsInt(Protocol): __slots__ = () @abstractmethod @@ -1355,7 +1469,8 @@ class SupportsInt(_Protocol): pass -class SupportsFloat(_Protocol): +@runtime_checkable +class SupportsFloat(Protocol): __slots__ = () @abstractmethod @@ -1363,7 +1478,8 @@ class SupportsFloat(_Protocol): pass -class SupportsComplex(_Protocol): +@runtime_checkable +class SupportsComplex(Protocol): __slots__ = () @abstractmethod @@ -1371,7 +1487,8 @@ class SupportsComplex(_Protocol): pass -class SupportsBytes(_Protocol): +@runtime_checkable +class SupportsBytes(Protocol): __slots__ = () @abstractmethod @@ -1379,7 +1496,8 @@ class SupportsBytes(_Protocol): pass -class SupportsIndex(_Protocol): +@runtime_checkable +class SupportsIndex(Protocol): __slots__ = () @abstractmethod @@ -1387,7 +1505,8 @@ class SupportsIndex(_Protocol): pass -class SupportsAbs(_Protocol[T_co]): +@runtime_checkable +class SupportsAbs(Protocol[T_co]): __slots__ = () @abstractmethod @@ -1395,7 +1514,8 @@ class SupportsAbs(_Protocol[T_co]): pass -class SupportsRound(_Protocol[T_co]): +@runtime_checkable +class SupportsRound(Protocol[T_co]): __slots__ = () @abstractmethod |