summaryrefslogtreecommitdiffstats
path: root/Lib/test
diff options
context:
space:
mode:
authorDuprat <yduprat@gmail.com>2022-03-25 22:01:21 (GMT)
committerGitHub <noreply@github.com>2022-03-25 22:01:21 (GMT)
commitd03acd7270d66ddb8e987f9743405147ecc15087 (patch)
treecffe25f0c26d55aef28c910dcf825747da99a6d4 /Lib/test
parent20e6e5636a06fe5e1472062918d0a302d82a71c3 (diff)
downloadcpython-d03acd7270d66ddb8e987f9743405147ecc15087.zip
cpython-d03acd7270d66ddb8e987f9743405147ecc15087.tar.gz
cpython-d03acd7270d66ddb8e987f9743405147ecc15087.tar.bz2
bpo-43352: Add a Barrier object in asyncio lib (GH-24903)
Co-authored-by: Yury Selivanov <yury@edgedb.com> Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
Diffstat (limited to 'Lib/test')
-rw-r--r--Lib/test/test_asyncio/test_locks.py578
1 files changed, 576 insertions, 2 deletions
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
index 920b3b5..415cbe5 100644
--- a/Lib/test/test_asyncio/test_locks.py
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -1,4 +1,4 @@
-"""Tests for lock.py"""
+"""Tests for locks.py"""
import unittest
from unittest import mock
@@ -9,7 +9,10 @@ import asyncio
STR_RGX_REPR = (
r'^<(?P<class>.*?) object at (?P<address>.*?)'
r'\[(?P<extras>'
- r'(set|unset|locked|unlocked)(, value:\d)?(, waiters:\d+)?'
+ r'(set|unset|locked|unlocked|filling|draining|resetting|broken)'
+ r'(, value:\d)?'
+ r'(, waiters:\d+)?'
+ r'(, waiters:\d+\/\d+)?' # barrier
r')\]>\Z'
)
RGX_REPR = re.compile(STR_RGX_REPR)
@@ -943,5 +946,576 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
)
+class BarrierTests(unittest.IsolatedAsyncioTestCase):
+
+ async def asyncSetUp(self):
+ await super().asyncSetUp()
+ self.N = 5
+
+ def make_tasks(self, n, coro):
+ tasks = [asyncio.create_task(coro()) for _ in range(n)]
+ return tasks
+
+ async def gather_tasks(self, n, coro):
+ tasks = self.make_tasks(n, coro)
+ res = await asyncio.gather(*tasks)
+ return res, tasks
+
+ async def test_barrier(self):
+ barrier = asyncio.Barrier(self.N)
+ self.assertIn("filling", repr(barrier))
+ with self.assertRaisesRegex(
+ TypeError,
+ "object Barrier can't be used in 'await' expression",
+ ):
+ await barrier
+
+ self.assertIn("filling", repr(barrier))
+
+ async def test_repr(self):
+ barrier = asyncio.Barrier(self.N)
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("filling", repr(barrier))
+
+ waiters = []
+ async def wait(barrier):
+ await barrier.wait()
+
+ incr = 2
+ for i in range(incr):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertTrue(f"waiters:{incr}/{self.N}" in repr(barrier))
+ self.assertIn("filling", repr(barrier))
+
+ # create missing waiters
+ for i in range(barrier.parties - barrier.n_waiting):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("draining", repr(barrier))
+
+ # add a part of waiters
+ for i in range(incr):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+ # and reset
+ await barrier.reset()
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("resetting", repr(barrier))
+
+ # add a part of waiters again
+ for i in range(incr):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+ # and abort
+ await barrier.abort()
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("broken", repr(barrier))
+ self.assertTrue(barrier.broken)
+
+ # suppress unhandled exceptions
+ await asyncio.gather(*waiters, return_exceptions=True)
+
+ async def test_barrier_parties(self):
+ self.assertRaises(ValueError, lambda: asyncio.Barrier(0))
+ self.assertRaises(ValueError, lambda: asyncio.Barrier(-4))
+
+ self.assertIsInstance(asyncio.Barrier(self.N), asyncio.Barrier)
+
+ async def test_context_manager(self):
+ self.N = 3
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ async with barrier as i:
+ results.append(i)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertListEqual(sorted(results), list(range(self.N)))
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_one_task(self):
+ barrier = asyncio.Barrier(1)
+
+ async def f():
+ async with barrier as i:
+ return True
+
+ ret = await f()
+
+ self.assertTrue(ret)
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_one_task_twice(self):
+ barrier = asyncio.Barrier(1)
+
+ t1 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 0)
+
+ t2 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+
+ self.assertEqual(t1.result(), t2.result())
+ self.assertEqual(t1.done(), t2.done())
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_task_by_task(self):
+ self.N = 3
+ barrier = asyncio.Barrier(self.N)
+
+ t1 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 1)
+ self.assertIn("filling", repr(barrier))
+
+ t2 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 2)
+ self.assertIn("filling", repr(barrier))
+
+ t3 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+
+ await asyncio.wait([t1, t2, t3])
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_tasks_wait_twice(self):
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ async with barrier:
+ results.append(True)
+
+ async with barrier:
+ results.append(False)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(len(results), self.N*2)
+ self.assertEqual(results.count(True), self.N)
+ self.assertEqual(results.count(False), self.N)
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_tasks_check_return_value(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+
+ async def coro():
+ async with barrier:
+ results1.append(True)
+
+ async with barrier as i:
+ results2.append(True)
+ return i
+
+ res, _ = await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(len(results1), self.N)
+ self.assertTrue(all(results1))
+ self.assertEqual(len(results2), self.N)
+ self.assertTrue(all(results2))
+ self.assertListEqual(sorted(res), list(range(self.N)))
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_draining_state(self):
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ async with barrier:
+ # barrier state change to filling for the last task release
+ results.append("draining" in repr(barrier))
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(len(results), self.N)
+ self.assertEqual(results[-1], False)
+ self.assertTrue(all(results[:self.N-1]))
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_blocking_tasks_while_draining(self):
+ rewait = 2
+ barrier = asyncio.Barrier(self.N)
+ barrier_nowaiting = asyncio.Barrier(self.N - rewait)
+ results = []
+ rewait_n = rewait
+ counter = 0
+
+ async def coro():
+ nonlocal rewait_n
+
+ # first time waiting
+ await barrier.wait()
+
+ # after wainting once for all tasks
+ if rewait_n > 0:
+ rewait_n -= 1
+ # wait again only for rewait tasks
+ await barrier.wait()
+ else:
+ # wait for end of draining state`
+ await barrier_nowaiting.wait()
+ # wait for other waiting tasks
+ await barrier.wait()
+
+ # a success means that barrier_nowaiting
+ # was waited for exactly N-rewait=3 times
+ await self.gather_tasks(self.N, coro)
+
+ async def test_filling_tasks_cancel_one(self):
+ self.N = 3
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ await barrier.wait()
+ results.append(True)
+
+ t1 = asyncio.create_task(coro())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 1)
+
+ t2 = asyncio.create_task(coro())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 2)
+
+ t1.cancel()
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 1)
+ with self.assertRaises(asyncio.CancelledError):
+ await t1
+ self.assertTrue(t1.cancelled())
+
+ t3 = asyncio.create_task(coro())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 2)
+
+ t4 = asyncio.create_task(coro())
+ await asyncio.gather(t2, t3, t4)
+
+ self.assertEqual(len(results), self.N)
+ self.assertTrue(all(results))
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier(self):
+ barrier = asyncio.Barrier(1)
+
+ asyncio.create_task(barrier.reset())
+ await asyncio.sleep(0)
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_while_tasks_waiting(self):
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ results.append(True)
+
+ async def coro_reset():
+ await barrier.reset()
+
+ # N-1 tasks waiting on barrier with N parties
+ tasks = self.make_tasks(self.N-1, coro)
+ await asyncio.sleep(0)
+
+ # reset the barrier
+ asyncio.create_task(coro_reset())
+ await asyncio.gather(*tasks)
+
+ self.assertEqual(len(results), self.N-1)
+ self.assertTrue(all(results))
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_when_tasks_half_draining(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ rest_of_tasks = self.N//2
+
+ async def coro():
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # catch here waiting tasks
+ results1.append(True)
+ else:
+ # here drained task ouside the barrier
+ if rest_of_tasks == barrier._count:
+ # tasks outside the barrier
+ await barrier.reset()
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(results1, [True]*rest_of_tasks)
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_when_tasks_half_draining_half_blocking(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ blocking_tasks = self.N//2
+ count = 0
+
+ async def coro():
+ nonlocal count
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here catch still waiting tasks
+ results1.append(True)
+
+ # so now waiting again to reach nb_parties
+ await barrier.wait()
+ else:
+ count += 1
+ if count > blocking_tasks:
+ # reset now: raise asyncio.BrokenBarrierError for waiting tasks
+ await barrier.reset()
+
+ # so now waiting again to reach nb_parties
+ await barrier.wait()
+ else:
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here no catch - blocked tasks go to wait
+ results2.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(results1, [True]*blocking_tasks)
+ self.assertEqual(results2, [])
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_while_tasks_waiting_and_waiting_again(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+
+ async def coro1():
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ results1.append(True)
+ finally:
+ await barrier.wait()
+ results2.append(True)
+
+ async def coro2():
+ async with barrier:
+ results2.append(True)
+
+ tasks = self.make_tasks(self.N-1, coro1)
+
+ # reset barrier, N-1 waiting tasks raise an BrokenBarrierError
+ asyncio.create_task(barrier.reset())
+ await asyncio.sleep(0)
+
+ # complete waiting tasks in the `finally`
+ asyncio.create_task(coro2())
+
+ await asyncio.gather(*tasks)
+
+ self.assertFalse(barrier.broken)
+ self.assertEqual(len(results1), self.N-1)
+ self.assertTrue(all(results1))
+ self.assertEqual(len(results2), self.N)
+ self.assertTrue(all(results2))
+
+ self.assertEqual(barrier.n_waiting, 0)
+
+
+ async def test_reset_barrier_while_tasks_draining(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ results3 = []
+ count = 0
+
+ async def coro():
+ nonlocal count
+
+ i = await barrier.wait()
+ count += 1
+ if count == self.N:
+ # last task exited from barrier
+ await barrier.reset()
+
+ # wit here to reach the `parties`
+ await barrier.wait()
+ else:
+ try:
+ # second waiting
+ await barrier.wait()
+
+ # N-1 tasks here
+ results1.append(True)
+ except Exception as e:
+ # never goes here
+ results2.append(True)
+
+ # Now, pass the barrier again
+ # last wait, must be completed
+ k = await barrier.wait()
+ results3.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertFalse(barrier.broken)
+ self.assertTrue(all(results1))
+ self.assertEqual(len(results1), self.N-1)
+ self.assertEqual(len(results2), 0)
+ self.assertEqual(len(results3), self.N)
+ self.assertTrue(all(results3))
+
+ self.assertEqual(barrier.n_waiting, 0)
+
+ async def test_abort_barrier(self):
+ barrier = asyncio.Barrier(1)
+
+ asyncio.create_task(barrier.abort())
+ await asyncio.sleep(0)
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertTrue(barrier.broken)
+
+ async def test_abort_barrier_when_tasks_half_draining_half_blocking(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ blocking_tasks = self.N//2
+ count = 0
+
+ async def coro():
+ nonlocal count
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here catch tasks waiting to drain
+ results1.append(True)
+ else:
+ count += 1
+ if count > blocking_tasks:
+ # abort now: raise asyncio.BrokenBarrierError for all tasks
+ await barrier.abort()
+ else:
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here catch blocked tasks (already drained)
+ results2.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertTrue(barrier.broken)
+ self.assertEqual(results1, [True]*blocking_tasks)
+ self.assertEqual(results2, [True]*(self.N-blocking_tasks-1))
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+
+ async def test_abort_barrier_when_exception(self):
+ # test from threading.Barrier: see `lock_tests.test_reset`
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+
+ async def coro():
+ try:
+ async with barrier as i :
+ if i == self.N//2:
+ raise RuntimeError
+ async with barrier:
+ results1.append(True)
+ except asyncio.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ await barrier.abort()
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertTrue(barrier.broken)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertTrue(all(results2))
+ self.assertEqual(barrier.n_waiting, 0)
+
+ async def test_abort_barrier_when_exception_then_resetting(self):
+ # test from threading.Barrier: see `lock_tests.test_abort_and_reset``
+ barrier1 = asyncio.Barrier(self.N)
+ barrier2 = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ results3 = []
+
+ async def coro():
+ try:
+ i = await barrier1.wait()
+ if i == self.N//2:
+ raise RuntimeError
+ await barrier1.wait()
+ results1.append(True)
+ except asyncio.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ await barrier1.abort()
+
+ # Synchronize and reset the barrier. Must synchronize first so
+ # that everyone has left it when we reset, and after so that no
+ # one enters it before the reset.
+ i = await barrier2.wait()
+ if i == self.N//2:
+ await barrier1.reset()
+ await barrier2.wait()
+ await barrier1.wait()
+ results3.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertFalse(barrier1.broken)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertTrue(all(results2))
+ self.assertEqual(len(results3), self.N)
+ self.assertTrue(all(results3))
+
+ self.assertEqual(barrier1.n_waiting, 0)
+
+
if __name__ == '__main__':
unittest.main()