summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorGuido van Rossum <guido@python.org>2014-01-26 00:51:57 (GMT)
committerGuido van Rossum <guido@python.org>2014-01-26 00:51:57 (GMT)
commitab3c88983bc6c2a6ae98625296eef9d7588c8d69 (patch)
tree011a16257d0b36aa1f0fa4ab3535f9eca25098d0 /Lib
parentab27a9fc4b4edff7c17e699b7e9e2173e9f8bc53 (diff)
downloadcpython-ab3c88983bc6c2a6ae98625296eef9d7588c8d69.zip
cpython-ab3c88983bc6c2a6ae98625296eef9d7588c8d69.tar.gz
cpython-ab3c88983bc6c2a6ae98625296eef9d7588c8d69.tar.bz2
asyncio: Locks refactor: use a separate context manager; remove Semaphore._locked.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/locks.py82
-rw-r--r--Lib/test/test_asyncio/test_locks.py35
2 files changed, 95 insertions, 22 deletions
diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py
index 9fdb937..29c4434 100644
--- a/Lib/asyncio/locks.py
+++ b/Lib/asyncio/locks.py
@@ -9,6 +9,36 @@ from . import futures
from . import tasks
+class _ContextManager:
+ """Context manager.
+
+ This enables the following idiom for acquiring and releasing a
+ lock around a block:
+
+ with (yield from lock):
+ <block>
+
+ while failing loudly when accidentally using:
+
+ with lock:
+ <block>
+ """
+
+ def __init__(self, lock):
+ self._lock = lock
+
+ def __enter__(self):
+ # We have no use for the "as ..." clause in the with
+ # statement for locks.
+ return None
+
+ def __exit__(self, *args):
+ try:
+ self._lock.release()
+ finally:
+ self._lock = None # Crudely prevent reuse.
+
+
class Lock:
"""Primitive lock objects.
@@ -124,17 +154,29 @@ class Lock:
raise RuntimeError('Lock is not acquired.')
def __enter__(self):
- if not self._locked:
- raise RuntimeError(
- '"yield from" should be used as context manager expression')
- return True
+ raise RuntimeError(
+ '"yield from" should be used as context manager expression')
def __exit__(self, *args):
- self.release()
+ # This must exist because __enter__ exists, even though that
+ # always raises; that's how the with-statement works.
+ pass
def __iter__(self):
+ # This is not a coroutine. It is meant to enable the idiom:
+ #
+ # with (yield from lock):
+ # <block>
+ #
+ # as an alternative to:
+ #
+ # yield from lock.acquire()
+ # try:
+ # <block>
+ # finally:
+ # lock.release()
yield from self.acquire()
- return self
+ return _ContextManager(self)
class Event:
@@ -311,14 +353,16 @@ class Condition:
self.notify(len(self._waiters))
def __enter__(self):
- return self._lock.__enter__()
+ raise RuntimeError(
+ '"yield from" should be used as context manager expression')
def __exit__(self, *args):
- return self._lock.__exit__(*args)
+ pass
def __iter__(self):
+ # See comment in Lock.__iter__().
yield from self.acquire()
- return self
+ return _ContextManager(self)
class Semaphore:
@@ -341,7 +385,6 @@ class Semaphore:
raise ValueError("Semaphore initial value must be >= 0")
self._value = value
self._waiters = collections.deque()
- self._locked = (value == 0)
if loop is not None:
self._loop = loop
else:
@@ -349,7 +392,7 @@ class Semaphore:
def __repr__(self):
res = super().__repr__()
- extra = 'locked' if self._locked else 'unlocked,value:{}'.format(
+ extra = 'locked' if self.locked() else 'unlocked,value:{}'.format(
self._value)
if self._waiters:
extra = '{},waiters:{}'.format(extra, len(self._waiters))
@@ -357,7 +400,7 @@ class Semaphore:
def locked(self):
"""Returns True if semaphore can not be acquired immediately."""
- return self._locked
+ return self._value == 0
@tasks.coroutine
def acquire(self):
@@ -371,8 +414,6 @@ class Semaphore:
"""
if not self._waiters and self._value > 0:
self._value -= 1
- if self._value == 0:
- self._locked = True
return True
fut = futures.Future(loop=self._loop)
@@ -380,8 +421,6 @@ class Semaphore:
try:
yield from fut
self._value -= 1
- if self._value == 0:
- self._locked = True
return True
finally:
self._waiters.remove(fut)
@@ -392,23 +431,22 @@ class Semaphore:
become larger than zero again, wake up that coroutine.
"""
self._value += 1
- self._locked = False
for waiter in self._waiters:
if not waiter.done():
waiter.set_result(True)
break
def __enter__(self):
- # TODO: This is questionable. How do we know the user actually
- # wrote "with (yield from sema)" instead of "with sema"?
- return True
+ raise RuntimeError(
+ '"yield from" should be used as context manager expression')
def __exit__(self, *args):
- self.release()
+ pass
def __iter__(self):
+ # See comment in Lock.__iter__().
yield from self.acquire()
- return self
+ return _ContextManager(self)
class BoundedSemaphore(Semaphore):
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
index 5d0e09e..0975f49 100644
--- a/Lib/test/test_asyncio/test_locks.py
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -208,6 +208,24 @@ class LockTests(unittest.TestCase):
self.assertFalse(lock.locked())
+ def test_context_manager_cant_reuse(self):
+ lock = asyncio.Lock(loop=self.loop)
+
+ @asyncio.coroutine
+ def acquire_lock():
+ return (yield from lock)
+
+ # This spells "yield from lock" outside a generator.
+ cm = self.loop.run_until_complete(acquire_lock())
+ with cm:
+ self.assertTrue(lock.locked())
+
+ self.assertFalse(lock.locked())
+
+ with self.assertRaises(AttributeError):
+ with cm:
+ pass
+
def test_context_manager_no_yield(self):
lock = asyncio.Lock(loop=self.loop)
@@ -219,6 +237,8 @@ class LockTests(unittest.TestCase):
str(err),
'"yield from" should be used as context manager expression')
+ self.assertFalse(lock.locked())
+
class EventTests(unittest.TestCase):
@@ -655,6 +675,8 @@ class ConditionTests(unittest.TestCase):
str(err),
'"yield from" should be used as context manager expression')
+ self.assertFalse(cond.locked())
+
class SemaphoreTests(unittest.TestCase):
@@ -830,6 +852,19 @@ class SemaphoreTests(unittest.TestCase):
self.assertEqual(2, sem._value)
+ def test_context_manager_no_yield(self):
+ sem = asyncio.Semaphore(2, loop=self.loop)
+
+ try:
+ with sem:
+ self.fail('RuntimeError is not raised in with expression')
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err),
+ '"yield from" should be used as context manager expression')
+
+ self.assertEqual(2, sem._value)
+
if __name__ == '__main__':
unittest.main()