diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/locks.py | 64 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_locks.py | 100 |
2 files changed, 142 insertions, 22 deletions
diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py index 7b81c25..fc03830 100644 --- a/Lib/asyncio/locks.py +++ b/Lib/asyncio/locks.py @@ -349,9 +349,8 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin): super().__init__(loop=loop) if value < 0: raise ValueError("Semaphore initial value must be >= 0") + self._waiters = None self._value = value - self._waiters = collections.deque() - self._wakeup_scheduled = False def __repr__(self): res = super().__repr__() @@ -360,16 +359,8 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin): extra = f'{extra}, waiters:{len(self._waiters)}' return f'<{res[1:-1]} [{extra}]>' - def _wake_up_next(self): - while self._waiters: - waiter = self._waiters.popleft() - if not waiter.done(): - waiter.set_result(None) - self._wakeup_scheduled = True - return - def locked(self): - """Returns True if semaphore can not be acquired immediately.""" + """Returns True if semaphore counter is zero.""" return self._value == 0 async def acquire(self): @@ -381,28 +372,57 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin): called release() to make it larger than 0, and then return True. """ - # _wakeup_scheduled is set if *another* task is scheduled to wakeup - # but its acquire() is not resumed yet - while self._wakeup_scheduled or self._value <= 0: - fut = self._get_loop().create_future() - self._waiters.append(fut) + if (not self.locked() and (self._waiters is None or + all(w.cancelled() for w in self._waiters))): + self._value -= 1 + return True + + if self._waiters is None: + self._waiters = collections.deque() + fut = self._get_loop().create_future() + self._waiters.append(fut) + + # Finally block should be called before the CancelledError + # handling as we don't want CancelledError to call + # _wake_up_first() and attempt to wake up itself. + try: try: await fut - # reset _wakeup_scheduled *after* waiting for a future - self._wakeup_scheduled = False - except exceptions.CancelledError: - self._wake_up_next() - raise + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + if not self.locked(): + self._wake_up_first() + raise + self._value -= 1 + if not self.locked(): + self._wake_up_first() return True def release(self): """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to become larger than zero again, wake up that coroutine. """ self._value += 1 - self._wake_up_next() + self._wake_up_first() + + def _wake_up_first(self): + """Wake up the first waiter if it isn't done.""" + if not self._waiters: + return + try: + fut = next(iter(self._waiters)) + except StopIteration: + return + + # .done() necessarily means that a waiter will wake up later on and + # either take the lock, or, if it was cancelled and lock wasn't + # taken already, will hit this again and wake up a new waiter. + if not fut.done(): + fut.set_result(True) class BoundedSemaphore(Semaphore): diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py index 167ae70..c539267 100644 --- a/Lib/test/test_asyncio/test_locks.py +++ b/Lib/test/test_asyncio/test_locks.py @@ -5,6 +5,7 @@ from unittest import mock import re import asyncio +import collections STR_RGX_REPR = ( r'^<(?P<class>.*?) object at (?P<address>.*?)' @@ -782,6 +783,9 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase): self.assertTrue('waiters' not in repr(sem)) self.assertTrue(RGX_REPR.match(repr(sem))) + if sem._waiters is None: + sem._waiters = collections.deque() + sem._waiters.append(mock.Mock()) self.assertTrue('waiters:1' in repr(sem)) self.assertTrue(RGX_REPR.match(repr(sem))) @@ -856,6 +860,7 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(2, sem._value) await asyncio.sleep(0) + await asyncio.sleep(0) self.assertEqual(0, sem._value) self.assertEqual(3, len(result)) self.assertTrue(sem.locked()) @@ -898,6 +903,7 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase): sem.release() await asyncio.sleep(0) + await asyncio.sleep(0) num_done = sum(t.done() for t in [t3, t4]) self.assertEqual(num_done, 1) self.assertTrue(t3.done()) @@ -917,9 +923,32 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase): t1.cancel() sem.release() await asyncio.sleep(0) + await asyncio.sleep(0) self.assertTrue(sem.locked()) self.assertTrue(t2.done()) + async def test_acquire_no_hang(self): + + sem = asyncio.Semaphore(1) + + async def c1(): + async with sem: + await asyncio.sleep(0) + t2.cancel() + + async def c2(): + async with sem: + self.assertFalse(True) + + t1 = asyncio.create_task(c1()) + t2 = asyncio.create_task(c2()) + + r1, r2 = await asyncio.gather(t1, t2, return_exceptions=True) + self.assertTrue(r1 is None) + self.assertTrue(isinstance(r2, asyncio.CancelledError)) + + await asyncio.wait_for(sem.acquire(), timeout=1.0) + def test_release_not_acquired(self): sem = asyncio.BoundedSemaphore() @@ -959,6 +988,77 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase): result ) + async def test_acquire_fifo_order_2(self): + sem = asyncio.Semaphore(1) + result = [] + + async def c1(result): + await sem.acquire() + result.append(1) + return True + + async def c2(result): + await sem.acquire() + result.append(2) + sem.release() + await sem.acquire() + result.append(4) + return True + + async def c3(result): + await sem.acquire() + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + + sem.release() + sem.release() + + tasks = [t1, t2, t3] + await asyncio.gather(*tasks) + self.assertEqual([1, 2, 3, 4], result) + + async def test_acquire_fifo_order_3(self): + sem = asyncio.Semaphore(0) + result = [] + + async def c1(result): + await sem.acquire() + result.append(1) + return True + + async def c2(result): + await sem.acquire() + result.append(2) + return True + + async def c3(result): + await sem.acquire() + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + + t1.cancel() + + await asyncio.sleep(0) + + sem.release() + sem.release() + + tasks = [t1, t2, t3] + await asyncio.gather(*tasks, return_exceptions=True) + self.assertEqual([2, 3], result) + if __name__ == '__main__': unittest.main() |