summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/asyncio/locks.py37
-rw-r--r--Lib/test/test_asyncio/test_locks.py52
-rw-r--r--Misc/NEWS3
3 files changed, 66 insertions, 26 deletions
diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py
index 7a13279..34f6bc1 100644
--- a/Lib/asyncio/locks.py
+++ b/Lib/asyncio/locks.py
@@ -411,6 +411,13 @@ class Semaphore(_ContextManagerMixin):
extra = '{},waiters:{}'.format(extra, len(self._waiters))
return '<{} [{}]>'.format(res[1:-1], extra)
+ def _wake_up_next(self):
+ while self._waiters:
+ waiter = self._waiters.popleft()
+ if not waiter.done():
+ waiter.set_result(None)
+ return
+
def locked(self):
"""Returns True if semaphore can not be acquired immediately."""
return self._value == 0
@@ -425,18 +432,19 @@ class Semaphore(_ContextManagerMixin):
called release() to make it larger than 0, and then return
True.
"""
- if not self._waiters and self._value > 0:
- self._value -= 1
- return True
-
- fut = futures.Future(loop=self._loop)
- self._waiters.append(fut)
- try:
- yield from fut
- self._value -= 1
- return True
- finally:
- self._waiters.remove(fut)
+ while self._value <= 0:
+ fut = futures.Future(loop=self._loop)
+ self._waiters.append(fut)
+ try:
+ yield from fut
+ except:
+ # See the similar code in Queue.get.
+ fut.cancel()
+ if self._value > 0 and not fut.cancelled():
+ self._wake_up_next()
+ raise
+ self._value -= 1
+ return True
def release(self):
"""Release a semaphore, incrementing the internal counter by one.
@@ -444,10 +452,7 @@ class Semaphore(_ContextManagerMixin):
become larger than zero again, wake up that coroutine.
"""
self._value += 1
- for waiter in self._waiters:
- if not waiter.done():
- waiter.set_result(True)
- break
+ self._wake_up_next()
class BoundedSemaphore(Semaphore):
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
index dda4577..cdf5d9d 100644
--- a/Lib/test/test_asyncio/test_locks.py
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -7,7 +7,6 @@ import re
import asyncio
from asyncio import test_utils
-
STR_RGX_REPR = (
r'^<(?P<class>.*?) object at (?P<address>.*?)'
r'\[(?P<extras>'
@@ -783,22 +782,20 @@ class SemaphoreTests(test_utils.TestCase):
test_utils.run_briefly(self.loop)
self.assertEqual(0, sem._value)
- self.assertEqual([1, 2, 3], result)
+ self.assertEqual(3, len(result))
self.assertTrue(sem.locked())
self.assertEqual(1, len(sem._waiters))
self.assertEqual(0, sem._value)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
- self.assertTrue(t2.done())
- self.assertTrue(t2.result())
- self.assertTrue(t3.done())
- self.assertTrue(t3.result())
- self.assertFalse(t4.done())
+ race_tasks = [t2, t3, t4]
+ done_tasks = [t for t in race_tasks if t.done() and t.result()]
+ self.assertTrue(2, len(done_tasks))
# cleanup locked semaphore
sem.release()
- self.loop.run_until_complete(t4)
+ self.loop.run_until_complete(asyncio.gather(*race_tasks))
def test_acquire_cancel(self):
sem = asyncio.Semaphore(loop=self.loop)
@@ -809,7 +806,44 @@ class SemaphoreTests(test_utils.TestCase):
self.assertRaises(
asyncio.CancelledError,
self.loop.run_until_complete, acquire)
- self.assertFalse(sem._waiters)
+ self.assertTrue((not sem._waiters) or
+ all(waiter.done() for waiter in sem._waiters))
+
+ def test_acquire_cancel_before_awoken(self):
+ sem = asyncio.Semaphore(value=0, loop=self.loop)
+
+ t1 = asyncio.Task(sem.acquire(), loop=self.loop)
+ t2 = asyncio.Task(sem.acquire(), loop=self.loop)
+ t3 = asyncio.Task(sem.acquire(), loop=self.loop)
+ t4 = asyncio.Task(sem.acquire(), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+
+ sem.release()
+ t1.cancel()
+ t2.cancel()
+
+ test_utils.run_briefly(self.loop)
+ num_done = sum(t.done() for t in [t3, t4])
+ self.assertEqual(num_done, 1)
+
+ t3.cancel()
+ t4.cancel()
+ test_utils.run_briefly(self.loop)
+
+ def test_acquire_hang(self):
+ sem = asyncio.Semaphore(value=0, loop=self.loop)
+
+ t1 = asyncio.Task(sem.acquire(), loop=self.loop)
+ t2 = asyncio.Task(sem.acquire(), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+
+ sem.release()
+ t1.cancel()
+
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(sem.locked())
def test_release_not_acquired(self):
sem = asyncio.BoundedSemaphore(loop=self.loop)
diff --git a/Misc/NEWS b/Misc/NEWS
index 5fb3be1..861b83e 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -193,7 +193,8 @@ Library
- Issue #25034: Fix string.Formatter problem with auto-numbering and
nested format_specs. Patch by Anthon van der Neut.
-- Issue #25233: Rewrite the guts of asyncio.Queue to be more understandable and correct.
+- Issue #25233: Rewrite the guts of asyncio.Queue and
+ asyncio.Semaphore to be more understandable and correct.
- Issue #23600: Default implementation of tzinfo.fromutc() was returning
wrong results in some cases.