diff options
author | Guido van Rossum <guido@python.org> | 2014-01-26 00:51:57 (GMT) |
---|---|---|
committer | Guido van Rossum <guido@python.org> | 2014-01-26 00:51:57 (GMT) |
commit | ab3c88983bc6c2a6ae98625296eef9d7588c8d69 (patch) | |
tree | 011a16257d0b36aa1f0fa4ab3535f9eca25098d0 /Lib | |
parent | ab27a9fc4b4edff7c17e699b7e9e2173e9f8bc53 (diff) | |
download | cpython-ab3c88983bc6c2a6ae98625296eef9d7588c8d69.zip cpython-ab3c88983bc6c2a6ae98625296eef9d7588c8d69.tar.gz cpython-ab3c88983bc6c2a6ae98625296eef9d7588c8d69.tar.bz2 |
asyncio: Locks refactor: use a separate context manager; remove Semaphore._locked.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/locks.py | 82 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_locks.py | 35 |
2 files changed, 95 insertions, 22 deletions
diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py index 9fdb937..29c4434 100644 --- a/Lib/asyncio/locks.py +++ b/Lib/asyncio/locks.py @@ -9,6 +9,36 @@ from . import futures from . import tasks +class _ContextManager: + """Context manager. + + This enables the following idiom for acquiring and releasing a + lock around a block: + + with (yield from lock): + <block> + + while failing loudly when accidentally using: + + with lock: + <block> + """ + + def __init__(self, lock): + self._lock = lock + + def __enter__(self): + # We have no use for the "as ..." clause in the with + # statement for locks. + return None + + def __exit__(self, *args): + try: + self._lock.release() + finally: + self._lock = None # Crudely prevent reuse. + + class Lock: """Primitive lock objects. @@ -124,17 +154,29 @@ class Lock: raise RuntimeError('Lock is not acquired.') def __enter__(self): - if not self._locked: - raise RuntimeError( - '"yield from" should be used as context manager expression') - return True + raise RuntimeError( + '"yield from" should be used as context manager expression') def __exit__(self, *args): - self.release() + # This must exist because __enter__ exists, even though that + # always raises; that's how the with-statement works. + pass def __iter__(self): + # This is not a coroutine. It is meant to enable the idiom: + # + # with (yield from lock): + # <block> + # + # as an alternative to: + # + # yield from lock.acquire() + # try: + # <block> + # finally: + # lock.release() yield from self.acquire() - return self + return _ContextManager(self) class Event: @@ -311,14 +353,16 @@ class Condition: self.notify(len(self._waiters)) def __enter__(self): - return self._lock.__enter__() + raise RuntimeError( + '"yield from" should be used as context manager expression') def __exit__(self, *args): - return self._lock.__exit__(*args) + pass def __iter__(self): + # See comment in Lock.__iter__(). yield from self.acquire() - return self + return _ContextManager(self) class Semaphore: @@ -341,7 +385,6 @@ class Semaphore: raise ValueError("Semaphore initial value must be >= 0") self._value = value self._waiters = collections.deque() - self._locked = (value == 0) if loop is not None: self._loop = loop else: @@ -349,7 +392,7 @@ class Semaphore: def __repr__(self): res = super().__repr__() - extra = 'locked' if self._locked else 'unlocked,value:{}'.format( + extra = 'locked' if self.locked() else 'unlocked,value:{}'.format( self._value) if self._waiters: extra = '{},waiters:{}'.format(extra, len(self._waiters)) @@ -357,7 +400,7 @@ class Semaphore: def locked(self): """Returns True if semaphore can not be acquired immediately.""" - return self._locked + return self._value == 0 @tasks.coroutine def acquire(self): @@ -371,8 +414,6 @@ class Semaphore: """ if not self._waiters and self._value > 0: self._value -= 1 - if self._value == 0: - self._locked = True return True fut = futures.Future(loop=self._loop) @@ -380,8 +421,6 @@ class Semaphore: try: yield from fut self._value -= 1 - if self._value == 0: - self._locked = True return True finally: self._waiters.remove(fut) @@ -392,23 +431,22 @@ class Semaphore: become larger than zero again, wake up that coroutine. """ self._value += 1 - self._locked = False for waiter in self._waiters: if not waiter.done(): waiter.set_result(True) break def __enter__(self): - # TODO: This is questionable. How do we know the user actually - # wrote "with (yield from sema)" instead of "with sema"? - return True + raise RuntimeError( + '"yield from" should be used as context manager expression') def __exit__(self, *args): - self.release() + pass def __iter__(self): + # See comment in Lock.__iter__(). yield from self.acquire() - return self + return _ContextManager(self) class BoundedSemaphore(Semaphore): diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py index 5d0e09e..0975f49 100644 --- a/Lib/test/test_asyncio/test_locks.py +++ b/Lib/test/test_asyncio/test_locks.py @@ -208,6 +208,24 @@ class LockTests(unittest.TestCase): self.assertFalse(lock.locked()) + def test_context_manager_cant_reuse(self): + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from lock) + + # This spells "yield from lock" outside a generator. + cm = self.loop.run_until_complete(acquire_lock()) + with cm: + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + with self.assertRaises(AttributeError): + with cm: + pass + def test_context_manager_no_yield(self): lock = asyncio.Lock(loop=self.loop) @@ -219,6 +237,8 @@ class LockTests(unittest.TestCase): str(err), '"yield from" should be used as context manager expression') + self.assertFalse(lock.locked()) + class EventTests(unittest.TestCase): @@ -655,6 +675,8 @@ class ConditionTests(unittest.TestCase): str(err), '"yield from" should be used as context manager expression') + self.assertFalse(cond.locked()) + class SemaphoreTests(unittest.TestCase): @@ -830,6 +852,19 @@ class SemaphoreTests(unittest.TestCase): self.assertEqual(2, sem._value) + def test_context_manager_no_yield(self): + sem = asyncio.Semaphore(2, loop=self.loop) + + try: + with sem: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + self.assertEqual(2, sem._value) + if __name__ == '__main__': unittest.main() |