diff options
-rw-r--r-- | Lib/asyncio/queues.py | 102 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_queues.py | 10 |
2 files changed, 48 insertions, 64 deletions
diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py index 4aeb6c4..84cdabc 100644 --- a/Lib/asyncio/queues.py +++ b/Lib/asyncio/queues.py @@ -1,7 +1,7 @@ """Queues""" -__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', - 'QueueFull', 'QueueEmpty'] +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty', + 'JoinableQueue'] import collections import heapq @@ -49,6 +49,9 @@ class Queue: self._getters = collections.deque() # Pairs of (item, Future). self._putters = collections.deque() + self._unfinished_tasks = 0 + self._finished = locks.Event(loop=self._loop) + self._finished.set() self._init(maxsize) def _init(self, maxsize): @@ -59,6 +62,8 @@ class Queue: def _put(self, item): self._queue.append(item) + self._unfinished_tasks += 1 + self._finished.clear() def __repr__(self): return '<{} at {:#x} {}>'.format( @@ -75,6 +80,8 @@ class Queue: result += ' _getters[{}]'.format(len(self._getters)) if self._putters: result += ' _putters[{}]'.format(len(self._putters)) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) return result def _consume_done_getters(self): @@ -126,9 +133,6 @@ class Queue: 'queue non-empty, why are getters waiting?') getter = self._getters.popleft() - - # Use _put and _get instead of passing item straight to getter, in - # case a subclass has logic that must run (e.g. JoinableQueue). self._put(item) # getter cannot be cancelled, we just removed done getters @@ -154,9 +158,6 @@ class Queue: 'queue non-empty, why are getters waiting?') getter = self._getters.popleft() - - # Use _put and _get instead of passing item straight to getter, in - # case a subclass has logic that must run (e.g. JoinableQueue). self._put(item) # getter cannot be cancelled, we just removed done getters @@ -219,6 +220,38 @@ class Queue: else: raise QueueEmpty + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer calls task_done() to + indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait() + class PriorityQueue(Queue): """A subclass of Queue; retrieves entries in priority order (lowest first). @@ -249,54 +282,5 @@ class LifoQueue(Queue): return self._queue.pop() -class JoinableQueue(Queue): - """A subclass of Queue with task_done() and join() methods.""" - - def __init__(self, maxsize=0, *, loop=None): - super().__init__(maxsize=maxsize, loop=loop) - self._unfinished_tasks = 0 - self._finished = locks.Event(loop=self._loop) - self._finished.set() - - def _format(self): - result = Queue._format(self) - if self._unfinished_tasks: - result += ' tasks={}'.format(self._unfinished_tasks) - return result - - def _put(self, item): - super()._put(item) - self._unfinished_tasks += 1 - self._finished.clear() - - def task_done(self): - """Indicate that a formerly enqueued task is complete. - - Used by queue consumers. For each get() used to fetch a task, - a subsequent call to task_done() tells the queue that the processing - on the task is complete. - - If a join() is currently blocking, it will resume when all items have - been processed (meaning that a task_done() call was received for every - item that had been put() into the queue). - - Raises ValueError if called more times than there were items placed in - the queue. - """ - if self._unfinished_tasks <= 0: - raise ValueError('task_done() called too many times') - self._unfinished_tasks -= 1 - if self._unfinished_tasks == 0: - self._finished.set() - - @coroutine - def join(self): - """Block until all items in the queue have been gotten and processed. - - The count of unfinished tasks goes up whenever an item is added to the - queue. The count goes down whenever a consumer thread calls task_done() - to indicate that the item was retrieved and all work on it is complete. - When the count of unfinished tasks drops to zero, join() unblocks. - """ - if self._unfinished_tasks > 0: - yield from self._finished.wait() +JoinableQueue = Queue +"""Deprecated alias for Queue.""" diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py index 3d4ac51..a73539d 100644 --- a/Lib/test/test_asyncio/test_queues.py +++ b/Lib/test/test_asyncio/test_queues.py @@ -408,14 +408,14 @@ class PriorityQueueTests(_QueueTestBase): self.assertEqual([1, 2, 3], items) -class JoinableQueueTests(_QueueTestBase): +class QueueJoinTests(_QueueTestBase): def test_task_done_underflow(self): - q = asyncio.JoinableQueue(loop=self.loop) + q = asyncio.Queue(loop=self.loop) self.assertRaises(ValueError, q.task_done) def test_task_done(self): - q = asyncio.JoinableQueue(loop=self.loop) + q = asyncio.Queue(loop=self.loop) for i in range(100): q.put_nowait(i) @@ -452,7 +452,7 @@ class JoinableQueueTests(_QueueTestBase): self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop)) def test_join_empty_queue(self): - q = asyncio.JoinableQueue(loop=self.loop) + q = asyncio.Queue(loop=self.loop) # Test that a queue join()s successfully, and before anything else # (done twice for insurance). @@ -465,7 +465,7 @@ class JoinableQueueTests(_QueueTestBase): self.loop.run_until_complete(join()) def test_format(self): - q = asyncio.JoinableQueue(loop=self.loop) + q = asyncio.Queue(loop=self.loop) self.assertEqual(q._format(), 'maxsize=0') q._unfinished_tasks = 2 |