summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorLaurie O <laurie_opperman@hotmail.com>2024-04-06 14:27:13 (GMT)
committerGitHub <noreply@github.com>2024-04-06 14:27:13 (GMT)
commitdf4d84c3cdca572f1be8f5dc5ef8ead5351b51fb (patch)
treedbbb1036001b8caccf3ff0a0b436aee47f96fff5 /Lib
parent1d3225ae056245da75e4a443ccafcc8f4f982cf2 (diff)
downloadcpython-df4d84c3cdca572f1be8f5dc5ef8ead5351b51fb.zip
cpython-df4d84c3cdca572f1be8f5dc5ef8ead5351b51fb.tar.gz
cpython-df4d84c3cdca572f1be8f5dc5ef8ead5351b51fb.tar.bz2
gh-96471: Add asyncio queue shutdown (#104228)
Co-authored-by: Duprat <yduprat@gmail.com>
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/queues.py68
-rw-r--r--Lib/test/test_asyncio/test_queues.py199
2 files changed, 264 insertions, 3 deletions
diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py
index a9656a6..b815670 100644
--- a/Lib/asyncio/queues.py
+++ b/Lib/asyncio/queues.py
@@ -1,4 +1,11 @@
-__all__ = ('Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty')
+__all__ = (
+ 'Queue',
+ 'PriorityQueue',
+ 'LifoQueue',
+ 'QueueFull',
+ 'QueueEmpty',
+ 'QueueShutDown',
+)
import collections
import heapq
@@ -18,6 +25,11 @@ class QueueFull(Exception):
pass
+class QueueShutDown(Exception):
+ """Raised when putting on to or getting from a shut-down Queue."""
+ pass
+
+
class Queue(mixins._LoopBoundMixin):
"""A queue, useful for coordinating producer and consumer coroutines.
@@ -41,6 +53,7 @@ class Queue(mixins._LoopBoundMixin):
self._finished = locks.Event()
self._finished.set()
self._init(maxsize)
+ self._is_shutdown = False
# These three are overridable in subclasses.
@@ -81,6 +94,8 @@ class Queue(mixins._LoopBoundMixin):
result += f' _putters[{len(self._putters)}]'
if self._unfinished_tasks:
result += f' tasks={self._unfinished_tasks}'
+ if self._is_shutdown:
+ result += ' shutdown'
return result
def qsize(self):
@@ -112,8 +127,12 @@ class Queue(mixins._LoopBoundMixin):
Put an item into the queue. If the queue is full, wait until a free
slot is available before adding item.
+
+ Raises QueueShutDown if the queue has been shut down.
"""
while self.full():
+ if self._is_shutdown:
+ raise QueueShutDown
putter = self._get_loop().create_future()
self._putters.append(putter)
try:
@@ -125,7 +144,7 @@ class Queue(mixins._LoopBoundMixin):
self._putters.remove(putter)
except ValueError:
# The putter could be removed from self._putters by a
- # previous get_nowait call.
+ # previous get_nowait call or a shutdown call.
pass
if not self.full() and not putter.cancelled():
# We were woken up by get_nowait(), but can't take
@@ -138,7 +157,11 @@ class Queue(mixins._LoopBoundMixin):
"""Put an item into the queue without blocking.
If no free slot is immediately available, raise QueueFull.
+
+ Raises QueueShutDown if the queue has been shut down.
"""
+ if self._is_shutdown:
+ raise QueueShutDown
if self.full():
raise QueueFull
self._put(item)
@@ -150,8 +173,13 @@ class Queue(mixins._LoopBoundMixin):
"""Remove and return an item from the queue.
If queue is empty, wait until an item is available.
+
+ Raises QueueShutDown if the queue has been shut down and is empty, or
+ if the queue has been shut down immediately.
"""
while self.empty():
+ if self._is_shutdown and self.empty():
+ raise QueueShutDown
getter = self._get_loop().create_future()
self._getters.append(getter)
try:
@@ -163,7 +191,7 @@ class Queue(mixins._LoopBoundMixin):
self._getters.remove(getter)
except ValueError:
# The getter could be removed from self._getters by a
- # previous put_nowait call.
+ # previous put_nowait call, or a shutdown call.
pass
if not self.empty() and not getter.cancelled():
# We were woken up by put_nowait(), but can't take
@@ -176,8 +204,13 @@ class Queue(mixins._LoopBoundMixin):
"""Remove and return an item from the queue.
Return an item if one is immediately available, else raise QueueEmpty.
+
+ Raises QueueShutDown if the queue has been shut down and is empty, or
+ if the queue has been shut down immediately.
"""
if self.empty():
+ if self._is_shutdown:
+ raise QueueShutDown
raise QueueEmpty
item = self._get()
self._wakeup_next(self._putters)
@@ -194,6 +227,9 @@ class Queue(mixins._LoopBoundMixin):
been processed (meaning that a task_done() call was received for every
item that had been put() into the queue).
+ shutdown(immediate=True) calls task_done() for each remaining item in
+ the queue.
+
Raises ValueError if called more times than there were items placed in
the queue.
"""
@@ -214,6 +250,32 @@ class Queue(mixins._LoopBoundMixin):
if self._unfinished_tasks > 0:
await self._finished.wait()
+ def shutdown(self, immediate=False):
+ """Shut-down the queue, making queue gets and puts raise QueueShutDown.
+
+ By default, gets will only raise once the queue is empty. Set
+ 'immediate' to True to make gets raise immediately instead.
+
+ All blocked callers of put() will be unblocked, and also get()
+ and join() if 'immediate'.
+ """
+ self._is_shutdown = True
+ if immediate:
+ while not self.empty():
+ self._get()
+ if self._unfinished_tasks > 0:
+ self._unfinished_tasks -= 1
+ if self._unfinished_tasks == 0:
+ self._finished.set()
+ while self._getters:
+ getter = self._getters.popleft()
+ if not getter.done():
+ getter.set_result(None)
+ while self._putters:
+ putter = self._putters.popleft()
+ if not putter.done():
+ putter.set_result(None)
+
class PriorityQueue(Queue):
"""A subclass of Queue; retrieves entries in priority order (lowest first).
diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py
index 2d058cc..5019e9a 100644
--- a/Lib/test/test_asyncio/test_queues.py
+++ b/Lib/test/test_asyncio/test_queues.py
@@ -522,5 +522,204 @@ class PriorityQueueJoinTests(_QueueJoinTestMixin, unittest.IsolatedAsyncioTestCa
q_class = asyncio.PriorityQueue
+class _QueueShutdownTestMixin:
+ q_class = None
+
+ def assertRaisesShutdown(self, msg="Didn't appear to shut-down queue"):
+ return self.assertRaises(asyncio.QueueShutDown, msg=msg)
+
+ async def test_format(self):
+ q = self.q_class()
+ q.shutdown()
+ self.assertEqual(q._format(), 'maxsize=0 shutdown')
+
+ async def test_shutdown_empty(self):
+ # Test shutting down an empty queue
+
+ # Setup empty queue, and join() and get() tasks
+ q = self.q_class()
+ loop = asyncio.get_running_loop()
+ get_task = loop.create_task(q.get())
+ await asyncio.sleep(0) # want get task pending before shutdown
+
+ # Perform shut-down
+ q.shutdown(immediate=False) # unfinished tasks: 0 -> 0
+
+ self.assertEqual(q.qsize(), 0)
+
+ # Ensure join() task successfully finishes
+ await q.join()
+
+ # Ensure get() task is finished, and raised ShutDown
+ await asyncio.sleep(0)
+ self.assertTrue(get_task.done())
+ with self.assertRaisesShutdown():
+ await get_task
+
+ # Ensure put() and get() raise ShutDown
+ with self.assertRaisesShutdown():
+ await q.put("data")
+ with self.assertRaisesShutdown():
+ q.put_nowait("data")
+
+ with self.assertRaisesShutdown():
+ await q.get()
+ with self.assertRaisesShutdown():
+ q.get_nowait()
+
+ async def test_shutdown_nonempty(self):
+ # Test shutting down a non-empty queue
+
+ # Setup full queue with 1 item, and join() and put() tasks
+ q = self.q_class(maxsize=1)
+ loop = asyncio.get_running_loop()
+
+ q.put_nowait("data")
+ join_task = loop.create_task(q.join())
+ put_task = loop.create_task(q.put("data2"))
+
+ # Ensure put() task is not finished
+ await asyncio.sleep(0)
+ self.assertFalse(put_task.done())
+
+ # Perform shut-down
+ q.shutdown(immediate=False) # unfinished tasks: 1 -> 1
+
+ self.assertEqual(q.qsize(), 1)
+
+ # Ensure put() task is finished, and raised ShutDown
+ await asyncio.sleep(0)
+ self.assertTrue(put_task.done())
+ with self.assertRaisesShutdown():
+ await put_task
+
+ # Ensure get() succeeds on enqueued item
+ self.assertEqual(await q.get(), "data")
+
+ # Ensure join() task is not finished
+ await asyncio.sleep(0)
+ self.assertFalse(join_task.done())
+
+ # Ensure put() and get() raise ShutDown
+ with self.assertRaisesShutdown():
+ await q.put("data")
+ with self.assertRaisesShutdown():
+ q.put_nowait("data")
+
+ with self.assertRaisesShutdown():
+ await q.get()
+ with self.assertRaisesShutdown():
+ q.get_nowait()
+
+ # Ensure there is 1 unfinished task, and join() task succeeds
+ q.task_done()
+
+ await asyncio.sleep(0)
+ self.assertTrue(join_task.done())
+ await join_task
+
+ with self.assertRaises(
+ ValueError, msg="Didn't appear to mark all tasks done"
+ ):
+ q.task_done()
+
+ async def test_shutdown_immediate(self):
+ # Test immediately shutting down a queue
+
+ # Setup queue with 1 item, and a join() task
+ q = self.q_class()
+ loop = asyncio.get_running_loop()
+ q.put_nowait("data")
+ join_task = loop.create_task(q.join())
+
+ # Perform shut-down
+ q.shutdown(immediate=True) # unfinished tasks: 1 -> 0
+
+ self.assertEqual(q.qsize(), 0)
+
+ # Ensure join() task has successfully finished
+ await asyncio.sleep(0)
+ self.assertTrue(join_task.done())
+ await join_task
+
+ # Ensure put() and get() raise ShutDown
+ with self.assertRaisesShutdown():
+ await q.put("data")
+ with self.assertRaisesShutdown():
+ q.put_nowait("data")
+
+ with self.assertRaisesShutdown():
+ await q.get()
+ with self.assertRaisesShutdown():
+ q.get_nowait()
+
+ # Ensure there are no unfinished tasks
+ with self.assertRaises(
+ ValueError, msg="Didn't appear to mark all tasks done"
+ ):
+ q.task_done()
+
+ async def test_shutdown_immediate_with_unfinished(self):
+ # Test immediately shutting down a queue with unfinished tasks
+
+ # Setup queue with 2 items (1 retrieved), and a join() task
+ q = self.q_class()
+ loop = asyncio.get_running_loop()
+ q.put_nowait("data")
+ q.put_nowait("data")
+ join_task = loop.create_task(q.join())
+ self.assertEqual(await q.get(), "data")
+
+ # Perform shut-down
+ q.shutdown(immediate=True) # unfinished tasks: 2 -> 1
+
+ self.assertEqual(q.qsize(), 0)
+
+ # Ensure join() task is not finished
+ await asyncio.sleep(0)
+ self.assertFalse(join_task.done())
+
+ # Ensure put() and get() raise ShutDown
+ with self.assertRaisesShutdown():
+ await q.put("data")
+ with self.assertRaisesShutdown():
+ q.put_nowait("data")
+
+ with self.assertRaisesShutdown():
+ await q.get()
+ with self.assertRaisesShutdown():
+ q.get_nowait()
+
+ # Ensure there is 1 unfinished task
+ q.task_done()
+ with self.assertRaises(
+ ValueError, msg="Didn't appear to mark all tasks done"
+ ):
+ q.task_done()
+
+ # Ensure join() task has successfully finished
+ await asyncio.sleep(0)
+ self.assertTrue(join_task.done())
+ await join_task
+
+
+class QueueShutdownTests(
+ _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
+):
+ q_class = asyncio.Queue
+
+
+class LifoQueueShutdownTests(
+ _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
+):
+ q_class = asyncio.LifoQueue
+
+
+class PriorityQueueShutdownTests(
+ _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
+):
+ q_class = asyncio.PriorityQueue
+
+
if __name__ == '__main__':
unittest.main()