summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/locks.py64
-rw-r--r--Lib/test/test_asyncio/test_locks.py100
2 files changed, 142 insertions, 22 deletions
diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py
index 7b81c25..fc03830 100644
--- a/Lib/asyncio/locks.py
+++ b/Lib/asyncio/locks.py
@@ -349,9 +349,8 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
super().__init__(loop=loop)
if value < 0:
raise ValueError("Semaphore initial value must be >= 0")
+ self._waiters = None
self._value = value
- self._waiters = collections.deque()
- self._wakeup_scheduled = False
def __repr__(self):
res = super().__repr__()
@@ -360,16 +359,8 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
extra = f'{extra}, waiters:{len(self._waiters)}'
return f'<{res[1:-1]} [{extra}]>'
- def _wake_up_next(self):
- while self._waiters:
- waiter = self._waiters.popleft()
- if not waiter.done():
- waiter.set_result(None)
- self._wakeup_scheduled = True
- return
-
def locked(self):
- """Returns True if semaphore can not be acquired immediately."""
+ """Returns True if semaphore counter is zero."""
return self._value == 0
async def acquire(self):
@@ -381,28 +372,57 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
called release() to make it larger than 0, and then return
True.
"""
- # _wakeup_scheduled is set if *another* task is scheduled to wakeup
- # but its acquire() is not resumed yet
- while self._wakeup_scheduled or self._value <= 0:
- fut = self._get_loop().create_future()
- self._waiters.append(fut)
+ if (not self.locked() and (self._waiters is None or
+ all(w.cancelled() for w in self._waiters))):
+ self._value -= 1
+ return True
+
+ if self._waiters is None:
+ self._waiters = collections.deque()
+ fut = self._get_loop().create_future()
+ self._waiters.append(fut)
+
+ # Finally block should be called before the CancelledError
+ # handling as we don't want CancelledError to call
+ # _wake_up_first() and attempt to wake up itself.
+ try:
try:
await fut
- # reset _wakeup_scheduled *after* waiting for a future
- self._wakeup_scheduled = False
- except exceptions.CancelledError:
- self._wake_up_next()
- raise
+ finally:
+ self._waiters.remove(fut)
+ except exceptions.CancelledError:
+ if not self.locked():
+ self._wake_up_first()
+ raise
+
self._value -= 1
+ if not self.locked():
+ self._wake_up_first()
return True
def release(self):
"""Release a semaphore, incrementing the internal counter by one.
+
When it was zero on entry and another coroutine is waiting for it to
become larger than zero again, wake up that coroutine.
"""
self._value += 1
- self._wake_up_next()
+ self._wake_up_first()
+
+ def _wake_up_first(self):
+ """Wake up the first waiter if it isn't done."""
+ if not self._waiters:
+ return
+ try:
+ fut = next(iter(self._waiters))
+ except StopIteration:
+ return
+
+ # .done() necessarily means that a waiter will wake up later on and
+ # either take the lock, or, if it was cancelled and lock wasn't
+ # taken already, will hit this again and wake up a new waiter.
+ if not fut.done():
+ fut.set_result(True)
class BoundedSemaphore(Semaphore):
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
index 167ae70..c539267 100644
--- a/Lib/test/test_asyncio/test_locks.py
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -5,6 +5,7 @@ from unittest import mock
import re
import asyncio
+import collections
STR_RGX_REPR = (
r'^<(?P<class>.*?) object at (?P<address>.*?)'
@@ -782,6 +783,9 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
self.assertTrue('waiters' not in repr(sem))
self.assertTrue(RGX_REPR.match(repr(sem)))
+ if sem._waiters is None:
+ sem._waiters = collections.deque()
+
sem._waiters.append(mock.Mock())
self.assertTrue('waiters:1' in repr(sem))
self.assertTrue(RGX_REPR.match(repr(sem)))
@@ -856,6 +860,7 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(2, sem._value)
await asyncio.sleep(0)
+ await asyncio.sleep(0)
self.assertEqual(0, sem._value)
self.assertEqual(3, len(result))
self.assertTrue(sem.locked())
@@ -898,6 +903,7 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
sem.release()
await asyncio.sleep(0)
+ await asyncio.sleep(0)
num_done = sum(t.done() for t in [t3, t4])
self.assertEqual(num_done, 1)
self.assertTrue(t3.done())
@@ -917,9 +923,32 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
t1.cancel()
sem.release()
await asyncio.sleep(0)
+ await asyncio.sleep(0)
self.assertTrue(sem.locked())
self.assertTrue(t2.done())
+ async def test_acquire_no_hang(self):
+
+ sem = asyncio.Semaphore(1)
+
+ async def c1():
+ async with sem:
+ await asyncio.sleep(0)
+ t2.cancel()
+
+ async def c2():
+ async with sem:
+ self.assertFalse(True)
+
+ t1 = asyncio.create_task(c1())
+ t2 = asyncio.create_task(c2())
+
+ r1, r2 = await asyncio.gather(t1, t2, return_exceptions=True)
+ self.assertTrue(r1 is None)
+ self.assertTrue(isinstance(r2, asyncio.CancelledError))
+
+ await asyncio.wait_for(sem.acquire(), timeout=1.0)
+
def test_release_not_acquired(self):
sem = asyncio.BoundedSemaphore()
@@ -959,6 +988,77 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
result
)
+ async def test_acquire_fifo_order_2(self):
+ sem = asyncio.Semaphore(1)
+ result = []
+
+ async def c1(result):
+ await sem.acquire()
+ result.append(1)
+ return True
+
+ async def c2(result):
+ await sem.acquire()
+ result.append(2)
+ sem.release()
+ await sem.acquire()
+ result.append(4)
+ return True
+
+ async def c3(result):
+ await sem.acquire()
+ result.append(3)
+ return True
+
+ t1 = asyncio.create_task(c1(result))
+ t2 = asyncio.create_task(c2(result))
+ t3 = asyncio.create_task(c3(result))
+
+ await asyncio.sleep(0)
+
+ sem.release()
+ sem.release()
+
+ tasks = [t1, t2, t3]
+ await asyncio.gather(*tasks)
+ self.assertEqual([1, 2, 3, 4], result)
+
+ async def test_acquire_fifo_order_3(self):
+ sem = asyncio.Semaphore(0)
+ result = []
+
+ async def c1(result):
+ await sem.acquire()
+ result.append(1)
+ return True
+
+ async def c2(result):
+ await sem.acquire()
+ result.append(2)
+ return True
+
+ async def c3(result):
+ await sem.acquire()
+ result.append(3)
+ return True
+
+ t1 = asyncio.create_task(c1(result))
+ t2 = asyncio.create_task(c2(result))
+ t3 = asyncio.create_task(c3(result))
+
+ await asyncio.sleep(0)
+
+ t1.cancel()
+
+ await asyncio.sleep(0)
+
+ sem.release()
+ sem.release()
+
+ tasks = [t1, t2, t3]
+ await asyncio.gather(*tasks, return_exceptions=True)
+ self.assertEqual([2, 3], result)
+
if __name__ == '__main__':
unittest.main()