diff options
author | Laurie O <laurie_opperman@hotmail.com> | 2024-02-10 04:58:30 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-10 04:58:30 (GMT) |
commit | b2d9d134dcb5633deebebf2b0118cd4f7ca598a2 (patch) | |
tree | b683686f4a42bdec2fe0540c585005dc4e5db8c2 /Lib | |
parent | d4d5bae1471788b345155e8e93a2fe4ab92d09dc (diff) | |
download | cpython-b2d9d134dcb5633deebebf2b0118cd4f7ca598a2.zip cpython-b2d9d134dcb5633deebebf2b0118cd4f7ca598a2.tar.gz cpython-b2d9d134dcb5633deebebf2b0118cd4f7ca598a2.tar.bz2 |
gh-96471: Add shutdown() method to queue.Queue (#104750)
Co-authored-by: Duprat <yduprat@gmail.com>
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/queue.py | 50 | ||||
-rw-r--r-- | Lib/test/test_queue.py | 378 |
2 files changed, 428 insertions, 0 deletions
diff --git a/Lib/queue.py b/Lib/queue.py index 55f5008..467ff4f 100644 --- a/Lib/queue.py +++ b/Lib/queue.py @@ -25,6 +25,10 @@ class Full(Exception): pass +class ShutDown(Exception): + '''Raised when put/get with shut-down queue.''' + + class Queue: '''Create a queue object with a given maximum size. @@ -54,6 +58,9 @@ class Queue: self.all_tasks_done = threading.Condition(self.mutex) self.unfinished_tasks = 0 + # Queue shutdown state + self.is_shutdown = False + def task_done(self): '''Indicate that a formerly enqueued task is complete. @@ -67,6 +74,8 @@ class Queue: Raises a ValueError if called more times than there were items placed in the queue. + + Raises ShutDown if the queue has been shut down immediately. ''' with self.all_tasks_done: unfinished = self.unfinished_tasks - 1 @@ -84,6 +93,8 @@ class Queue: to indicate the item was retrieved and all work on it is complete. When the count of unfinished tasks drops to zero, join() unblocks. + + Raises ShutDown if the queue has been shut down immediately. ''' with self.all_tasks_done: while self.unfinished_tasks: @@ -129,8 +140,12 @@ class Queue: Otherwise ('block' is false), put an item on the queue if a free slot is immediately available, else raise the Full exception ('timeout' is ignored in that case). + + Raises ShutDown if the queue has been shut down. ''' with self.not_full: + if self.is_shutdown: + raise ShutDown if self.maxsize > 0: if not block: if self._qsize() >= self.maxsize: @@ -138,6 +153,8 @@ class Queue: elif timeout is None: while self._qsize() >= self.maxsize: self.not_full.wait() + if self.is_shutdown: + raise ShutDown elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") else: @@ -147,6 +164,8 @@ class Queue: if remaining <= 0.0: raise Full self.not_full.wait(remaining) + if self.is_shutdown: + raise ShutDown self._put(item) self.unfinished_tasks += 1 self.not_empty.notify() @@ -161,14 +180,21 @@ class Queue: Otherwise ('block' is false), return an item if one is immediately available, else raise the Empty exception ('timeout' is ignored in that case). + + Raises ShutDown if the queue has been shut down and is empty, + or if the queue has been shut down immediately. ''' with self.not_empty: + if self.is_shutdown and not self._qsize(): + raise ShutDown if not block: if not self._qsize(): raise Empty elif timeout is None: while not self._qsize(): self.not_empty.wait() + if self.is_shutdown and not self._qsize(): + raise ShutDown elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") else: @@ -178,6 +204,8 @@ class Queue: if remaining <= 0.0: raise Empty self.not_empty.wait(remaining) + if self.is_shutdown and not self._qsize(): + raise ShutDown item = self._get() self.not_full.notify() return item @@ -198,6 +226,28 @@ class Queue: ''' return self.get(block=False) + def shutdown(self, immediate=False): + '''Shut-down the queue, making queue gets and puts raise. + + By default, gets will only raise once the queue is empty. Set + 'immediate' to True to make gets raise immediately instead. + + All blocked callers of put() will be unblocked, and also get() + and join() if 'immediate'. The ShutDown exception is raised. + ''' + with self.mutex: + self.is_shutdown = True + if immediate: + n_items = self._qsize() + while self._qsize(): + self._get() + if self.unfinished_tasks > 0: + self.unfinished_tasks -= 1 + self.not_empty.notify_all() + # release all blocked threads in `join()` + self.all_tasks_done.notify_all() + self.not_full.notify_all() + # Override these methods to implement other queue organizations # (e.g. stack or priority queue). # These will only be called with appropriate locks held diff --git a/Lib/test/test_queue.py b/Lib/test/test_queue.py index 33113a7..e3d4d56 100644 --- a/Lib/test/test_queue.py +++ b/Lib/test/test_queue.py @@ -241,6 +241,384 @@ class BaseQueueTestMixin(BlockingTestMixin): with self.assertRaises(self.queue.Full): q.put_nowait(4) + def test_shutdown_empty(self): + q = self.type2test() + q.shutdown() + with self.assertRaises(self.queue.ShutDown): + q.put("data") + with self.assertRaises(self.queue.ShutDown): + q.get() + + def test_shutdown_nonempty(self): + q = self.type2test() + q.put("data") + q.shutdown() + q.get() + with self.assertRaises(self.queue.ShutDown): + q.get() + + def test_shutdown_immediate(self): + q = self.type2test() + q.put("data") + q.shutdown(immediate=True) + with self.assertRaises(self.queue.ShutDown): + q.get() + + def test_shutdown_allowed_transitions(self): + # allowed transitions would be from alive via shutdown to immediate + q = self.type2test() + self.assertFalse(q.is_shutdown) + + q.shutdown() + self.assertTrue(q.is_shutdown) + + q.shutdown(immediate=True) + self.assertTrue(q.is_shutdown) + + q.shutdown(immediate=False) + + def _shutdown_all_methods_in_one_thread(self, immediate): + q = self.type2test(2) + q.put("L") + q.put_nowait("O") + q.shutdown(immediate) + + with self.assertRaises(self.queue.ShutDown): + q.put("E") + with self.assertRaises(self.queue.ShutDown): + q.put_nowait("W") + if immediate: + with self.assertRaises(self.queue.ShutDown): + q.get() + with self.assertRaises(self.queue.ShutDown): + q.get_nowait() + with self.assertRaises(ValueError): + q.task_done() + q.join() + else: + self.assertIn(q.get(), "LO") + q.task_done() + self.assertIn(q.get(), "LO") + q.task_done() + q.join() + # on shutdown(immediate=False) + # when queue is empty, should raise ShutDown Exception + with self.assertRaises(self.queue.ShutDown): + q.get() # p.get(True) + with self.assertRaises(self.queue.ShutDown): + q.get_nowait() # p.get(False) + with self.assertRaises(self.queue.ShutDown): + q.get(True, 1.0) + + def test_shutdown_all_methods_in_one_thread(self): + return self._shutdown_all_methods_in_one_thread(False) + + def test_shutdown_immediate_all_methods_in_one_thread(self): + return self._shutdown_all_methods_in_one_thread(True) + + def _write_msg_thread(self, q, n, results, delay, + i_when_exec_shutdown, + event_start, event_end): + event_start.wait() + for i in range(1, n+1): + try: + q.put((i, "YDLO")) + results.append(True) + except self.queue.ShutDown: + results.append(False) + # triggers shutdown of queue + if i == i_when_exec_shutdown: + event_end.set() + time.sleep(delay) + # end of all puts + q.join() + + def _read_msg_thread(self, q, nb, results, delay, event_start): + event_start.wait() + block = True + while nb: + time.sleep(delay) + try: + # Get at least one message + q.get(block) + block = False + q.task_done() + results.append(True) + nb -= 1 + except self.queue.ShutDown: + results.append(False) + nb -= 1 + except self.queue.Empty: + pass + q.join() + + def _shutdown_thread(self, q, event_end, immediate): + event_end.wait() + q.shutdown(immediate) + q.join() + + def _join_thread(self, q, delay, event_start): + event_start.wait() + time.sleep(delay) + q.join() + + def _shutdown_all_methods_in_many_threads(self, immediate): + q = self.type2test() + ps = [] + ev_start = threading.Event() + ev_exec_shutdown = threading.Event() + res_puts = [] + res_gets = [] + delay = 1e-4 + read_process = 4 + nb_msgs = read_process * 16 + nb_msgs_r = nb_msgs // read_process + when_exec_shutdown = nb_msgs // 2 + lprocs = ( + (self._write_msg_thread, 1, (q, nb_msgs, res_puts, delay, + when_exec_shutdown, + ev_start, ev_exec_shutdown)), + (self._read_msg_thread, read_process, (q, nb_msgs_r, + res_gets, delay*2, + ev_start)), + (self._join_thread, 2, (q, delay*2, ev_start)), + (self._shutdown_thread, 1, (q, ev_exec_shutdown, immediate)), + ) + # start all threds + for func, n, args in lprocs: + for i in range(n): + ps.append(threading.Thread(target=func, args=args)) + ps[-1].start() + # set event in order to run q.shutdown() + ev_start.set() + + if not immediate: + assert(len(res_gets) == len(res_puts)) + assert(res_gets.count(True) == res_puts.count(True)) + else: + assert(len(res_gets) <= len(res_puts)) + assert(res_gets.count(True) <= res_puts.count(True)) + + for thread in ps[1:]: + thread.join() + + def test_shutdown_all_methods_in_many_threads(self): + return self._shutdown_all_methods_in_many_threads(False) + + def test_shutdown_immediate_all_methods_in_many_threads(self): + return self._shutdown_all_methods_in_many_threads(True) + + def _get(self, q, go, results, shutdown=False): + go.wait() + try: + msg = q.get() + results.append(not shutdown) + return not shutdown + except self.queue.ShutDown: + results.append(shutdown) + return shutdown + + def _get_shutdown(self, q, go, results): + return self._get(q, go, results, True) + + def _get_task_done(self, q, go, results): + go.wait() + try: + msg = q.get() + q.task_done() + results.append(True) + return msg + except self.queue.ShutDown: + results.append(False) + return False + + def _put(self, q, msg, go, results, shutdown=False): + go.wait() + try: + q.put(msg) + results.append(not shutdown) + return not shutdown + except self.queue.ShutDown: + results.append(shutdown) + return shutdown + + def _put_shutdown(self, q, msg, go, results): + return self._put(q, msg, go, results, True) + + def _join(self, q, results, shutdown=False): + try: + q.join() + results.append(not shutdown) + return not shutdown + except self.queue.ShutDown: + results.append(shutdown) + return shutdown + + def _join_shutdown(self, q, results): + return self._join(q, results, True) + + def _shutdown_get(self, immediate): + q = self.type2test(2) + results = [] + go = threading.Event() + q.put("Y") + q.put("D") + # queue full + + if immediate: + thrds = ( + (self._get_shutdown, (q, go, results)), + (self._get_shutdown, (q, go, results)), + ) + else: + thrds = ( + # on shutdown(immediate=False) + # one of these threads shoud raise Shutdown + (self._get, (q, go, results)), + (self._get, (q, go, results)), + (self._get, (q, go, results)), + ) + threads = [] + for func, params in thrds: + threads.append(threading.Thread(target=func, args=params)) + threads[-1].start() + q.shutdown(immediate) + go.set() + for t in threads: + t.join() + if immediate: + self.assertListEqual(results, [True, True]) + else: + self.assertListEqual(sorted(results), [False] + [True]*(len(thrds)-1)) + + def test_shutdown_get(self): + return self._shutdown_get(False) + + def test_shutdown_immediate_get(self): + return self._shutdown_get(True) + + def _shutdown_put(self, immediate): + q = self.type2test(2) + results = [] + go = threading.Event() + q.put("Y") + q.put("D") + # queue fulled + + thrds = ( + (self._put_shutdown, (q, "E", go, results)), + (self._put_shutdown, (q, "W", go, results)), + ) + threads = [] + for func, params in thrds: + threads.append(threading.Thread(target=func, args=params)) + threads[-1].start() + q.shutdown() + go.set() + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def test_shutdown_put(self): + return self._shutdown_put(False) + + def test_shutdown_immediate_put(self): + return self._shutdown_put(True) + + def _shutdown_join(self, immediate): + q = self.type2test() + results = [] + q.put("Y") + go = threading.Event() + nb = q.qsize() + + thrds = ( + (self._join, (q, results)), + (self._join, (q, results)), + ) + threads = [] + for func, params in thrds: + threads.append(threading.Thread(target=func, args=params)) + threads[-1].start() + if not immediate: + res = [] + for i in range(nb): + threads.append(threading.Thread(target=self._get_task_done, args=(q, go, res))) + threads[-1].start() + q.shutdown(immediate) + go.set() + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def test_shutdown_immediate_join(self): + return self._shutdown_join(True) + + def test_shutdown_join(self): + return self._shutdown_join(False) + + def _shutdown_put_join(self, immediate): + q = self.type2test(2) + results = [] + go = threading.Event() + q.put("Y") + nb = q.qsize() + # queue not fulled + + thrds = ( + (self._put_shutdown, (q, "E", go, results)), + (self._join, (q, results)), + ) + threads = [] + for func, params in thrds: + threads.append(threading.Thread(target=func, args=params)) + threads[-1].start() + self.assertEqual(q.unfinished_tasks, nb) + for i in range(nb): + t = threading.Thread(target=q.task_done) + t.start() + threads.append(t) + q.shutdown(immediate) + go.set() + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def test_shutdown_immediate_put_join(self): + return self._shutdown_put_join(True) + + def test_shutdown_put_join(self): + return self._shutdown_put_join(False) + + def test_shutdown_get_task_done_join(self): + q = self.type2test(2) + results = [] + go = threading.Event() + q.put("Y") + q.put("D") + self.assertEqual(q.unfinished_tasks, q.qsize()) + + thrds = ( + (self._get_task_done, (q, go, results)), + (self._get_task_done, (q, go, results)), + (self._join, (q, results)), + (self._join, (q, results)), + ) + threads = [] + for func, params in thrds: + threads.append(threading.Thread(target=func, args=params)) + threads[-1].start() + go.set() + q.shutdown(False) + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + class QueueTest(BaseQueueTestMixin): def setUp(self): |