summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDuprat <yduprat@gmail.com>2022-03-25 22:01:21 (GMT)
committerGitHub <noreply@github.com>2022-03-25 22:01:21 (GMT)
commitd03acd7270d66ddb8e987f9743405147ecc15087 (patch)
treecffe25f0c26d55aef28c910dcf825747da99a6d4
parent20e6e5636a06fe5e1472062918d0a302d82a71c3 (diff)
downloadcpython-d03acd7270d66ddb8e987f9743405147ecc15087.zip
cpython-d03acd7270d66ddb8e987f9743405147ecc15087.tar.gz
cpython-d03acd7270d66ddb8e987f9743405147ecc15087.tar.bz2
bpo-43352: Add a Barrier object in asyncio lib (GH-24903)
Co-authored-by: Yury Selivanov <yury@edgedb.com> Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
-rw-r--r--Doc/library/asyncio-api-index.rst8
-rw-r--r--Doc/library/asyncio-sync.rst110
-rw-r--r--Lib/asyncio/exceptions.py7
-rw-r--r--Lib/asyncio/locks.py157
-rw-r--r--Lib/test/test_asyncio/test_locks.py578
-rw-r--r--Misc/NEWS.d/next/Library/2021-03-31-15-22-45.bpo-43352.nSjMuE.rst1
6 files changed, 856 insertions, 5 deletions
diff --git a/Doc/library/asyncio-api-index.rst b/Doc/library/asyncio-api-index.rst
index 8bc7943..a4e38e4 100644
--- a/Doc/library/asyncio-api-index.rst
+++ b/Doc/library/asyncio-api-index.rst
@@ -186,11 +186,16 @@ Threading-like synchronization primitives that can be used in Tasks.
* - :class:`BoundedSemaphore`
- A bounded semaphore.
+ * - :class:`Barrier`
+ - A barrier object.
+
.. rubric:: Examples
* :ref:`Using asyncio.Event <asyncio_example_sync_event>`.
+* :ref:`Using asyncio.Barrier <asyncio_example_barrier>`.
+
* See also the documentation of asyncio
:ref:`synchronization primitives <asyncio-sync>`.
@@ -206,6 +211,9 @@ Exceptions
* - :exc:`asyncio.CancelledError`
- Raised when a Task is cancelled. See also :meth:`Task.cancel`.
+ * - :exc:`asyncio.BrokenBarrierError`
+ - Raised when a Barrier is broken. See also :meth:`Barrier.wait`.
+
.. rubric:: Examples
diff --git a/Doc/library/asyncio-sync.rst b/Doc/library/asyncio-sync.rst
index f4063db..141733e 100644
--- a/Doc/library/asyncio-sync.rst
+++ b/Doc/library/asyncio-sync.rst
@@ -28,6 +28,7 @@ asyncio has the following basic synchronization primitives:
* :class:`Condition`
* :class:`Semaphore`
* :class:`BoundedSemaphore`
+* :class:`Barrier`
---------
@@ -340,6 +341,115 @@ BoundedSemaphore
.. versionchanged:: 3.10
Removed the *loop* parameter.
+
+Barrier
+=======
+
+.. class:: Barrier(parties, action=None)
+
+ A barrier object. Not thread-safe.
+
+ A barrier is a simple synchronization primitive that allows to block until
+ *parties* number of tasks are waiting on it.
+ Tasks can wait on the :meth:`~Barrier.wait` method and would be blocked until
+ the specified number of tasks end up waiting on :meth:`~Barrier.wait`.
+ At that point all of the waiting tasks would unblock simultaneously.
+
+ :keyword:`async with` can be used as an alternative to awaiting on
+ :meth:`~Barrier.wait`.
+
+ The barrier can be reused any number of times.
+
+ .. _asyncio_example_barrier:
+
+ Example::
+
+ async def example_barrier():
+ # barrier with 3 parties
+ b = asyncio.Barrier(3)
+
+ # create 2 new waiting tasks
+ asyncio.create_task(b.wait())
+ asyncio.create_task(b.wait())
+
+ await asyncio.sleep(0)
+ print(b)
+
+ # The third .wait() call passes the barrier
+ await b.wait()
+ print(b)
+ print("barrier passed")
+
+ await asyncio.sleep(0)
+ print(b)
+
+ asyncio.run(example_barrier())
+
+ Result of this example is::
+
+ <asyncio.locks.Barrier object at 0x... [filling, waiters:2/3]>
+ <asyncio.locks.Barrier object at 0x... [draining, waiters:0/3]>
+ barrier passed
+ <asyncio.locks.Barrier object at 0x... [filling, waiters:0/3]>
+
+ .. versionadded:: 3.11
+
+ .. coroutinemethod:: wait()
+
+ Pass the barrier. When all the tasks party to the barrier have called
+ this function, they are all unblocked simultaneously.
+
+ When a waiting or blocked task in the barrier is cancelled,
+ this task exits the barrier which stays in the same state.
+ If the state of the barrier is "filling", the number of waiting task
+ decreases by 1.
+
+ The return value is an integer in the range of 0 to ``parties-1``, different
+ for each task. This can be used to select a task to do some special
+ housekeeping, e.g.::
+
+ ...
+ async with barrier as position:
+ if position == 0:
+ # Only one task print this
+ print('End of *draining phasis*')
+
+ This method may raise a :class:`BrokenBarrierError` exception if the
+ barrier is broken or reset while a task is waiting.
+ It could raise a :exc:`CancelledError` if a task is cancelled.
+
+ .. coroutinemethod:: reset()
+
+ Return the barrier to the default, empty state. Any tasks waiting on it
+ will receive the :class:`BrokenBarrierError` exception.
+
+ If a barrier is broken it may be better to just leave it and create a new one.
+
+ .. coroutinemethod:: abort()
+
+ Put the barrier into a broken state. This causes any active or future
+ calls to :meth:`wait` to fail with the :class:`BrokenBarrierError`.
+ Use this for example if one of the taks needs to abort, to avoid infinite
+ waiting tasks.
+
+ .. attribute:: parties
+
+ The number of tasks required to pass the barrier.
+
+ .. attribute:: n_waiting
+
+ The number of tasks currently waiting in the barrier while filling.
+
+ .. attribute:: broken
+
+ A boolean that is ``True`` if the barrier is in the broken state.
+
+
+.. exception:: BrokenBarrierError
+
+ This exception, a subclass of :exc:`RuntimeError`, is raised when the
+ :class:`Barrier` object is reset or broken.
+
---------
diff --git a/Lib/asyncio/exceptions.py b/Lib/asyncio/exceptions.py
index c764c9f..5ece595 100644
--- a/Lib/asyncio/exceptions.py
+++ b/Lib/asyncio/exceptions.py
@@ -1,7 +1,8 @@
"""asyncio exceptions."""
-__all__ = ('CancelledError', 'InvalidStateError', 'TimeoutError',
+__all__ = ('BrokenBarrierError',
+ 'CancelledError', 'InvalidStateError', 'TimeoutError',
'IncompleteReadError', 'LimitOverrunError',
'SendfileNotAvailableError')
@@ -55,3 +56,7 @@ class LimitOverrunError(Exception):
def __reduce__(self):
return type(self), (self.args[0], self.consumed)
+
+
+class BrokenBarrierError(RuntimeError):
+ """Barrier is broken by barrier.abort() call."""
diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py
index 9b46121..e711302 100644
--- a/Lib/asyncio/locks.py
+++ b/Lib/asyncio/locks.py
@@ -1,14 +1,15 @@
"""Synchronization primitives."""
-__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore')
+__all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
+ 'BoundedSemaphore', 'Barrier')
import collections
+import enum
from . import exceptions
from . import mixins
from . import tasks
-
class _ContextManagerMixin:
async def __aenter__(self):
await self.acquire()
@@ -416,3 +417,155 @@ class BoundedSemaphore(Semaphore):
if self._value >= self._bound_value:
raise ValueError('BoundedSemaphore released too many times')
super().release()
+
+
+
+class _BarrierState(enum.Enum):
+ FILLING = 'filling'
+ DRAINING = 'draining'
+ RESETTING = 'resetting'
+ BROKEN = 'broken'
+
+
+class Barrier(mixins._LoopBoundMixin):
+ """Asyncio equivalent to threading.Barrier
+
+ Implements a Barrier primitive.
+ Useful for synchronizing a fixed number of tasks at known synchronization
+ points. Tasks block on 'wait()' and are simultaneously awoken once they
+ have all made their call.
+ """
+
+ def __init__(self, parties):
+ """Create a barrier, initialised to 'parties' tasks."""
+ if parties < 1:
+ raise ValueError('parties must be > 0')
+
+ self._cond = Condition() # notify all tasks when state changes
+
+ self._parties = parties
+ self._state = _BarrierState.FILLING
+ self._count = 0 # count tasks in Barrier
+
+ def __repr__(self):
+ res = super().__repr__()
+ extra = f'{self._state.value}'
+ if not self.broken:
+ extra += f', waiters:{self.n_waiting}/{self.parties}'
+ return f'<{res[1:-1]} [{extra}]>'
+
+ async def __aenter__(self):
+ # wait for the barrier reaches the parties number
+ # when start draining release and return index of waited task
+ return await self.wait()
+
+ async def __aexit__(self, *args):
+ pass
+
+ async def wait(self):
+ """Wait for the barrier.
+
+ When the specified number of tasks have started waiting, they are all
+ simultaneously awoken.
+ Returns an unique and individual index number from 0 to 'parties-1'.
+ """
+ async with self._cond:
+ await self._block() # Block while the barrier drains or resets.
+ try:
+ index = self._count
+ self._count += 1
+ if index + 1 == self._parties:
+ # We release the barrier
+ await self._release()
+ else:
+ await self._wait()
+ return index
+ finally:
+ self._count -= 1
+ # Wake up any tasks waiting for barrier to drain.
+ self._exit()
+
+ async def _block(self):
+ # Block until the barrier is ready for us,
+ # or raise an exception if it is broken.
+ #
+ # It is draining or resetting, wait until done
+ # unless a CancelledError occurs
+ await self._cond.wait_for(
+ lambda: self._state not in (
+ _BarrierState.DRAINING, _BarrierState.RESETTING
+ )
+ )
+
+ # see if the barrier is in a broken state
+ if self._state is _BarrierState.BROKEN:
+ raise exceptions.BrokenBarrierError("Barrier aborted")
+
+ async def _release(self):
+ # Release the tasks waiting in the barrier.
+
+ # Enter draining state.
+ # Next waiting tasks will be blocked until the end of draining.
+ self._state = _BarrierState.DRAINING
+ self._cond.notify_all()
+
+ async def _wait(self):
+ # Wait in the barrier until we are released. Raise an exception
+ # if the barrier is reset or broken.
+
+ # wait for end of filling
+ # unless a CancelledError occurs
+ await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
+
+ if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
+ raise exceptions.BrokenBarrierError("Abort or reset of barrier")
+
+ def _exit(self):
+ # If we are the last tasks to exit the barrier, signal any tasks
+ # waiting for the barrier to drain.
+ if self._count == 0:
+ if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
+ self._state = _BarrierState.FILLING
+ self._cond.notify_all()
+
+ async def reset(self):
+ """Reset the barrier to the initial state.
+
+ Any tasks currently waiting will get the BrokenBarrier exception
+ raised.
+ """
+ async with self._cond:
+ if self._count > 0:
+ if self._state is not _BarrierState.RESETTING:
+ #reset the barrier, waking up tasks
+ self._state = _BarrierState.RESETTING
+ else:
+ self._state = _BarrierState.FILLING
+ self._cond.notify_all()
+
+ async def abort(self):
+ """Place the barrier into a 'broken' state.
+
+ Useful in case of error. Any currently waiting tasks and tasks
+ attempting to 'wait()' will have BrokenBarrierError raised.
+ """
+ async with self._cond:
+ self._state = _BarrierState.BROKEN
+ self._cond.notify_all()
+
+ @property
+ def parties(self):
+ """Return the number of tasks required to trip the barrier."""
+ return self._parties
+
+ @property
+ def n_waiting(self):
+ """Return the number of tasks currently waiting at the barrier."""
+ if self._state is _BarrierState.FILLING:
+ return self._count
+ return 0
+
+ @property
+ def broken(self):
+ """Return True if the barrier is in a broken state."""
+ return self._state is _BarrierState.BROKEN
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
index 920b3b5..415cbe5 100644
--- a/Lib/test/test_asyncio/test_locks.py
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -1,4 +1,4 @@
-"""Tests for lock.py"""
+"""Tests for locks.py"""
import unittest
from unittest import mock
@@ -9,7 +9,10 @@ import asyncio
STR_RGX_REPR = (
r'^<(?P<class>.*?) object at (?P<address>.*?)'
r'\[(?P<extras>'
- r'(set|unset|locked|unlocked)(, value:\d)?(, waiters:\d+)?'
+ r'(set|unset|locked|unlocked|filling|draining|resetting|broken)'
+ r'(, value:\d)?'
+ r'(, waiters:\d+)?'
+ r'(, waiters:\d+\/\d+)?' # barrier
r')\]>\Z'
)
RGX_REPR = re.compile(STR_RGX_REPR)
@@ -943,5 +946,576 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
)
+class BarrierTests(unittest.IsolatedAsyncioTestCase):
+
+ async def asyncSetUp(self):
+ await super().asyncSetUp()
+ self.N = 5
+
+ def make_tasks(self, n, coro):
+ tasks = [asyncio.create_task(coro()) for _ in range(n)]
+ return tasks
+
+ async def gather_tasks(self, n, coro):
+ tasks = self.make_tasks(n, coro)
+ res = await asyncio.gather(*tasks)
+ return res, tasks
+
+ async def test_barrier(self):
+ barrier = asyncio.Barrier(self.N)
+ self.assertIn("filling", repr(barrier))
+ with self.assertRaisesRegex(
+ TypeError,
+ "object Barrier can't be used in 'await' expression",
+ ):
+ await barrier
+
+ self.assertIn("filling", repr(barrier))
+
+ async def test_repr(self):
+ barrier = asyncio.Barrier(self.N)
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("filling", repr(barrier))
+
+ waiters = []
+ async def wait(barrier):
+ await barrier.wait()
+
+ incr = 2
+ for i in range(incr):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertTrue(f"waiters:{incr}/{self.N}" in repr(barrier))
+ self.assertIn("filling", repr(barrier))
+
+ # create missing waiters
+ for i in range(barrier.parties - barrier.n_waiting):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("draining", repr(barrier))
+
+ # add a part of waiters
+ for i in range(incr):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+ # and reset
+ await barrier.reset()
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("resetting", repr(barrier))
+
+ # add a part of waiters again
+ for i in range(incr):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+ # and abort
+ await barrier.abort()
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("broken", repr(barrier))
+ self.assertTrue(barrier.broken)
+
+ # suppress unhandled exceptions
+ await asyncio.gather(*waiters, return_exceptions=True)
+
+ async def test_barrier_parties(self):
+ self.assertRaises(ValueError, lambda: asyncio.Barrier(0))
+ self.assertRaises(ValueError, lambda: asyncio.Barrier(-4))
+
+ self.assertIsInstance(asyncio.Barrier(self.N), asyncio.Barrier)
+
+ async def test_context_manager(self):
+ self.N = 3
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ async with barrier as i:
+ results.append(i)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertListEqual(sorted(results), list(range(self.N)))
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_one_task(self):
+ barrier = asyncio.Barrier(1)
+
+ async def f():
+ async with barrier as i:
+ return True
+
+ ret = await f()
+
+ self.assertTrue(ret)
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_one_task_twice(self):
+ barrier = asyncio.Barrier(1)
+
+ t1 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 0)
+
+ t2 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+
+ self.assertEqual(t1.result(), t2.result())
+ self.assertEqual(t1.done(), t2.done())
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_task_by_task(self):
+ self.N = 3
+ barrier = asyncio.Barrier(self.N)
+
+ t1 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 1)
+ self.assertIn("filling", repr(barrier))
+
+ t2 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 2)
+ self.assertIn("filling", repr(barrier))
+
+ t3 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+
+ await asyncio.wait([t1, t2, t3])
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_tasks_wait_twice(self):
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ async with barrier:
+ results.append(True)
+
+ async with barrier:
+ results.append(False)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(len(results), self.N*2)
+ self.assertEqual(results.count(True), self.N)
+ self.assertEqual(results.count(False), self.N)
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_tasks_check_return_value(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+
+ async def coro():
+ async with barrier:
+ results1.append(True)
+
+ async with barrier as i:
+ results2.append(True)
+ return i
+
+ res, _ = await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(len(results1), self.N)
+ self.assertTrue(all(results1))
+ self.assertEqual(len(results2), self.N)
+ self.assertTrue(all(results2))
+ self.assertListEqual(sorted(res), list(range(self.N)))
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_draining_state(self):
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ async with barrier:
+ # barrier state change to filling for the last task release
+ results.append("draining" in repr(barrier))
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(len(results), self.N)
+ self.assertEqual(results[-1], False)
+ self.assertTrue(all(results[:self.N-1]))
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_blocking_tasks_while_draining(self):
+ rewait = 2
+ barrier = asyncio.Barrier(self.N)
+ barrier_nowaiting = asyncio.Barrier(self.N - rewait)
+ results = []
+ rewait_n = rewait
+ counter = 0
+
+ async def coro():
+ nonlocal rewait_n
+
+ # first time waiting
+ await barrier.wait()
+
+ # after wainting once for all tasks
+ if rewait_n > 0:
+ rewait_n -= 1
+ # wait again only for rewait tasks
+ await barrier.wait()
+ else:
+ # wait for end of draining state`
+ await barrier_nowaiting.wait()
+ # wait for other waiting tasks
+ await barrier.wait()
+
+ # a success means that barrier_nowaiting
+ # was waited for exactly N-rewait=3 times
+ await self.gather_tasks(self.N, coro)
+
+ async def test_filling_tasks_cancel_one(self):
+ self.N = 3
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ await barrier.wait()
+ results.append(True)
+
+ t1 = asyncio.create_task(coro())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 1)
+
+ t2 = asyncio.create_task(coro())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 2)
+
+ t1.cancel()
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 1)
+ with self.assertRaises(asyncio.CancelledError):
+ await t1
+ self.assertTrue(t1.cancelled())
+
+ t3 = asyncio.create_task(coro())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 2)
+
+ t4 = asyncio.create_task(coro())
+ await asyncio.gather(t2, t3, t4)
+
+ self.assertEqual(len(results), self.N)
+ self.assertTrue(all(results))
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier(self):
+ barrier = asyncio.Barrier(1)
+
+ asyncio.create_task(barrier.reset())
+ await asyncio.sleep(0)
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_while_tasks_waiting(self):
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ results.append(True)
+
+ async def coro_reset():
+ await barrier.reset()
+
+ # N-1 tasks waiting on barrier with N parties
+ tasks = self.make_tasks(self.N-1, coro)
+ await asyncio.sleep(0)
+
+ # reset the barrier
+ asyncio.create_task(coro_reset())
+ await asyncio.gather(*tasks)
+
+ self.assertEqual(len(results), self.N-1)
+ self.assertTrue(all(results))
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_when_tasks_half_draining(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ rest_of_tasks = self.N//2
+
+ async def coro():
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # catch here waiting tasks
+ results1.append(True)
+ else:
+ # here drained task ouside the barrier
+ if rest_of_tasks == barrier._count:
+ # tasks outside the barrier
+ await barrier.reset()
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(results1, [True]*rest_of_tasks)
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_when_tasks_half_draining_half_blocking(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ blocking_tasks = self.N//2
+ count = 0
+
+ async def coro():
+ nonlocal count
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here catch still waiting tasks
+ results1.append(True)
+
+ # so now waiting again to reach nb_parties
+ await barrier.wait()
+ else:
+ count += 1
+ if count > blocking_tasks:
+ # reset now: raise asyncio.BrokenBarrierError for waiting tasks
+ await barrier.reset()
+
+ # so now waiting again to reach nb_parties
+ await barrier.wait()
+ else:
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here no catch - blocked tasks go to wait
+ results2.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(results1, [True]*blocking_tasks)
+ self.assertEqual(results2, [])
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_while_tasks_waiting_and_waiting_again(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+
+ async def coro1():
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ results1.append(True)
+ finally:
+ await barrier.wait()
+ results2.append(True)
+
+ async def coro2():
+ async with barrier:
+ results2.append(True)
+
+ tasks = self.make_tasks(self.N-1, coro1)
+
+ # reset barrier, N-1 waiting tasks raise an BrokenBarrierError
+ asyncio.create_task(barrier.reset())
+ await asyncio.sleep(0)
+
+ # complete waiting tasks in the `finally`
+ asyncio.create_task(coro2())
+
+ await asyncio.gather(*tasks)
+
+ self.assertFalse(barrier.broken)
+ self.assertEqual(len(results1), self.N-1)
+ self.assertTrue(all(results1))
+ self.assertEqual(len(results2), self.N)
+ self.assertTrue(all(results2))
+
+ self.assertEqual(barrier.n_waiting, 0)
+
+
+ async def test_reset_barrier_while_tasks_draining(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ results3 = []
+ count = 0
+
+ async def coro():
+ nonlocal count
+
+ i = await barrier.wait()
+ count += 1
+ if count == self.N:
+ # last task exited from barrier
+ await barrier.reset()
+
+ # wit here to reach the `parties`
+ await barrier.wait()
+ else:
+ try:
+ # second waiting
+ await barrier.wait()
+
+ # N-1 tasks here
+ results1.append(True)
+ except Exception as e:
+ # never goes here
+ results2.append(True)
+
+ # Now, pass the barrier again
+ # last wait, must be completed
+ k = await barrier.wait()
+ results3.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertFalse(barrier.broken)
+ self.assertTrue(all(results1))
+ self.assertEqual(len(results1), self.N-1)
+ self.assertEqual(len(results2), 0)
+ self.assertEqual(len(results3), self.N)
+ self.assertTrue(all(results3))
+
+ self.assertEqual(barrier.n_waiting, 0)
+
+ async def test_abort_barrier(self):
+ barrier = asyncio.Barrier(1)
+
+ asyncio.create_task(barrier.abort())
+ await asyncio.sleep(0)
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertTrue(barrier.broken)
+
+ async def test_abort_barrier_when_tasks_half_draining_half_blocking(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ blocking_tasks = self.N//2
+ count = 0
+
+ async def coro():
+ nonlocal count
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here catch tasks waiting to drain
+ results1.append(True)
+ else:
+ count += 1
+ if count > blocking_tasks:
+ # abort now: raise asyncio.BrokenBarrierError for all tasks
+ await barrier.abort()
+ else:
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here catch blocked tasks (already drained)
+ results2.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertTrue(barrier.broken)
+ self.assertEqual(results1, [True]*blocking_tasks)
+ self.assertEqual(results2, [True]*(self.N-blocking_tasks-1))
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+
+ async def test_abort_barrier_when_exception(self):
+ # test from threading.Barrier: see `lock_tests.test_reset`
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+
+ async def coro():
+ try:
+ async with barrier as i :
+ if i == self.N//2:
+ raise RuntimeError
+ async with barrier:
+ results1.append(True)
+ except asyncio.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ await barrier.abort()
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertTrue(barrier.broken)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertTrue(all(results2))
+ self.assertEqual(barrier.n_waiting, 0)
+
+ async def test_abort_barrier_when_exception_then_resetting(self):
+ # test from threading.Barrier: see `lock_tests.test_abort_and_reset``
+ barrier1 = asyncio.Barrier(self.N)
+ barrier2 = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ results3 = []
+
+ async def coro():
+ try:
+ i = await barrier1.wait()
+ if i == self.N//2:
+ raise RuntimeError
+ await barrier1.wait()
+ results1.append(True)
+ except asyncio.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ await barrier1.abort()
+
+ # Synchronize and reset the barrier. Must synchronize first so
+ # that everyone has left it when we reset, and after so that no
+ # one enters it before the reset.
+ i = await barrier2.wait()
+ if i == self.N//2:
+ await barrier1.reset()
+ await barrier2.wait()
+ await barrier1.wait()
+ results3.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertFalse(barrier1.broken)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertTrue(all(results2))
+ self.assertEqual(len(results3), self.N)
+ self.assertTrue(all(results3))
+
+ self.assertEqual(barrier1.n_waiting, 0)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2021-03-31-15-22-45.bpo-43352.nSjMuE.rst b/Misc/NEWS.d/next/Library/2021-03-31-15-22-45.bpo-43352.nSjMuE.rst
new file mode 100644
index 0000000..e53ba28
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-03-31-15-22-45.bpo-43352.nSjMuE.rst
@@ -0,0 +1 @@
+Add an Barrier object in synchronization primitives of *asyncio* Lib in order to be consistant with Barrier from *threading* and *multiprocessing* libs*