summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorIvan Levkivskyi <levkivskyi@gmail.com>2019-05-28 07:40:15 (GMT)
committerGitHub <noreply@github.com>2019-05-28 07:40:15 (GMT)
commit74d7f76e2c953fbfdb7ce01b7319d91d471cc5ef (patch)
tree6bba7b64dc4b4a88569809f0758113c87bb690b4 /Lib
parent3880f263d2994fb1eba25835dddccb0cf696fdf0 (diff)
downloadcpython-74d7f76e2c953fbfdb7ce01b7319d91d471cc5ef.zip
cpython-74d7f76e2c953fbfdb7ce01b7319d91d471cc5ef.tar.gz
cpython-74d7f76e2c953fbfdb7ce01b7319d91d471cc5ef.tar.bz2
bpo-37058: PEP 544: Add Protocol to typing module (GH-13585)
I tried to get rid of the `_ProtocolMeta`, but unfortunately it didn'y work. My idea to return a generic alias from `@runtime_checkable` made runtime protocols unpickleable. I am not sure what is worse (a custom metaclass or having some classes unpickleable), so I decided to stick with the status quo (since there were no complains so far). So essentially this is a copy of the implementation in `typing_extensions` with two modifications: * Rename `@runtime` to `@runtime_checkable` (plus corresponding updates). * Allow protocols that extend `collections.abc.Iterable` etc.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_typing.py731
-rw-r--r--Lib/typing.py342
2 files changed, 955 insertions, 118 deletions
diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py
index 46b7621..2b4b934 100644
--- a/Lib/test/test_typing.py
+++ b/Lib/test/test_typing.py
@@ -12,8 +12,8 @@ from typing import T, KT, VT # Not in __all__.
from typing import Union, Optional, Literal
from typing import Tuple, List, MutableMapping
from typing import Callable
-from typing import Generic, ClassVar, Final, final
-from typing import cast
+from typing import Generic, ClassVar, Final, final, Protocol
+from typing import cast, runtime_checkable
from typing import get_type_hints
from typing import no_type_check, no_type_check_decorator
from typing import Type
@@ -24,6 +24,7 @@ from typing import Pattern, Match
import abc
import typing
import weakref
+import types
from test import mod_generics_cache
@@ -585,7 +586,710 @@ class MySimpleMapping(SimpleMapping[XK, XV]):
return default
+class Coordinate(Protocol):
+ x: int
+ y: int
+
+@runtime_checkable
+class Point(Coordinate, Protocol):
+ label: str
+
+class MyPoint:
+ x: int
+ y: int
+ label: str
+
+class XAxis(Protocol):
+ x: int
+
+class YAxis(Protocol):
+ y: int
+
+@runtime_checkable
+class Position(XAxis, YAxis, Protocol):
+ pass
+
+@runtime_checkable
+class Proto(Protocol):
+ attr: int
+ def meth(self, arg: str) -> int:
+ ...
+
+class Concrete(Proto):
+ pass
+
+class Other:
+ attr: int = 1
+ def meth(self, arg: str) -> int:
+ if arg == 'this':
+ return 1
+ return 0
+
+class NT(NamedTuple):
+ x: int
+ y: int
+
+@runtime_checkable
+class HasCallProtocol(Protocol):
+ __call__: typing.Callable
+
+
class ProtocolTests(BaseTestCase):
+ def test_basic_protocol(self):
+ @runtime_checkable
+ class P(Protocol):
+ def meth(self):
+ pass
+
+ class C: pass
+
+ class D:
+ def meth(self):
+ pass
+
+ def f():
+ pass
+
+ self.assertIsSubclass(D, P)
+ self.assertIsInstance(D(), P)
+ self.assertNotIsSubclass(C, P)
+ self.assertNotIsInstance(C(), P)
+ self.assertNotIsSubclass(types.FunctionType, P)
+ self.assertNotIsInstance(f, P)
+
+ def test_everything_implements_empty_protocol(self):
+ @runtime_checkable
+ class Empty(Protocol):
+ pass
+
+ class C:
+ pass
+
+ def f():
+ pass
+
+ for thing in (object, type, tuple, C, types.FunctionType):
+ self.assertIsSubclass(thing, Empty)
+ for thing in (object(), 1, (), typing, f):
+ self.assertIsInstance(thing, Empty)
+
+ def test_function_implements_protocol(self):
+ def f():
+ pass
+
+ self.assertIsInstance(f, HasCallProtocol)
+
+ def test_no_inheritance_from_nominal(self):
+ class C: pass
+
+ class BP(Protocol): pass
+
+ with self.assertRaises(TypeError):
+ class P(C, Protocol):
+ pass
+ with self.assertRaises(TypeError):
+ class P(Protocol, C):
+ pass
+ with self.assertRaises(TypeError):
+ class P(BP, C, Protocol):
+ pass
+
+ class D(BP, C): pass
+
+ class E(C, BP): pass
+
+ self.assertNotIsInstance(D(), E)
+ self.assertNotIsInstance(E(), D)
+
+ def test_no_instantiation(self):
+ class P(Protocol): pass
+
+ with self.assertRaises(TypeError):
+ P()
+
+ class C(P): pass
+
+ self.assertIsInstance(C(), C)
+ T = TypeVar('T')
+
+ class PG(Protocol[T]): pass
+
+ with self.assertRaises(TypeError):
+ PG()
+ with self.assertRaises(TypeError):
+ PG[int]()
+ with self.assertRaises(TypeError):
+ PG[T]()
+
+ class CG(PG[T]): pass
+
+ self.assertIsInstance(CG[int](), CG)
+
+ def test_cannot_instantiate_abstract(self):
+ @runtime_checkable
+ class P(Protocol):
+ @abc.abstractmethod
+ def ameth(self) -> int:
+ raise NotImplementedError
+
+ class B(P):
+ pass
+
+ class C(B):
+ def ameth(self) -> int:
+ return 26
+
+ with self.assertRaises(TypeError):
+ B()
+ self.assertIsInstance(C(), P)
+
+ def test_subprotocols_extending(self):
+ class P1(Protocol):
+ def meth1(self):
+ pass
+
+ @runtime_checkable
+ class P2(P1, Protocol):
+ def meth2(self):
+ pass
+
+ class C:
+ def meth1(self):
+ pass
+
+ def meth2(self):
+ pass
+
+ class C1:
+ def meth1(self):
+ pass
+
+ class C2:
+ def meth2(self):
+ pass
+
+ self.assertNotIsInstance(C1(), P2)
+ self.assertNotIsInstance(C2(), P2)
+ self.assertNotIsSubclass(C1, P2)
+ self.assertNotIsSubclass(C2, P2)
+ self.assertIsInstance(C(), P2)
+ self.assertIsSubclass(C, P2)
+
+ def test_subprotocols_merging(self):
+ class P1(Protocol):
+ def meth1(self):
+ pass
+
+ class P2(Protocol):
+ def meth2(self):
+ pass
+
+ @runtime_checkable
+ class P(P1, P2, Protocol):
+ pass
+
+ class C:
+ def meth1(self):
+ pass
+
+ def meth2(self):
+ pass
+
+ class C1:
+ def meth1(self):
+ pass
+
+ class C2:
+ def meth2(self):
+ pass
+
+ self.assertNotIsInstance(C1(), P)
+ self.assertNotIsInstance(C2(), P)
+ self.assertNotIsSubclass(C1, P)
+ self.assertNotIsSubclass(C2, P)
+ self.assertIsInstance(C(), P)
+ self.assertIsSubclass(C, P)
+
+ def test_protocols_issubclass(self):
+ T = TypeVar('T')
+
+ @runtime_checkable
+ class P(Protocol):
+ def x(self): ...
+
+ @runtime_checkable
+ class PG(Protocol[T]):
+ def x(self): ...
+
+ class BadP(Protocol):
+ def x(self): ...
+
+ class BadPG(Protocol[T]):
+ def x(self): ...
+
+ class C:
+ def x(self): ...
+
+ self.assertIsSubclass(C, P)
+ self.assertIsSubclass(C, PG)
+ self.assertIsSubclass(BadP, PG)
+
+ with self.assertRaises(TypeError):
+ issubclass(C, PG[T])
+ with self.assertRaises(TypeError):
+ issubclass(C, PG[C])
+ with self.assertRaises(TypeError):
+ issubclass(C, BadP)
+ with self.assertRaises(TypeError):
+ issubclass(C, BadPG)
+ with self.assertRaises(TypeError):
+ issubclass(P, PG[T])
+ with self.assertRaises(TypeError):
+ issubclass(PG, PG[int])
+
+ def test_protocols_issubclass_non_callable(self):
+ class C:
+ x = 1
+
+ @runtime_checkable
+ class PNonCall(Protocol):
+ x = 1
+
+ with self.assertRaises(TypeError):
+ issubclass(C, PNonCall)
+ self.assertIsInstance(C(), PNonCall)
+ PNonCall.register(C)
+ with self.assertRaises(TypeError):
+ issubclass(C, PNonCall)
+ self.assertIsInstance(C(), PNonCall)
+
+ # check that non-protocol subclasses are not affected
+ class D(PNonCall): ...
+
+ self.assertNotIsSubclass(C, D)
+ self.assertNotIsInstance(C(), D)
+ D.register(C)
+ self.assertIsSubclass(C, D)
+ self.assertIsInstance(C(), D)
+ with self.assertRaises(TypeError):
+ issubclass(D, PNonCall)
+
+ def test_protocols_isinstance(self):
+ T = TypeVar('T')
+
+ @runtime_checkable
+ class P(Protocol):
+ def meth(x): ...
+
+ @runtime_checkable
+ class PG(Protocol[T]):
+ def meth(x): ...
+
+ class BadP(Protocol):
+ def meth(x): ...
+
+ class BadPG(Protocol[T]):
+ def meth(x): ...
+
+ class C:
+ def meth(x): ...
+
+ self.assertIsInstance(C(), P)
+ self.assertIsInstance(C(), PG)
+ with self.assertRaises(TypeError):
+ isinstance(C(), PG[T])
+ with self.assertRaises(TypeError):
+ isinstance(C(), PG[C])
+ with self.assertRaises(TypeError):
+ isinstance(C(), BadP)
+ with self.assertRaises(TypeError):
+ isinstance(C(), BadPG)
+
+ def test_protocols_isinstance_py36(self):
+ class APoint:
+ def __init__(self, x, y, label):
+ self.x = x
+ self.y = y
+ self.label = label
+
+ class BPoint:
+ label = 'B'
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ class C:
+ def __init__(self, attr):
+ self.attr = attr
+
+ def meth(self, arg):
+ return 0
+
+ class Bad: pass
+
+ self.assertIsInstance(APoint(1, 2, 'A'), Point)
+ self.assertIsInstance(BPoint(1, 2), Point)
+ self.assertNotIsInstance(MyPoint(), Point)
+ self.assertIsInstance(BPoint(1, 2), Position)
+ self.assertIsInstance(Other(), Proto)
+ self.assertIsInstance(Concrete(), Proto)
+ self.assertIsInstance(C(42), Proto)
+ self.assertNotIsInstance(Bad(), Proto)
+ self.assertNotIsInstance(Bad(), Point)
+ self.assertNotIsInstance(Bad(), Position)
+ self.assertNotIsInstance(Bad(), Concrete)
+ self.assertNotIsInstance(Other(), Concrete)
+ self.assertIsInstance(NT(1, 2), Position)
+
+ def test_protocols_isinstance_init(self):
+ T = TypeVar('T')
+
+ @runtime_checkable
+ class P(Protocol):
+ x = 1
+
+ @runtime_checkable
+ class PG(Protocol[T]):
+ x = 1
+
+ class C:
+ def __init__(self, x):
+ self.x = x
+
+ self.assertIsInstance(C(1), P)
+ self.assertIsInstance(C(1), PG)
+
+ def test_protocols_support_register(self):
+ @runtime_checkable
+ class P(Protocol):
+ x = 1
+
+ class PM(Protocol):
+ def meth(self): pass
+
+ class D(PM): pass
+
+ class C: pass
+
+ D.register(C)
+ P.register(C)
+ self.assertIsInstance(C(), P)
+ self.assertIsInstance(C(), D)
+
+ def test_none_on_non_callable_doesnt_block_implementation(self):
+ @runtime_checkable
+ class P(Protocol):
+ x = 1
+
+ class A:
+ x = 1
+
+ class B(A):
+ x = None
+
+ class C:
+ def __init__(self):
+ self.x = None
+
+ self.assertIsInstance(B(), P)
+ self.assertIsInstance(C(), P)
+
+ def test_none_on_callable_blocks_implementation(self):
+ @runtime_checkable
+ class P(Protocol):
+ def x(self): ...
+
+ class A:
+ def x(self): ...
+
+ class B(A):
+ x = None
+
+ class C:
+ def __init__(self):
+ self.x = None
+
+ self.assertNotIsInstance(B(), P)
+ self.assertNotIsInstance(C(), P)
+
+ def test_non_protocol_subclasses(self):
+ class P(Protocol):
+ x = 1
+
+ @runtime_checkable
+ class PR(Protocol):
+ def meth(self): pass
+
+ class NonP(P):
+ x = 1
+
+ class NonPR(PR): pass
+
+ class C:
+ x = 1
+
+ class D:
+ def meth(self): pass
+
+ self.assertNotIsInstance(C(), NonP)
+ self.assertNotIsInstance(D(), NonPR)
+ self.assertNotIsSubclass(C, NonP)
+ self.assertNotIsSubclass(D, NonPR)
+ self.assertIsInstance(NonPR(), PR)
+ self.assertIsSubclass(NonPR, PR)
+
+ def test_custom_subclasshook(self):
+ class P(Protocol):
+ x = 1
+
+ class OKClass: pass
+
+ class BadClass:
+ x = 1
+
+ class C(P):
+ @classmethod
+ def __subclasshook__(cls, other):
+ return other.__name__.startswith("OK")
+
+ self.assertIsInstance(OKClass(), C)
+ self.assertNotIsInstance(BadClass(), C)
+ self.assertIsSubclass(OKClass, C)
+ self.assertNotIsSubclass(BadClass, C)
+
+ def test_issubclass_fails_correctly(self):
+ @runtime_checkable
+ class P(Protocol):
+ x = 1
+
+ class C: pass
+
+ with self.assertRaises(TypeError):
+ issubclass(C(), P)
+
+ def test_defining_generic_protocols(self):
+ T = TypeVar('T')
+ S = TypeVar('S')
+
+ @runtime_checkable
+ class PR(Protocol[T, S]):
+ def meth(self): pass
+
+ class P(PR[int, T], Protocol[T]):
+ y = 1
+
+ with self.assertRaises(TypeError):
+ PR[int]
+ with self.assertRaises(TypeError):
+ P[int, str]
+ with self.assertRaises(TypeError):
+ PR[int, 1]
+ with self.assertRaises(TypeError):
+ PR[int, ClassVar]
+
+ class C(PR[int, T]): pass
+
+ self.assertIsInstance(C[str](), C)
+
+ def test_defining_generic_protocols_old_style(self):
+ T = TypeVar('T')
+ S = TypeVar('S')
+
+ @runtime_checkable
+ class PR(Protocol, Generic[T, S]):
+ def meth(self): pass
+
+ class P(PR[int, str], Protocol):
+ y = 1
+
+ with self.assertRaises(TypeError):
+ issubclass(PR[int, str], PR)
+ self.assertIsSubclass(P, PR)
+ with self.assertRaises(TypeError):
+ PR[int]
+ with self.assertRaises(TypeError):
+ PR[int, 1]
+
+ class P1(Protocol, Generic[T]):
+ def bar(self, x: T) -> str: ...
+
+ class P2(Generic[T], Protocol):
+ def bar(self, x: T) -> str: ...
+
+ @runtime_checkable
+ class PSub(P1[str], Protocol):
+ x = 1
+
+ class Test:
+ x = 1
+
+ def bar(self, x: str) -> str:
+ return x
+
+ self.assertIsInstance(Test(), PSub)
+ with self.assertRaises(TypeError):
+ PR[int, ClassVar]
+
+ def test_init_called(self):
+ T = TypeVar('T')
+
+ class P(Protocol[T]): pass
+
+ class C(P[T]):
+ def __init__(self):
+ self.test = 'OK'
+
+ self.assertEqual(C[int]().test, 'OK')
+
+ def test_protocols_bad_subscripts(self):
+ T = TypeVar('T')
+ S = TypeVar('S')
+ with self.assertRaises(TypeError):
+ class P(Protocol[T, T]): pass
+ with self.assertRaises(TypeError):
+ class P(Protocol[int]): pass
+ with self.assertRaises(TypeError):
+ class P(Protocol[T], Protocol[S]): pass
+ with self.assertRaises(TypeError):
+ class P(typing.Mapping[T, S], Protocol[T]): pass
+
+ def test_generic_protocols_repr(self):
+ T = TypeVar('T')
+ S = TypeVar('S')
+
+ class P(Protocol[T, S]): pass
+
+ self.assertTrue(repr(P[T, S]).endswith('P[~T, ~S]'))
+ self.assertTrue(repr(P[int, str]).endswith('P[int, str]'))
+
+ def test_generic_protocols_eq(self):
+ T = TypeVar('T')
+ S = TypeVar('S')
+
+ class P(Protocol[T, S]): pass
+
+ self.assertEqual(P, P)
+ self.assertEqual(P[int, T], P[int, T])
+ self.assertEqual(P[T, T][Tuple[T, S]][int, str],
+ P[Tuple[int, str], Tuple[int, str]])
+
+ def test_generic_protocols_special_from_generic(self):
+ T = TypeVar('T')
+
+ class P(Protocol[T]): pass
+
+ self.assertEqual(P.__parameters__, (T,))
+ self.assertEqual(P[int].__parameters__, ())
+ self.assertEqual(P[int].__args__, (int,))
+ self.assertIs(P[int].__origin__, P)
+
+ def test_generic_protocols_special_from_protocol(self):
+ @runtime_checkable
+ class PR(Protocol):
+ x = 1
+
+ class P(Protocol):
+ def meth(self):
+ pass
+
+ T = TypeVar('T')
+
+ class PG(Protocol[T]):
+ x = 1
+
+ def meth(self):
+ pass
+
+ self.assertTrue(P._is_protocol)
+ self.assertTrue(PR._is_protocol)
+ self.assertTrue(PG._is_protocol)
+ self.assertFalse(P._is_runtime_protocol)
+ self.assertTrue(PR._is_runtime_protocol)
+ self.assertTrue(PG[int]._is_protocol)
+ self.assertEqual(typing._get_protocol_attrs(P), {'meth'})
+ self.assertEqual(typing._get_protocol_attrs(PR), {'x'})
+ self.assertEqual(frozenset(typing._get_protocol_attrs(PG)),
+ frozenset({'x', 'meth'}))
+
+ def test_no_runtime_deco_on_nominal(self):
+ with self.assertRaises(TypeError):
+ @runtime_checkable
+ class C: pass
+
+ class Proto(Protocol):
+ x = 1
+
+ with self.assertRaises(TypeError):
+ @runtime_checkable
+ class Concrete(Proto):
+ pass
+
+ def test_none_treated_correctly(self):
+ @runtime_checkable
+ class P(Protocol):
+ x = None # type: int
+
+ class B(object): pass
+
+ self.assertNotIsInstance(B(), P)
+
+ class C:
+ x = 1
+
+ class D:
+ x = None
+
+ self.assertIsInstance(C(), P)
+ self.assertIsInstance(D(), P)
+
+ class CI:
+ def __init__(self):
+ self.x = 1
+
+ class DI:
+ def __init__(self):
+ self.x = None
+
+ self.assertIsInstance(C(), P)
+ self.assertIsInstance(D(), P)
+
+ def test_protocols_in_unions(self):
+ class P(Protocol):
+ x = None # type: int
+
+ Alias = typing.Union[typing.Iterable, P]
+ Alias2 = typing.Union[P, typing.Iterable]
+ self.assertEqual(Alias, Alias2)
+
+ def test_protocols_pickleable(self):
+ global P, CP # pickle wants to reference the class by name
+ T = TypeVar('T')
+
+ @runtime_checkable
+ class P(Protocol[T]):
+ x = 1
+
+ class CP(P[int]):
+ pass
+
+ c = CP()
+ c.foo = 42
+ c.bar = 'abc'
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ z = pickle.dumps(c, proto)
+ x = pickle.loads(z)
+ self.assertEqual(x.foo, 42)
+ self.assertEqual(x.bar, 'abc')
+ self.assertEqual(x.x, 1)
+ self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'})
+ s = pickle.dumps(P)
+ D = pickle.loads(s)
+
+ class E:
+ x = 1
+
+ self.assertIsInstance(E(), D)
def test_supports_int(self):
self.assertIsSubclass(int, typing.SupportsInt)
@@ -634,9 +1338,8 @@ class ProtocolTests(BaseTestCase):
self.assertIsSubclass(int, typing.SupportsIndex)
self.assertNotIsSubclass(str, typing.SupportsIndex)
- def test_protocol_instance_type_error(self):
- with self.assertRaises(TypeError):
- isinstance(0, typing.SupportsAbs)
+ def test_bundled_protocol_instance_works(self):
+ self.assertIsInstance(0, typing.SupportsAbs)
class C1(typing.SupportsInt):
def __int__(self) -> int:
return 42
@@ -645,6 +1348,20 @@ class ProtocolTests(BaseTestCase):
c = C2()
self.assertIsInstance(c, C1)
+ def test_collections_protocols_allowed(self):
+ @runtime_checkable
+ class Custom(collections.abc.Iterable, Protocol):
+ def close(self): ...
+
+ class A: pass
+ class B:
+ def __iter__(self):
+ return []
+ def close(self):
+ return 0
+
+ self.assertIsSubclass(B, Custom)
+ self.assertNotIsSubclass(A, Custom)
class GenericTests(BaseTestCase):
@@ -771,7 +1488,7 @@ class GenericTests(BaseTestCase):
def test_new_repr_bare(self):
T = TypeVar('T')
self.assertEqual(repr(Generic[T]), 'typing.Generic[~T]')
- self.assertEqual(repr(typing._Protocol[T]), 'typing._Protocol[~T]')
+ self.assertEqual(repr(typing.Protocol[T]), 'typing.Protocol[~T]')
class C(typing.Dict[Any, Any]): ...
# this line should just work
repr(C.__mro__)
@@ -1067,7 +1784,7 @@ class GenericTests(BaseTestCase):
with self.assertRaises(TypeError):
Tuple[Generic[T]]
with self.assertRaises(TypeError):
- List[typing._Protocol]
+ List[typing.Protocol]
def test_type_erasure_special(self):
T = TypeVar('T')
diff --git a/Lib/typing.py b/Lib/typing.py
index d3e84cd..14bd06b 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -9,8 +9,7 @@ At large scale, the structure of the module is following:
* The core of internal generics API: _GenericAlias and _VariadicGenericAlias, the latter is
currently only used by Tuple and Callable. All subscripted types like X[int], Union[int, str],
etc., are instances of either of these classes.
-* The public counterpart of the generics API consists of two classes: Generic and Protocol
- (the latter is currently private, but will be made public after PEP 544 acceptance).
+* The public counterpart of the generics API consists of two classes: Generic and Protocol.
* Public helper functions: get_type_hints, overload, cast, no_type_check,
no_type_check_decorator.
* Generic aliases for collections.abc ABCs and few additional protocols.
@@ -18,7 +17,7 @@ At large scale, the structure of the module is following:
* Wrapper submodules for re and io related types.
"""
-from abc import abstractmethod, abstractproperty
+from abc import abstractmethod, abstractproperty, ABCMeta
import collections
import collections.abc
import contextlib
@@ -39,6 +38,7 @@ __all__ = [
'Generic',
'Literal',
'Optional',
+ 'Protocol',
'Tuple',
'Type',
'TypeVar',
@@ -102,6 +102,7 @@ __all__ = [
'no_type_check_decorator',
'NoReturn',
'overload',
+ 'runtime_checkable',
'Text',
'TYPE_CHECKING',
]
@@ -123,7 +124,7 @@ def _type_check(arg, msg, is_argument=True):
We append the repr() of the actual value (truncated to 100 chars).
"""
- invalid_generic_forms = (Generic, _Protocol)
+ invalid_generic_forms = (Generic, Protocol)
if is_argument:
invalid_generic_forms = invalid_generic_forms + (ClassVar, Final)
@@ -135,7 +136,7 @@ def _type_check(arg, msg, is_argument=True):
arg.__origin__ in invalid_generic_forms):
raise TypeError(f"{arg} is not valid as type argument")
if (isinstance(arg, _SpecialForm) and arg not in (Any, NoReturn) or
- arg in (Generic, _Protocol)):
+ arg in (Generic, Protocol)):
raise TypeError(f"Plain {arg} is not valid as type argument")
if isinstance(arg, (type, TypeVar, ForwardRef)):
return arg
@@ -665,8 +666,8 @@ class _GenericAlias(_Final, _root=True):
@_tp_cache
def __getitem__(self, params):
- if self.__origin__ in (Generic, _Protocol):
- # Can't subscript Generic[...] or _Protocol[...].
+ if self.__origin__ in (Generic, Protocol):
+ # Can't subscript Generic[...] or Protocol[...].
raise TypeError(f"Cannot subscript already-subscripted {self}")
if not isinstance(params, tuple):
params = (params,)
@@ -733,6 +734,8 @@ class _GenericAlias(_Final, _root=True):
res.append(Generic)
return tuple(res)
if self.__origin__ is Generic:
+ if Protocol in bases:
+ return ()
i = bases.index(self)
for b in bases[i+1:]:
if isinstance(b, _GenericAlias) and b is not self:
@@ -850,10 +853,11 @@ class Generic:
return default
"""
__slots__ = ()
+ _is_protocol = False
def __new__(cls, *args, **kwds):
- if cls is Generic:
- raise TypeError("Type Generic cannot be instantiated; "
+ if cls in (Generic, Protocol):
+ raise TypeError(f"Type {cls.__name__} cannot be instantiated; "
"it can be used only as a base class")
if super().__new__ is object.__new__ and cls.__init__ is not object.__init__:
obj = super().__new__(cls)
@@ -870,17 +874,14 @@ class Generic:
f"Parameter list to {cls.__qualname__}[...] cannot be empty")
msg = "Parameters to generic types must be types."
params = tuple(_type_check(p, msg) for p in params)
- if cls is Generic:
- # Generic can only be subscripted with unique type variables.
+ if cls in (Generic, Protocol):
+ # Generic and Protocol can only be subscripted with unique type variables.
if not all(isinstance(p, TypeVar) for p in params):
raise TypeError(
- "Parameters to Generic[...] must all be type variables")
+ f"Parameters to {cls.__name__}[...] must all be type variables")
if len(set(params)) != len(params):
raise TypeError(
- "Parameters to Generic[...] must all be unique")
- elif cls is _Protocol:
- # _Protocol is internal at the moment, just skip the check
- pass
+ f"Parameters to {cls.__name__}[...] must all be unique")
else:
# Subscripting a regular Generic subclass.
_check_generic(cls, params)
@@ -892,7 +893,7 @@ class Generic:
if '__orig_bases__' in cls.__dict__:
error = Generic in cls.__orig_bases__
else:
- error = Generic in cls.__bases__ and cls.__name__ != '_Protocol'
+ error = Generic in cls.__bases__ and cls.__name__ != 'Protocol'
if error:
raise TypeError("Cannot inherit from plain Generic")
if '__orig_bases__' in cls.__dict__:
@@ -910,9 +911,7 @@ class Generic:
raise TypeError(
"Cannot inherit from Generic[...] multiple types.")
gvars = base.__parameters__
- if gvars is None:
- gvars = tvars
- else:
+ if gvars is not None:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
@@ -935,6 +934,204 @@ class _TypingEllipsis:
"""Internal placeholder for ... (ellipsis)."""
+_TYPING_INTERNALS = ['__parameters__', '__orig_bases__', '__orig_class__',
+ '_is_protocol', '_is_runtime_protocol']
+
+_SPECIAL_NAMES = ['__abstractmethods__', '__annotations__', '__dict__', '__doc__',
+ '__init__', '__module__', '__new__', '__slots__',
+ '__subclasshook__', '__weakref__']
+
+# These special attributes will be not collected as protocol members.
+EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS + _SPECIAL_NAMES + ['_MutableMapping__marker']
+
+
+def _get_protocol_attrs(cls):
+ """Collect protocol members from a protocol class objects.
+
+ This includes names actually defined in the class dictionary, as well
+ as names that appear in annotations. Special names (above) are skipped.
+ """
+ attrs = set()
+ for base in cls.__mro__[:-1]: # without object
+ if base.__name__ in ('Protocol', 'Generic'):
+ continue
+ annotations = getattr(base, '__annotations__', {})
+ for attr in list(base.__dict__.keys()) + list(annotations.keys()):
+ if not attr.startswith('_abc_') and attr not in EXCLUDED_ATTRIBUTES:
+ attrs.add(attr)
+ return attrs
+
+
+def _is_callable_members_only(cls):
+ # PEP 544 prohibits using issubclass() with protocols that have non-method members.
+ return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))
+
+
+def _no_init(self, *args, **kwargs):
+ if type(self)._is_protocol:
+ raise TypeError('Protocols cannot be instantiated')
+
+
+def _allow_reckless_class_cheks():
+ """Allow instnance and class checks for special stdlib modules.
+
+ The abc and functools modules indiscriminately call isinstance() and
+ issubclass() on the whole MRO of a user class, which may contain protocols.
+ """
+ try:
+ return sys._getframe(3).f_globals['__name__'] in ['abc', 'functools']
+ except (AttributeError, ValueError): # For platforms without _getframe().
+ return True
+
+
+_PROTO_WHITELIST = ['Callable', 'Awaitable',
+ 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator',
+ 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible',
+ 'ContextManager', 'AsyncContextManager']
+
+
+class _ProtocolMeta(ABCMeta):
+ # This metaclass is really unfortunate and exists only because of
+ # the lack of __instancehook__.
+ def __instancecheck__(cls, instance):
+ # We need this method for situations where attributes are
+ # assigned in __init__.
+ if ((not getattr(cls, '_is_protocol', False) or
+ _is_callable_members_only(cls)) and
+ issubclass(instance.__class__, cls)):
+ return True
+ if cls._is_protocol:
+ if all(hasattr(instance, attr) and
+ # All *methods* can be blocked by setting them to None.
+ (not callable(getattr(cls, attr, None)) or
+ getattr(instance, attr) is not None)
+ for attr in _get_protocol_attrs(cls)):
+ return True
+ return super().__instancecheck__(instance)
+
+
+class Protocol(Generic, metaclass=_ProtocolMeta):
+ """Base class for protocol classes.
+
+ Protocol classes are defined as::
+
+ class Proto(Protocol):
+ def meth(self) -> int:
+ ...
+
+ Such classes are primarily used with static type checkers that recognize
+ structural subtyping (static duck-typing), for example::
+
+ class C:
+ def meth(self) -> int:
+ return 0
+
+ def func(x: Proto) -> int:
+ return x.meth()
+
+ func(C()) # Passes static type check
+
+ See PEP 544 for details. Protocol classes decorated with
+ @typing.runtime_checkable act as simple-minded runtime protocols that check
+ only the presence of given attributes, ignoring their type signatures.
+ Protocol classes can be generic, they are defined as::
+
+ class GenProto(Protocol[T]):
+ def meth(self) -> T:
+ ...
+ """
+ __slots__ = ()
+ _is_protocol = True
+ _is_runtime_protocol = False
+
+ def __init_subclass__(cls, *args, **kwargs):
+ super().__init_subclass__(*args, **kwargs)
+
+ # Determine if this is a protocol or a concrete subclass.
+ if not cls.__dict__.get('_is_protocol', False):
+ cls._is_protocol = any(b is Protocol for b in cls.__bases__)
+
+ # Set (or override) the protocol subclass hook.
+ def _proto_hook(other):
+ if not cls.__dict__.get('_is_protocol', False):
+ return NotImplemented
+
+ # First, perform various sanity checks.
+ if not getattr(cls, '_is_runtime_protocol', False):
+ if _allow_reckless_class_cheks():
+ return NotImplemented
+ raise TypeError("Instance and class checks can only be used with"
+ " @runtime_checkable protocols")
+ if not _is_callable_members_only(cls):
+ if _allow_reckless_class_cheks():
+ return NotImplemented
+ raise TypeError("Protocols with non-method members"
+ " don't support issubclass()")
+ if not isinstance(other, type):
+ # Same error message as for issubclass(1, int).
+ raise TypeError('issubclass() arg 1 must be a class')
+
+ # Second, perform the actual structural compatibility check.
+ for attr in _get_protocol_attrs(cls):
+ for base in other.__mro__:
+ # Check if the members appears in the class dictionary...
+ if attr in base.__dict__:
+ if base.__dict__[attr] is None:
+ return NotImplemented
+ break
+
+ # ...or in annotations, if it is a sub-protocol.
+ annotations = getattr(base, '__annotations__', {})
+ if (isinstance(annotations, collections.abc.Mapping) and
+ attr in annotations and
+ issubclass(other, Generic) and other._is_protocol):
+ break
+ else:
+ return NotImplemented
+ return True
+
+ if '__subclasshook__' not in cls.__dict__:
+ cls.__subclasshook__ = _proto_hook
+
+ # We have nothing more to do for non-protocols...
+ if not cls._is_protocol:
+ return
+
+ # ... otherwise check consistency of bases, and prohibit instantiation.
+ for base in cls.__bases__:
+ if not (base in (object, Generic) or
+ base.__module__ == 'collections.abc' and base.__name__ in _PROTO_WHITELIST or
+ issubclass(base, Generic) and base._is_protocol):
+ raise TypeError('Protocols can only inherit from other'
+ ' protocols, got %r' % base)
+ cls.__init__ = _no_init
+
+
+def runtime_checkable(cls):
+ """Mark a protocol class as a runtime protocol.
+
+ Such protocol can be used with isinstance() and issubclass().
+ Raise TypeError if applied to a non-protocol class.
+ This allows a simple-minded structural check very similar to
+ one trick ponies in collections.abc such as Iterable.
+ For example::
+
+ @runtime_checkable
+ class Closable(Protocol):
+ def close(self): ...
+
+ assert isinstance(open('/some/file'), Closable)
+
+ Warning: this will check only the presence of the required methods,
+ not their type signatures!
+ """
+ if not issubclass(cls, Generic) or not cls._is_protocol:
+ raise TypeError('@runtime_checkable can be only applied to protocol classes,'
+ ' got %r' % cls)
+ cls._is_runtime_protocol = True
+ return cls
+
+
def cast(typ, val):
"""Cast a value to a type.
@@ -1159,90 +1356,6 @@ def final(f):
return f
-class _ProtocolMeta(type):
- """Internal metaclass for _Protocol.
-
- This exists so _Protocol classes can be generic without deriving
- from Generic.
- """
-
- def __instancecheck__(self, obj):
- if _Protocol not in self.__bases__:
- return super().__instancecheck__(obj)
- raise TypeError("Protocols cannot be used with isinstance().")
-
- def __subclasscheck__(self, cls):
- if not self._is_protocol:
- # No structural checks since this isn't a protocol.
- return NotImplemented
-
- if self is _Protocol:
- # Every class is a subclass of the empty protocol.
- return True
-
- # Find all attributes defined in the protocol.
- attrs = self._get_protocol_attrs()
-
- for attr in attrs:
- if not any(attr in d.__dict__ for d in cls.__mro__):
- return False
- return True
-
- def _get_protocol_attrs(self):
- # Get all Protocol base classes.
- protocol_bases = []
- for c in self.__mro__:
- if getattr(c, '_is_protocol', False) and c.__name__ != '_Protocol':
- protocol_bases.append(c)
-
- # Get attributes included in protocol.
- attrs = set()
- for base in protocol_bases:
- for attr in base.__dict__.keys():
- # Include attributes not defined in any non-protocol bases.
- for c in self.__mro__:
- if (c is not base and attr in c.__dict__ and
- not getattr(c, '_is_protocol', False)):
- break
- else:
- if (not attr.startswith('_abc_') and
- attr != '__abstractmethods__' and
- attr != '__annotations__' and
- attr != '__weakref__' and
- attr != '_is_protocol' and
- attr != '_gorg' and
- attr != '__dict__' and
- attr != '__args__' and
- attr != '__slots__' and
- attr != '_get_protocol_attrs' and
- attr != '__next_in_mro__' and
- attr != '__parameters__' and
- attr != '__origin__' and
- attr != '__orig_bases__' and
- attr != '__extra__' and
- attr != '__tree_hash__' and
- attr != '__module__'):
- attrs.add(attr)
-
- return attrs
-
-
-class _Protocol(Generic, metaclass=_ProtocolMeta):
- """Internal base class for protocol classes.
-
- This implements a simple-minded structural issubclass check
- (similar but more general than the one-offs in collections.abc
- such as Hashable).
- """
-
- __slots__ = ()
-
- _is_protocol = True
-
- def __class_getitem__(cls, params):
- return super().__class_getitem__(params)
-
-
# Some unconstrained type variables. These are used by the container types.
# (These are not for export.)
T = TypeVar('T') # Any type.
@@ -1347,7 +1460,8 @@ Type.__doc__ = \
"""
-class SupportsInt(_Protocol):
+@runtime_checkable
+class SupportsInt(Protocol):
__slots__ = ()
@abstractmethod
@@ -1355,7 +1469,8 @@ class SupportsInt(_Protocol):
pass
-class SupportsFloat(_Protocol):
+@runtime_checkable
+class SupportsFloat(Protocol):
__slots__ = ()
@abstractmethod
@@ -1363,7 +1478,8 @@ class SupportsFloat(_Protocol):
pass
-class SupportsComplex(_Protocol):
+@runtime_checkable
+class SupportsComplex(Protocol):
__slots__ = ()
@abstractmethod
@@ -1371,7 +1487,8 @@ class SupportsComplex(_Protocol):
pass
-class SupportsBytes(_Protocol):
+@runtime_checkable
+class SupportsBytes(Protocol):
__slots__ = ()
@abstractmethod
@@ -1379,7 +1496,8 @@ class SupportsBytes(_Protocol):
pass
-class SupportsIndex(_Protocol):
+@runtime_checkable
+class SupportsIndex(Protocol):
__slots__ = ()
@abstractmethod
@@ -1387,7 +1505,8 @@ class SupportsIndex(_Protocol):
pass
-class SupportsAbs(_Protocol[T_co]):
+@runtime_checkable
+class SupportsAbs(Protocol[T_co]):
__slots__ = ()
@abstractmethod
@@ -1395,7 +1514,8 @@ class SupportsAbs(_Protocol[T_co]):
pass
-class SupportsRound(_Protocol[T_co]):
+@runtime_checkable
+class SupportsRound(Protocol[T_co]):
__slots__ = ()
@abstractmethod