diff options
-rw-r--r-- | Lib/asyncio/locks.py | 37 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_locks.py | 52 | ||||
-rw-r--r-- | Misc/NEWS | 3 |
3 files changed, 66 insertions, 26 deletions
diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py index 7a13279..34f6bc1 100644 --- a/Lib/asyncio/locks.py +++ b/Lib/asyncio/locks.py @@ -411,6 +411,13 @@ class Semaphore(_ContextManagerMixin): extra = '{},waiters:{}'.format(extra, len(self._waiters)) return '<{} [{}]>'.format(res[1:-1], extra) + def _wake_up_next(self): + while self._waiters: + waiter = self._waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + return + def locked(self): """Returns True if semaphore can not be acquired immediately.""" return self._value == 0 @@ -425,18 +432,19 @@ class Semaphore(_ContextManagerMixin): called release() to make it larger than 0, and then return True. """ - if not self._waiters and self._value > 0: - self._value -= 1 - return True - - fut = futures.Future(loop=self._loop) - self._waiters.append(fut) - try: - yield from fut - self._value -= 1 - return True - finally: - self._waiters.remove(fut) + while self._value <= 0: + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + except: + # See the similar code in Queue.get. + fut.cancel() + if self._value > 0 and not fut.cancelled(): + self._wake_up_next() + raise + self._value -= 1 + return True def release(self): """Release a semaphore, incrementing the internal counter by one. @@ -444,10 +452,7 @@ class Semaphore(_ContextManagerMixin): become larger than zero again, wake up that coroutine. """ self._value += 1 - for waiter in self._waiters: - if not waiter.done(): - waiter.set_result(True) - break + self._wake_up_next() class BoundedSemaphore(Semaphore): diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py index dda4577..cdf5d9d 100644 --- a/Lib/test/test_asyncio/test_locks.py +++ b/Lib/test/test_asyncio/test_locks.py @@ -7,7 +7,6 @@ import re import asyncio from asyncio import test_utils - STR_RGX_REPR = ( r'^<(?P<class>.*?) object at (?P<address>.*?)' r'\[(?P<extras>' @@ -783,22 +782,20 @@ class SemaphoreTests(test_utils.TestCase): test_utils.run_briefly(self.loop) self.assertEqual(0, sem._value) - self.assertEqual([1, 2, 3], result) + self.assertEqual(3, len(result)) self.assertTrue(sem.locked()) self.assertEqual(1, len(sem._waiters)) self.assertEqual(0, sem._value) self.assertTrue(t1.done()) self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) - self.assertFalse(t4.done()) + race_tasks = [t2, t3, t4] + done_tasks = [t for t in race_tasks if t.done() and t.result()] + self.assertTrue(2, len(done_tasks)) # cleanup locked semaphore sem.release() - self.loop.run_until_complete(t4) + self.loop.run_until_complete(asyncio.gather(*race_tasks)) def test_acquire_cancel(self): sem = asyncio.Semaphore(loop=self.loop) @@ -809,7 +806,44 @@ class SemaphoreTests(test_utils.TestCase): self.assertRaises( asyncio.CancelledError, self.loop.run_until_complete, acquire) - self.assertFalse(sem._waiters) + self.assertTrue((not sem._waiters) or + all(waiter.done() for waiter in sem._waiters)) + + def test_acquire_cancel_before_awoken(self): + sem = asyncio.Semaphore(value=0, loop=self.loop) + + t1 = asyncio.Task(sem.acquire(), loop=self.loop) + t2 = asyncio.Task(sem.acquire(), loop=self.loop) + t3 = asyncio.Task(sem.acquire(), loop=self.loop) + t4 = asyncio.Task(sem.acquire(), loop=self.loop) + + test_utils.run_briefly(self.loop) + + sem.release() + t1.cancel() + t2.cancel() + + test_utils.run_briefly(self.loop) + num_done = sum(t.done() for t in [t3, t4]) + self.assertEqual(num_done, 1) + + t3.cancel() + t4.cancel() + test_utils.run_briefly(self.loop) + + def test_acquire_hang(self): + sem = asyncio.Semaphore(value=0, loop=self.loop) + + t1 = asyncio.Task(sem.acquire(), loop=self.loop) + t2 = asyncio.Task(sem.acquire(), loop=self.loop) + + test_utils.run_briefly(self.loop) + + sem.release() + t1.cancel() + + test_utils.run_briefly(self.loop) + self.assertTrue(sem.locked()) def test_release_not_acquired(self): sem = asyncio.BoundedSemaphore(loop=self.loop) @@ -81,7 +81,8 @@ Library - Issue #25034: Fix string.Formatter problem with auto-numbering and nested format_specs. Patch by Anthon van der Neut. -- Issue #25233: Rewrite the guts of asyncio.Queue to be more understandable and correct. +- Issue #25233: Rewrite the guts of asyncio.Queue and + asyncio.Semaphore to be more understandable and correct. - Issue #23600: Default implementation of tzinfo.fromutc() was returning wrong results in some cases. |