summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNikita Sobolev <mail@sobolevn.me>2023-12-10 15:21:20 (GMT)
committerGitHub <noreply@github.com>2023-12-10 15:21:20 (GMT)
commit9d02d3451a61521c65db6f93596ece2f572f1f3e (patch)
tree8ac8b45431f38d94c2e39e5d169f612731a48a4a
parent7595d47722ae359e6642506646640a3f86816cef (diff)
downloadcpython-9d02d3451a61521c65db6f93596ece2f572f1f3e.zip
cpython-9d02d3451a61521c65db6f93596ece2f572f1f3e.tar.gz
cpython-9d02d3451a61521c65db6f93596ece2f572f1f3e.tar.bz2
gh-110686: Test pattern matching with `runtime_checkable` protocols (#110687)
-rw-r--r--Lib/test/test_patma.py155
1 files changed, 155 insertions, 0 deletions
diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py
index dedbc82..298e78c 100644
--- a/Lib/test/test_patma.py
+++ b/Lib/test/test_patma.py
@@ -2760,6 +2760,132 @@ class TestPatma(unittest.TestCase):
self.assertEqual(y, 1)
self.assertIs(z, x)
+ def test_patma_runtime_checkable_protocol(self):
+ # Runtime-checkable protocol
+ from typing import Protocol, runtime_checkable
+
+ @runtime_checkable
+ class P(Protocol):
+ x: int
+ y: int
+
+ class A:
+ def __init__(self, x: int, y: int):
+ self.x = x
+ self.y = y
+
+ class B(A): ...
+
+ for cls in (A, B):
+ with self.subTest(cls=cls.__name__):
+ inst = cls(1, 2)
+ w = 0
+ match inst:
+ case P() as p:
+ self.assertIsInstance(p, cls)
+ self.assertEqual(p.x, 1)
+ self.assertEqual(p.y, 2)
+ w = 1
+ self.assertEqual(w, 1)
+
+ q = 0
+ match inst:
+ case P(x=x, y=y):
+ self.assertEqual(x, 1)
+ self.assertEqual(y, 2)
+ q = 1
+ self.assertEqual(q, 1)
+
+
+ def test_patma_generic_protocol(self):
+ # Runtime-checkable generic protocol
+ from typing import Generic, TypeVar, Protocol, runtime_checkable
+
+ T = TypeVar('T') # not using PEP695 to be able to backport changes
+
+ @runtime_checkable
+ class P(Protocol[T]):
+ a: T
+ b: T
+
+ class A:
+ def __init__(self, x: int, y: int):
+ self.x = x
+ self.y = y
+
+ class G(Generic[T]):
+ def __init__(self, x: T, y: T):
+ self.x = x
+ self.y = y
+
+ for cls in (A, G):
+ with self.subTest(cls=cls.__name__):
+ inst = cls(1, 2)
+ w = 0
+ match inst:
+ case P():
+ w = 1
+ self.assertEqual(w, 0)
+
+ def test_patma_protocol_with_match_args(self):
+ # Runtime-checkable protocol with `__match_args__`
+ from typing import Protocol, runtime_checkable
+
+ # Used to fail before
+ # https://github.com/python/cpython/issues/110682
+ @runtime_checkable
+ class P(Protocol):
+ __match_args__ = ('x', 'y')
+ x: int
+ y: int
+
+ class A:
+ def __init__(self, x: int, y: int):
+ self.x = x
+ self.y = y
+
+ class B(A): ...
+
+ for cls in (A, B):
+ with self.subTest(cls=cls.__name__):
+ inst = cls(1, 2)
+ w = 0
+ match inst:
+ case P() as p:
+ self.assertIsInstance(p, cls)
+ self.assertEqual(p.x, 1)
+ self.assertEqual(p.y, 2)
+ w = 1
+ self.assertEqual(w, 1)
+
+ q = 0
+ match inst:
+ case P(x=x, y=y):
+ self.assertEqual(x, 1)
+ self.assertEqual(y, 2)
+ q = 1
+ self.assertEqual(q, 1)
+
+ j = 0
+ match inst:
+ case P(x=1, y=2):
+ j = 1
+ self.assertEqual(j, 1)
+
+ g = 0
+ match inst:
+ case P(x, y):
+ self.assertEqual(x, 1)
+ self.assertEqual(y, 2)
+ g = 1
+ self.assertEqual(g, 1)
+
+ h = 0
+ match inst:
+ case P(1, 2):
+ h = 1
+ self.assertEqual(h, 1)
+
class TestSyntaxErrors(unittest.TestCase):
@@ -3198,6 +3324,35 @@ class TestTypeErrors(unittest.TestCase):
w = 0
self.assertIsNone(w)
+ def test_regular_protocol(self):
+ from typing import Protocol
+ class P(Protocol): ...
+ msg = (
+ 'Instance and class checks can only be used '
+ 'with @runtime_checkable protocols'
+ )
+ w = None
+ with self.assertRaisesRegex(TypeError, msg):
+ match 1:
+ case P():
+ w = 0
+ self.assertIsNone(w)
+
+ def test_positional_patterns_with_regular_protocol(self):
+ from typing import Protocol
+ class P(Protocol):
+ x: int # no `__match_args__`
+ y: int
+ class A:
+ x = 1
+ y = 2
+ w = None
+ with self.assertRaises(TypeError):
+ match A():
+ case P(x, y):
+ w = 0
+ self.assertIsNone(w)
+
class TestValueErrors(unittest.TestCase):