diff options
author | Guido van Rossum <guido@python.org> | 2015-09-29 18:54:45 (GMT) |
---|---|---|
committer | Guido van Rossum <guido@python.org> | 2015-09-29 18:54:45 (GMT) |
commit | d455a50773eb1f4531882a0b99ff7a253ad1d41e (patch) | |
tree | 28211b4643e499edddd219416ba19a5a93cfa5d0 /Lib | |
parent | d94c1b92ed6420044d38b59371fd934b9ca9a79f (diff) | |
download | cpython-d455a50773eb1f4531882a0b99ff7a253ad1d41e.zip cpython-d455a50773eb1f4531882a0b99ff7a253ad1d41e.tar.gz cpython-d455a50773eb1f4531882a0b99ff7a253ad1d41e.tar.bz2 |
Also rewrote the guts of asyncio.Semaphore (patch by manipopopo).
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/locks.py | 37 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_locks.py | 52 |
2 files changed, 64 insertions, 25 deletions
diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py index 7a13279..34f6bc1 100644 --- a/Lib/asyncio/locks.py +++ b/Lib/asyncio/locks.py @@ -411,6 +411,13 @@ class Semaphore(_ContextManagerMixin): extra = '{},waiters:{}'.format(extra, len(self._waiters)) return '<{} [{}]>'.format(res[1:-1], extra) + def _wake_up_next(self): + while self._waiters: + waiter = self._waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + return + def locked(self): """Returns True if semaphore can not be acquired immediately.""" return self._value == 0 @@ -425,18 +432,19 @@ class Semaphore(_ContextManagerMixin): called release() to make it larger than 0, and then return True. """ - if not self._waiters and self._value > 0: - self._value -= 1 - return True - - fut = futures.Future(loop=self._loop) - self._waiters.append(fut) - try: - yield from fut - self._value -= 1 - return True - finally: - self._waiters.remove(fut) + while self._value <= 0: + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + except: + # See the similar code in Queue.get. + fut.cancel() + if self._value > 0 and not fut.cancelled(): + self._wake_up_next() + raise + self._value -= 1 + return True def release(self): """Release a semaphore, incrementing the internal counter by one. @@ -444,10 +452,7 @@ class Semaphore(_ContextManagerMixin): become larger than zero again, wake up that coroutine. """ self._value += 1 - for waiter in self._waiters: - if not waiter.done(): - waiter.set_result(True) - break + self._wake_up_next() class BoundedSemaphore(Semaphore): diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py index dda4577..cdf5d9d 100644 --- a/Lib/test/test_asyncio/test_locks.py +++ b/Lib/test/test_asyncio/test_locks.py @@ -7,7 +7,6 @@ import re import asyncio from asyncio import test_utils - STR_RGX_REPR = ( r'^<(?P<class>.*?) object at (?P<address>.*?)' r'\[(?P<extras>' @@ -783,22 +782,20 @@ class SemaphoreTests(test_utils.TestCase): test_utils.run_briefly(self.loop) self.assertEqual(0, sem._value) - self.assertEqual([1, 2, 3], result) + self.assertEqual(3, len(result)) self.assertTrue(sem.locked()) self.assertEqual(1, len(sem._waiters)) self.assertEqual(0, sem._value) self.assertTrue(t1.done()) self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) - self.assertFalse(t4.done()) + race_tasks = [t2, t3, t4] + done_tasks = [t for t in race_tasks if t.done() and t.result()] + self.assertTrue(2, len(done_tasks)) # cleanup locked semaphore sem.release() - self.loop.run_until_complete(t4) + self.loop.run_until_complete(asyncio.gather(*race_tasks)) def test_acquire_cancel(self): sem = asyncio.Semaphore(loop=self.loop) @@ -809,7 +806,44 @@ class SemaphoreTests(test_utils.TestCase): self.assertRaises( asyncio.CancelledError, self.loop.run_until_complete, acquire) - self.assertFalse(sem._waiters) + self.assertTrue((not sem._waiters) or + all(waiter.done() for waiter in sem._waiters)) + + def test_acquire_cancel_before_awoken(self): + sem = asyncio.Semaphore(value=0, loop=self.loop) + + t1 = asyncio.Task(sem.acquire(), loop=self.loop) + t2 = asyncio.Task(sem.acquire(), loop=self.loop) + t3 = asyncio.Task(sem.acquire(), loop=self.loop) + t4 = asyncio.Task(sem.acquire(), loop=self.loop) + + test_utils.run_briefly(self.loop) + + sem.release() + t1.cancel() + t2.cancel() + + test_utils.run_briefly(self.loop) + num_done = sum(t.done() for t in [t3, t4]) + self.assertEqual(num_done, 1) + + t3.cancel() + t4.cancel() + test_utils.run_briefly(self.loop) + + def test_acquire_hang(self): + sem = asyncio.Semaphore(value=0, loop=self.loop) + + t1 = asyncio.Task(sem.acquire(), loop=self.loop) + t2 = asyncio.Task(sem.acquire(), loop=self.loop) + + test_utils.run_briefly(self.loop) + + sem.release() + t1.cancel() + + test_utils.run_briefly(self.loop) + self.assertTrue(sem.locked()) def test_release_not_acquired(self): sem = asyncio.BoundedSemaphore(loop=self.loop) |