From 3ab0136ac5d6059ce96d4debca89c5f5ab0356f5 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Thu, 17 May 2018 10:27:09 -0400 Subject: bpo-32604: Implement force-closing channels. (gh-6937) This will make it easier to clean up channels (e.g. when used in tests). --- Lib/test/test__xxsubinterpreters.py | 100 ++++++++++++++++++++++- Modules/_xxsubinterpretersmodule.c | 157 +++++++++++++++++++++++++++++++----- 2 files changed, 232 insertions(+), 25 deletions(-) diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py index 118f2e4..f66cc95 100644 --- a/Lib/test/test__xxsubinterpreters.py +++ b/Lib/test/test__xxsubinterpreters.py @@ -1379,13 +1379,105 @@ class ChannelTests(TestBase): with self.assertRaises(interpreters.ChannelClosedError): interpreters.channel_close(cid) - def test_close_with_unused_items(self): + def test_close_empty(self): + tests = [ + (False, False), + (True, False), + (False, True), + (True, True), + ] + for send, recv in tests: + with self.subTest((send, recv)): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_recv(cid) + interpreters.channel_close(cid, send=send, recv=recv) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + + def test_close_defaults_with_unused_items(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + + with self.assertRaises(interpreters.ChannelNotEmptyError): + interpreters.channel_close(cid) + interpreters.channel_recv(cid) + interpreters.channel_send(cid, b'eggs') + + def test_close_recv_with_unused_items_unforced(self): cid = interpreters.channel_create() interpreters.channel_send(cid, b'spam') interpreters.channel_send(cid, b'ham') - interpreters.channel_close(cid) + + with self.assertRaises(interpreters.ChannelNotEmptyError): + interpreters.channel_close(cid, recv=True) + interpreters.channel_recv(cid) + interpreters.channel_send(cid, b'eggs') + interpreters.channel_recv(cid) + interpreters.channel_recv(cid) + interpreters.channel_close(cid, recv=True) + + def test_close_send_with_unused_items_unforced(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + interpreters.channel_close(cid, send=True) with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + interpreters.channel_recv(cid) + interpreters.channel_recv(cid) + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + + def test_close_both_with_unused_items_unforced(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + + with self.assertRaises(interpreters.ChannelNotEmptyError): + interpreters.channel_close(cid, recv=True, send=True) + interpreters.channel_recv(cid) + interpreters.channel_send(cid, b'eggs') + interpreters.channel_recv(cid) + interpreters.channel_recv(cid) + interpreters.channel_close(cid, recv=True) + + def test_close_recv_with_unused_items_forced(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + interpreters.channel_close(cid, recv=True, force=True) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + + def test_close_send_with_unused_items_forced(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + interpreters.channel_close(cid, send=True, force=True) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + + def test_close_both_with_unused_items_forced(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + interpreters.channel_close(cid, send=True, recv=True, force=True) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + with self.assertRaises(interpreters.ChannelClosedError): interpreters.channel_recv(cid) def test_close_never_used(self): @@ -1403,7 +1495,7 @@ class ChannelTests(TestBase): interp = interpreters.create() interpreters.run_string(interp, dedent(f""" import _xxsubinterpreters as _interpreters - _interpreters.channel_close({cid}) + _interpreters.channel_close({cid}, force=True) """)) with self.assertRaises(interpreters.ChannelClosedError): interpreters.channel_recv(cid) @@ -1416,7 +1508,7 @@ class ChannelTests(TestBase): interpreters.channel_send(cid, b'spam') interpreters.channel_send(cid, b'spam') interpreters.channel_recv(cid) - interpreters.channel_close(cid) + interpreters.channel_close(cid, force=True) with self.assertRaises(interpreters.ChannelClosedError): interpreters.channel_send(cid, b'eggs') diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c index 5184f65..72387d8 100644 --- a/Modules/_xxsubinterpretersmodule.c +++ b/Modules/_xxsubinterpretersmodule.c @@ -306,10 +306,15 @@ _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass) /* channel-specific code ****************************************************/ +#define CHANNEL_SEND 1 +#define CHANNEL_BOTH 0 +#define CHANNEL_RECV -1 + static PyObject *ChannelError; static PyObject *ChannelNotFoundError; static PyObject *ChannelClosedError; static PyObject *ChannelEmptyError; +static PyObject *ChannelNotEmptyError; static int channel_exceptions_init(PyObject *ns) @@ -356,6 +361,16 @@ channel_exceptions_init(PyObject *ns) return -1; } + // An operation tried to close a non-empty channel. + ChannelNotEmptyError = PyErr_NewException( + "_xxsubinterpreters.ChannelNotEmptyError", ChannelError, NULL); + if (ChannelNotEmptyError == NULL) { + return -1; + } + if (PyDict_SetItemString(ns, "ChannelNotEmptyError", ChannelNotEmptyError) != 0) { + return -1; + } + return 0; } @@ -696,8 +711,11 @@ _channelends_close_interpreter(_channelends *ends, int64_t interp, int which) } static void -_channelends_close_all(_channelends *ends) +_channelends_close_all(_channelends *ends, int which, int force) { + // XXX Handle the ends. + // XXX Handle force is True. + // Ensure all the "send"-associated interpreters are closed. _channelend *end; for (end = ends->send; end != NULL; end = end->next) { @@ -713,12 +731,16 @@ _channelends_close_all(_channelends *ends) /* channels */ struct _channel; +struct _channel_closing; +static void _channel_clear_closing(struct _channel *); +static void _channel_finish_closing(struct _channel *); typedef struct _channel { PyThread_type_lock mutex; _channelqueue *queue; _channelends *ends; int open; + struct _channel_closing *closing; } _PyChannelState; static _PyChannelState * @@ -747,12 +769,14 @@ _channel_new(void) return NULL; } chan->open = 1; + chan->closing = NULL; return chan; } static void _channel_free(_PyChannelState *chan) { + _channel_clear_closing(chan); PyThread_acquire_lock(chan->mutex, WAIT_LOCK); _channelqueue_free(chan->queue); _channelends_free(chan->ends); @@ -802,13 +826,20 @@ _channel_next(_PyChannelState *chan, int64_t interp) } data = _channelqueue_get(chan->queue); + if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) { + chan->open = 0; + } + done: PyThread_release_lock(chan->mutex); + if (chan->queue->count == 0) { + _channel_finish_closing(chan); + } return data; } static int -_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which) +_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end) { PyThread_acquire_lock(chan->mutex, WAIT_LOCK); @@ -818,7 +849,7 @@ _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which) goto done; } - if (_channelends_close_interpreter(chan->ends, interp, which) != 0) { + if (_channelends_close_interpreter(chan->ends, interp, end) != 0) { goto done; } chan->open = _channelends_is_open(chan->ends); @@ -830,7 +861,7 @@ done: } static int -_channel_close_all(_PyChannelState *chan) +_channel_close_all(_PyChannelState *chan, int end, int force) { int res = -1; PyThread_acquire_lock(chan->mutex, WAIT_LOCK); @@ -840,11 +871,17 @@ _channel_close_all(_PyChannelState *chan) goto done; } + if (!force && chan->queue->count > 0) { + PyErr_SetString(ChannelNotEmptyError, + "may not be closed if not empty (try force=True)"); + goto done; + } + chan->open = 0; // We *could* also just leave these in place, since we've marked // the channel as closed already. - _channelends_close_all(chan->ends); + _channelends_close_all(chan->ends, end, force); res = 0; done: @@ -889,6 +926,9 @@ _channelref_new(int64_t id, _PyChannelState *chan) static void _channelref_free(_channelref *ref) { + if (ref->chan != NULL) { + _channel_clear_closing(ref->chan); + } //_channelref_clear(ref); PyMem_Free(ref); } @@ -1009,8 +1049,12 @@ done: return cid; } +/* forward */ +static int _channel_set_closing(struct _channelref *, PyThread_type_lock); + static int -_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan) +_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan, + int end, int force) { int res = -1; PyThread_acquire_lock(channels->mutex, WAIT_LOCK); @@ -1028,14 +1072,35 @@ _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan) PyErr_Format(ChannelClosedError, "channel %d closed", cid); goto done; } + else if (!force && end == CHANNEL_SEND && ref->chan->closing != NULL) { + PyErr_Format(ChannelClosedError, "channel %d closed", cid); + goto done; + } else { - if (_channel_close_all(ref->chan) != 0) { + if (_channel_close_all(ref->chan, end, force) != 0) { + if (end == CHANNEL_SEND && + PyErr_ExceptionMatches(ChannelNotEmptyError)) { + if (ref->chan->closing != NULL) { + PyErr_Format(ChannelClosedError, "channel %d closed", cid); + goto done; + } + // Mark the channel as closing and return. The channel + // will be cleaned up in _channel_next(). + PyErr_Clear(); + if (_channel_set_closing(ref, channels->mutex) != 0) { + goto done; + } + if (pchan != NULL) { + *pchan = ref->chan; + } + res = 0; + } goto done; } if (pchan != NULL) { *pchan = ref->chan; } - else { + else { _channel_free(ref->chan); } ref->chan = NULL; @@ -1161,6 +1226,60 @@ done: return cids; } +/* support for closing non-empty channels */ + +struct _channel_closing { + struct _channelref *ref; +}; + +static int +_channel_set_closing(struct _channelref *ref, PyThread_type_lock mutex) { + struct _channel *chan = ref->chan; + if (chan == NULL) { + // already closed + return 0; + } + int res = -1; + PyThread_acquire_lock(chan->mutex, WAIT_LOCK); + if (chan->closing != NULL) { + PyErr_SetString(ChannelClosedError, "channel closed"); + goto done; + } + chan->closing = PyMem_NEW(struct _channel_closing, 1); + if (chan->closing == NULL) { + goto done; + } + chan->closing->ref = ref; + + res = 0; +done: + PyThread_release_lock(chan->mutex); + return res; +} + +static void +_channel_clear_closing(struct _channel *chan) { + PyThread_acquire_lock(chan->mutex, WAIT_LOCK); + if (chan->closing != NULL) { + PyMem_Free(chan->closing); + chan->closing = NULL; + } + PyThread_release_lock(chan->mutex); +} + +static void +_channel_finish_closing(struct _channel *chan) { + struct _channel_closing *closing = chan->closing; + if (closing == NULL) { + return; + } + _channelref *ref = closing->ref; + _channel_clear_closing(chan); + // Do the things that would have been done in _channels_close(). + ref->chan = NULL; + _channel_free(chan); +}; + /* "high"-level channel-related functions */ static int64_t @@ -1207,6 +1326,12 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj) } // Past this point we are responsible for releasing the mutex. + if (chan->closing != NULL) { + PyErr_Format(ChannelClosedError, "channel %d closed", id); + PyThread_release_lock(mutex); + return -1; + } + // Convert the object to cross-interpreter data. _PyCrossInterpreterData *data = PyMem_NEW(_PyCrossInterpreterData, 1); if (data == NULL) { @@ -1290,16 +1415,13 @@ _channel_drop(_channels *channels, int64_t id, int send, int recv) } static int -_channel_close(_channels *channels, int64_t id) +_channel_close(_channels *channels, int64_t id, int end, int force) { - return _channels_close(channels, id, NULL); + return _channels_close(channels, id, NULL, end, force); } /* ChannelID class */ -#define CHANNEL_SEND 1 -#define CHANNEL_RECV -1 - static PyTypeObject ChannelIDtype; typedef struct channelid { @@ -2555,15 +2677,8 @@ channel_close(PyObject *self, PyObject *args, PyObject *kwds) if (cid < 0) { return NULL; } - if (send == 0 && recv == 0) { - send = 1; - recv = 1; - } - - // XXX Handle the ends. - // XXX Handle force is True. - if (_channel_close(&_globals.channels, cid) != 0) { + if (_channel_close(&_globals.channels, cid, send-recv, force) != 0) { return NULL; } Py_RETURN_NONE; -- cgit v0.12