diff options
-rw-r--r-- | Lib/test/test__xxsubinterpreters.py | 301 | ||||
-rw-r--r-- | Modules/_xxsubinterpretersmodule.c | 1139 |
2 files changed, 1317 insertions, 123 deletions
diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py index e17bfde..039c040 100644 --- a/Lib/test/test__xxsubinterpreters.py +++ b/Lib/test/test__xxsubinterpreters.py @@ -1,3 +1,4 @@ +import builtins from collections import namedtuple import contextlib import itertools @@ -866,10 +867,11 @@ class RunStringTests(TestBase): yield if msg is None: self.assertEqual(str(caught.exception).split(':')[0], - str(exctype)) + exctype.__name__) else: self.assertEqual(str(caught.exception), - "{}: {}".format(exctype, msg)) + "{}: {}".format(exctype.__name__, msg)) + self.assertIsInstance(caught.exception.__cause__, exctype) def test_invalid_syntax(self): with self.assert_run_failed(SyntaxError): @@ -1060,6 +1062,301 @@ class RunStringTests(TestBase): self.assertEqual(retcode, 0) +def build_exception(exctype, /, *args, **kwargs): + # XXX Use __qualname__? + name = exctype.__name__ + argreprs = [repr(a) for a in args] + if kwargs: + kwargreprs = [f'{k}={v!r}' for k, v in kwargs.items()] + script = f'{name}({", ".join(argreprs)}, {", ".join(kwargreprs)})' + else: + script = f'{name}({", ".join(argreprs)})' + expected = exctype(*args, **kwargs) + return script, expected + + +def build_exceptions(self, *exctypes, default=None, custom=None, bases=True): + if not exctypes: + raise NotImplementedError + if not default: + default = ((), {}) + elif isinstance(default, str): + default = ((default,), {}) + elif type(default) is not tuple: + raise NotImplementedError + elif len(default) != 2: + default = (default, {}) + elif type(default[0]) is not tuple: + default = (default, {}) + elif type(default[1]) is not dict: + default = (default, {}) + # else leave it alone + + for exctype in exctypes: + customtype = None + values = default + if custom: + if exctype in custom: + customtype = exctype + elif bases: + for customtype in custom: + if issubclass(exctype, customtype): + break + else: + customtype = None + if customtype is not None: + values = custom[customtype] + if values is None: + continue + args, kwargs = values + script, expected = build_exception(exctype, *args, **kwargs) + yield exctype, customtype, script, expected + + +try: + raise Exception +except Exception as exc: + assert exc.__traceback__ is not None + Traceback = type(exc.__traceback__) + + +class RunFailedTests(TestBase): + + BUILTINS = [v + for v in vars(builtins).values() + if (type(v) is type + and issubclass(v, Exception) + #and issubclass(v, BaseException) + ) + ] + BUILTINS_SPECIAL = [ + # These all have extra attributes (i.e. args/kwargs) + SyntaxError, + ImportError, + UnicodeError, + OSError, + SystemExit, + StopIteration, + ] + + @classmethod + def build_exceptions(cls, exctypes=None, default=(), custom=None): + if exctypes is None: + exctypes = cls.BUILTINS + if custom is None: + # Skip the "special" ones. + custom = {et: None for et in cls.BUILTINS_SPECIAL} + yield from build_exceptions(*exctypes, default=default, custom=custom) + + def assertExceptionsEqual(self, exc, expected, *, chained=True): + if type(expected) is type: + self.assertIs(type(exc), expected) + return + elif not isinstance(exc, Exception): + self.assertEqual(exc, expected) + elif not isinstance(expected, Exception): + self.assertEqual(exc, expected) + else: + # Plain equality doesn't work, so we have to compare manually. + self.assertIs(type(exc), type(expected)) + self.assertEqual(exc.args, expected.args) + self.assertEqual(exc.__reduce__(), expected.__reduce__()) + if chained: + self.assertExceptionsEqual(exc.__context__, + expected.__context__) + self.assertExceptionsEqual(exc.__cause__, + expected.__cause__) + self.assertEqual(exc.__suppress_context__, + expected.__suppress_context__) + + def assertTracebacksEqual(self, tb, expected): + if not isinstance(tb, Traceback): + self.assertEqual(tb, expected) + elif not isinstance(expected, Traceback): + self.assertEqual(tb, expected) + else: + self.assertEqual(tb.tb_frame.f_code.co_name, + expected.tb_frame.f_code.co_name) + self.assertEqual(tb.tb_frame.f_code.co_filename, + expected.tb_frame.f_code.co_filename) + self.assertEqual(tb.tb_lineno, expected.tb_lineno) + self.assertTracebacksEqual(tb.tb_next, expected.tb_next) + + # XXX Move this to TestBase? + @contextlib.contextmanager + def expected_run_failure(self, expected): + exctype = expected if type(expected) is type else type(expected) + + with self.assertRaises(interpreters.RunFailedError) as caught: + yield caught + exc = caught.exception + + modname = exctype.__module__ + if modname == 'builtins' or modname == '__main__': + exctypename = exctype.__name__ + else: + exctypename = f'{modname}.{exctype.__name__}' + if exctype is expected: + self.assertEqual(str(exc).split(':')[0], exctypename) + else: + self.assertEqual(str(exc), f'{exctypename}: {expected}') + self.assertExceptionsEqual(exc.__cause__, expected) + if exc.__cause__ is not None: + self.assertIsNotNone(exc.__cause__.__traceback__) + + def test_builtin_exceptions(self): + interpid = interpreters.create() + msg = '<a message>' + for i, info in enumerate(self.build_exceptions( + default=msg, + custom={ + SyntaxError: ((msg, '<stdin>', 1, 3, 'a +?'), {}), + ImportError: ((msg,), {'name': 'spam', 'path': '/x/spam.py'}), + UnicodeError: None, + #UnicodeError: ((), {}), + #OSError: ((), {}), + SystemExit: ((1,), {}), + StopIteration: (('<a value>',), {}), + }, + )): + exctype, _, script, expected = info + testname = f'{i+1} - {script}' + script = f'raise {script}' + + with self.subTest(testname): + with self.expected_run_failure(expected): + interpreters.run_string(interpid, script) + + def test_custom_exception_from___main__(self): + script = dedent(""" + class SpamError(Exception): + def __init__(self, q): + super().__init__(f'got {q}') + self.q = q + raise SpamError('eggs') + """) + expected = Exception(f'SpamError: got {"eggs"}') + + interpid = interpreters.create() + with self.assertRaises(interpreters.RunFailedError) as caught: + interpreters.run_string(interpid, script) + cause = caught.exception.__cause__ + + self.assertExceptionsEqual(cause, expected) + + class SpamError(Exception): + # The normal Exception.__reduce__() produces a funny result + # here. So we have to use a custom __new__(). + def __new__(cls, q): + if type(q) is SpamError: + return q + return super().__new__(cls, q) + def __init__(self, q): + super().__init__(f'got {q}') + self.q = q + + def test_custom_exception(self): + script = dedent(""" + import test.test__xxsubinterpreters + SpamError = test.test__xxsubinterpreters.RunFailedTests.SpamError + raise SpamError('eggs') + """) + try: + ns = {} + exec(script, ns, ns) + except Exception as exc: + expected = exc + + interpid = interpreters.create() + with self.expected_run_failure(expected): + interpreters.run_string(interpid, script) + + class SpamReducedError(Exception): + def __init__(self, q): + super().__init__(f'got {q}') + self.q = q + def __reduce__(self): + return (type(self), (self.q,), {}) + + def test_custom___reduce__(self): + script = dedent(""" + import test.test__xxsubinterpreters + SpamError = test.test__xxsubinterpreters.RunFailedTests.SpamReducedError + raise SpamError('eggs') + """) + try: + exec(script, (ns := {'__name__': '__main__'}), ns) + except Exception as exc: + expected = exc + + interpid = interpreters.create() + with self.expected_run_failure(expected): + interpreters.run_string(interpid, script) + + def test_traceback_propagated(self): + script = dedent(""" + def do_spam(): + raise Exception('uh-oh') + def do_eggs(): + return do_spam() + class Spam: + def do(self): + return do_eggs() + def get_handler(): + def handler(): + return Spam().do() + return handler + go = (lambda: get_handler()()) + def iter_all(): + yield from (go() for _ in [True]) + yield None + def main(): + for v in iter_all(): + pass + main() + """) + try: + ns = {} + exec(script, ns, ns) + except Exception as exc: + expected = exc + expectedtb = exc.__traceback__.tb_next + + interpid = interpreters.create() + with self.expected_run_failure(expected) as caught: + interpreters.run_string(interpid, script) + exc = caught.exception + + self.assertTracebacksEqual(exc.__cause__.__traceback__, + expectedtb) + + def test_chained_exceptions(self): + script = dedent(""" + try: + raise ValueError('msg 1') + except Exception as exc1: + try: + raise TypeError('msg 2') + except Exception as exc2: + try: + raise IndexError('msg 3') from exc2 + except Exception: + raise AttributeError('msg 4') + """) + try: + exec(script, {}, {}) + except Exception as exc: + expected = exc + + interpid = interpreters.create() + with self.expected_run_failure(expected) as caught: + interpreters.run_string(interpid, script) + exc = caught.exception + + # ...just to be sure. + self.assertIs(type(exc.__cause__), AttributeError) + + ################################## # channel tests diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c index 8a6fce9..9c5df16 100644 --- a/Modules/_xxsubinterpretersmodule.c +++ b/Modules/_xxsubinterpretersmodule.c @@ -1,5 +1,4 @@ - -/* interpreters module */ +/* _interpreters module */ /* low-level access to interpreter primitives */ #include "Python.h" @@ -7,35 +6,921 @@ #include "interpreteridobject.h" +// XXX Emit a warning? +#define IGNORE_FAILURE(msg) \ + fprintf(stderr, " -----\nRunFailedError: %s\n", msg); \ + PyErr_PrintEx(0); \ + fprintf(stderr, " -----\n"); \ + PyErr_Clear(); + +typedef void (*_deallocfunc)(void *); + +static PyInterpreterState * +_get_current(void) +{ + // _PyInterpreterState_Get() aborts if lookup fails, so don't need + // to check the result for NULL. + return _PyInterpreterState_Get(); +} + + +/* string utils *************************************************************/ + +// PyMem_Free() must be used to dealocate the resulting string. static char * -_copy_raw_string(PyObject *strobj) +_strdup_and_size(const char *data, Py_ssize_t *psize, _deallocfunc *dealloc) { - const char *str = PyUnicode_AsUTF8(strobj); - if (str == NULL) { - return NULL; + if (data == NULL) { + if (psize != NULL) { + *psize = 0; + } + if (dealloc != NULL) { + *dealloc = NULL; + } + return ""; + } + + Py_ssize_t size; + if (psize == NULL) { + size = strlen(data); + } else { + size = *psize; + if (size == 0) { + size = strlen(data); + *psize = size; // The size "return" value. + } } - char *copied = PyMem_Malloc(strlen(str)+1); + char *copied = PyMem_Malloc(size+1); if (copied == NULL) { PyErr_NoMemory(); return NULL; } - strcpy(copied, str); + if (dealloc != NULL) { + *dealloc = PyMem_Free; + } + memcpy(copied, data, size+1); return copied; } -static PyInterpreterState * -_get_current(void) +static const char * +_pyobj_get_str_and_size(PyObject *obj, Py_ssize_t *psize) { - // PyInterpreterState_Get() aborts if lookup fails, so don't need - // to check the result for NULL. - return PyInterpreterState_Get(); + if (PyUnicode_Check(obj)) { + return PyUnicode_AsUTF8AndSize(obj, psize); + } else { + const char *data = NULL; + PyBytes_AsStringAndSize(obj, (char **)&data, psize); + return data; + } +} + +/* "raw" strings */ + +typedef struct _rawstring { + Py_ssize_t size; + const char *data; + _deallocfunc dealloc; +} _rawstring; + +static void +_rawstring_init(_rawstring *raw) +{ + raw->size = 0; + raw->data = NULL; + raw->dealloc = NULL; +} + +static _rawstring * +_rawstring_new(void) +{ + _rawstring *raw = PyMem_NEW(_rawstring, 1); + if (raw == NULL) { + PyErr_NoMemory(); + return NULL; + } + _rawstring_init(raw); + return raw; +} + +static void +_rawstring_clear(_rawstring *raw) +{ + if (raw->data != NULL && raw->dealloc != NULL) { + (*raw->dealloc)((void *)raw->data); + } + _rawstring_init(raw); +} + +static void +_rawstring_free(_rawstring *raw) +{ + _rawstring_clear(raw); + PyMem_Free(raw); +} + +static int +_rawstring_is_clear(_rawstring *raw) +{ + return raw->size == 0 && raw->data == NULL && raw->dealloc == NULL; +} + +//static void +//_rawstring_move(_rawstring *raw, _rawstring *src) +//{ +// raw->size = src->size; +// raw->data = src->data; +// raw->dealloc = src->dealloc; +// _rawstring_init(src); +//} + +static void +_rawstring_proxy(_rawstring *raw, const char *str) +{ + if (str == NULL) { + str = ""; + } + raw->size = strlen(str); + raw->data = str; + raw->dealloc = NULL; +} + +static int +_rawstring_buffer(_rawstring *raw, Py_ssize_t size) +{ + raw->data = PyMem_Malloc(size+1); + if (raw->data == NULL) { + PyErr_NoMemory(); + return -1; + } + raw->size = size; + raw->dealloc = PyMem_Free; + return 0; +} + +static int +_rawstring_strcpy(_rawstring *raw, const char *str, Py_ssize_t size) +{ + _deallocfunc dealloc = NULL; + const char *copied = _strdup_and_size(str, &size, &dealloc); + if (copied == NULL) { + return -1; + } + + raw->size = size; + raw->dealloc = dealloc; + raw->data = copied; + return 0; +} + +static int +_rawstring_from_pyobj(_rawstring *raw, PyObject *obj) +{ + Py_ssize_t size = 0; + const char *data = _pyobj_get_str_and_size(obj, &size); + if (PyErr_Occurred()) { + return -1; + } + if (_rawstring_strcpy(raw, data, size) != 0) { + return -1; + } + return 0; +} + +static int +_rawstring_from_pyobj_attr(_rawstring *raw, PyObject *obj, const char *attr) +{ + int res = -1; + PyObject *valueobj = PyObject_GetAttrString(obj, attr); + if (valueobj == NULL) { + goto done; + } + if (!PyUnicode_Check(valueobj)) { + // XXX PyObject_Str()? Repr()? + goto done; + } + const char *valuestr = PyUnicode_AsUTF8(valueobj); + if (valuestr == NULL) { + if (PyErr_Occurred()) { + goto done; + } + } else if (_rawstring_strcpy(raw, valuestr, 0) != 0) { + _rawstring_clear(raw); + goto done; + } + res = 0; + +done: + Py_XDECREF(valueobj); + return res; +} + +static PyObject * +_rawstring_as_pybytes(_rawstring *raw) +{ + return PyBytes_FromStringAndSize(raw->data, raw->size); +} + + +/* object utils *************************************************************/ + +static void +_pyobj_identify_type(PyObject *obj, _rawstring *modname, _rawstring *clsname) +{ + PyObject *objtype = (PyObject *)Py_TYPE(obj); + + // Try __module__ and __name__. + if (_rawstring_from_pyobj_attr(modname, objtype, "__module__") != 0) { + // Fall back to the previous values in "modname". + IGNORE_FAILURE("bad __module__"); + } + if (_rawstring_from_pyobj_attr(clsname, objtype, "__name__") != 0) { + // Fall back to the previous values in "clsname". + IGNORE_FAILURE("bad __name__"); + } + + // XXX Fall back to __qualname__? + // XXX Fall back to tp_name? +} + +static PyObject * +_pyobj_get_class(const char *modname, const char *clsname) +{ + assert(clsname != NULL); + if (modname == NULL) { + modname = "builtins"; + } + + PyObject *module = PyImport_ImportModule(modname); + if (module == NULL) { + return NULL; + } + PyObject *cls = PyObject_GetAttrString(module, clsname); + Py_DECREF(module); + return cls; +} + +static PyObject * +_pyobj_create(const char *modname, const char *clsname, PyObject *arg) +{ + PyObject *cls = _pyobj_get_class(modname, clsname); + if (cls == NULL) { + return NULL; + } + PyObject *obj = NULL; + if (arg == NULL) { + obj = _PyObject_CallNoArg(cls); + } else { + obj = PyObject_CallFunction(cls, "O", arg); + } + Py_DECREF(cls); + return obj; +} + + +/* object snapshots */ + +typedef struct _objsnapshot { + // If modname is NULL then try "builtins" and "__main__". + _rawstring modname; + // clsname is required. + _rawstring clsname; + + // The rest are optional. + + // The serialized exception. + _rawstring *serialized; +} _objsnapshot; + +static void +_objsnapshot_init(_objsnapshot *osn) +{ + _rawstring_init(&osn->modname); + _rawstring_init(&osn->clsname); + osn->serialized = NULL; +} + +//static _objsnapshot * +//_objsnapshot_new(void) +//{ +// _objsnapshot *osn = PyMem_NEW(_objsnapshot, 1); +// if (osn == NULL) { +// PyErr_NoMemory(); +// return NULL; +// } +// _objsnapshot_init(osn); +// return osn; +//} + +static void +_objsnapshot_clear(_objsnapshot *osn) +{ + _rawstring_clear(&osn->modname); + _rawstring_clear(&osn->clsname); + if (osn->serialized != NULL) { + _rawstring_free(osn->serialized); + osn->serialized = NULL; + } +} + +//static void +//_objsnapshot_free(_objsnapshot *osn) +//{ +// _objsnapshot_clear(osn); +// PyMem_Free(osn); +//} + +static int +_objsnapshot_is_clear(_objsnapshot *osn) +{ + return osn->serialized == NULL + && _rawstring_is_clear(&osn->modname) + && _rawstring_is_clear(&osn->clsname); +} + +static void +_objsnapshot_summarize(_objsnapshot *osn, _rawstring *rawbuf, const char *msg) +{ + if (msg == NULL || *msg == '\0') { + // XXX Keep it NULL? + // XXX Keep it an empty string? + // XXX Use something more informative? + msg = "<no message>"; + } + const char *clsname = osn->clsname.data; + const char *modname = osn->modname.data; + if (modname && *modname == '\0') { + modname = NULL; + } + + // Prep the buffer. + Py_ssize_t size = strlen(clsname); + if (modname != NULL) { + if (strcmp(modname, "builtins") == 0) { + modname = NULL; + } else if (strcmp(modname, "__main__") == 0) { + modname = NULL; + } else { + size += strlen(modname) + 1; + } + } + if (msg != NULL) { + size += strlen(": ") + strlen(msg); + } + if (modname != NULL || msg != NULL) { + if (_rawstring_buffer(rawbuf, size) != 0) { + IGNORE_FAILURE("could not summarize object snapshot"); + return; + } + } + // ...else we'll proxy clsname as-is, so no need to allocate a buffer. + + // XXX Use __qualname__ somehow? + char *buf = (char *)rawbuf->data; + if (modname != NULL) { + if (msg != NULL) { + snprintf(buf, size+1, "%s.%s: %s", modname, clsname, msg); + } else { + snprintf(buf, size+1, "%s.%s", modname, clsname); + } + } else if (msg != NULL) { + snprintf(buf, size+1, "%s: %s", clsname, msg); + } else { + _rawstring_proxy(rawbuf, clsname); + } +} + +static _rawstring * +_objsnapshot_get_minimal_summary(_objsnapshot *osn, PyObject *obj) +{ + const char *str = NULL; + PyObject *objstr = PyObject_Str(obj); + if (objstr == NULL) { + PyErr_Clear(); + } else { + str = PyUnicode_AsUTF8(objstr); + if (str == NULL) { + PyErr_Clear(); + } + } + + _rawstring *summary = _rawstring_new(); + if (summary == NULL) { + return NULL; + } + _objsnapshot_summarize(osn, summary, str); + return summary; +} + +static void +_objsnapshot_extract(_objsnapshot *osn, PyObject *obj) +{ + assert(_objsnapshot_is_clear(osn)); + + // Get the "qualname". + _rawstring_proxy(&osn->modname, "<unknown>"); + _rawstring_proxy(&osn->clsname, "<unknown>"); + _pyobj_identify_type(obj, &osn->modname, &osn->clsname); + + // Serialize the object. + // XXX Use marshal? + PyObject *pickle = PyImport_ImportModule("pickle"); + if (pickle == NULL) { + IGNORE_FAILURE("could not serialize object: pickle import failed"); + return; + } + PyObject *objdata = PyObject_CallMethod(pickle, "dumps", "(O)", obj); + Py_DECREF(pickle); + if (objdata == NULL) { + IGNORE_FAILURE("could not serialize object: pickle.dumps failed"); + } else { + _rawstring *serialized = _rawstring_new(); + int res = _rawstring_from_pyobj(serialized, objdata); + Py_DECREF(objdata); + if (res != 0) { + IGNORE_FAILURE("could not serialize object: raw str failed"); + _rawstring_free(serialized); + } else if (serialized->size == 0) { + _rawstring_free(serialized); + } else { + osn->serialized = serialized; + } + } +} + +static PyObject * +_objsnapshot_resolve_serialized(_objsnapshot *osn) +{ + assert(osn->serialized != NULL); + + // XXX Use marshal? + PyObject *pickle = PyImport_ImportModule("pickle"); + if (pickle == NULL) { + return NULL; + } + PyObject *objdata = _rawstring_as_pybytes(osn->serialized); + if (objdata == NULL) { + return NULL; + } else { + PyObject *obj = PyObject_CallMethod(pickle, "loads", "O", objdata); + Py_DECREF(objdata); + return obj; + } +} + +static PyObject * +_objsnapshot_resolve_naive(_objsnapshot *osn, PyObject *arg) +{ + if (_rawstring_is_clear(&osn->clsname)) { + // We can't proceed without at least the class name. + PyErr_SetString(PyExc_ValueError, "missing class name"); + return NULL; + } + + if (osn->modname.data != NULL) { + return _pyobj_create(osn->modname.data, osn->clsname.data, arg); + } else { + PyObject *obj = _pyobj_create("builtins", osn->clsname.data, arg); + if (obj == NULL) { + PyErr_Clear(); + obj = _pyobj_create("__main__", osn->clsname.data, arg); + } + return obj; + } +} + +static PyObject * +_objsnapshot_resolve(_objsnapshot *osn) +{ + if (osn->serialized != NULL) { + PyObject *obj = _objsnapshot_resolve_serialized(osn); + if (obj != NULL) { + return obj; + } + IGNORE_FAILURE("could not de-serialize object"); + } + + // Fall back to naive resolution. + return _objsnapshot_resolve_naive(osn, NULL); +} + + +/* exception utils **********************************************************/ + +// _pyexc_create is inspired by _PyErr_SetObject(). + +static PyObject * +_pyexc_create(PyObject *exctype, const char *msg, PyObject *tb) +{ + assert(exctype != NULL && PyExceptionClass_Check(exctype)); + + PyObject *curtype = NULL, *curexc = NULL, *curtb = NULL; + PyErr_Fetch(&curtype, &curexc, &curtb); + + // Create the object. + PyObject *exc = NULL; + if (msg != NULL) { + PyObject *msgobj = PyUnicode_FromString(msg); + if (msgobj == NULL) { + IGNORE_FAILURE("could not deserialize propagated error message"); + } + exc = _PyObject_CallOneArg(exctype, msgobj); + Py_XDECREF(msgobj); + } else { + exc = _PyObject_CallNoArg(exctype); + } + if (exc == NULL) { + return NULL; + } + + // Set the traceback, if any. + if (tb == NULL) { + tb = curtb; + } + if (tb != NULL) { + // This does *not* steal a reference! + PyException_SetTraceback(exc, tb); + } + + PyErr_Restore(curtype, curexc, curtb); + + return exc; +} + +/* traceback snapshots */ + +typedef struct _tbsnapshot { + _rawstring tbs_funcname; + _rawstring tbs_filename; + int tbs_lineno; + struct _tbsnapshot *tbs_next; +} _tbsnapshot; + +static void +_tbsnapshot_init(_tbsnapshot *tbs) +{ + _rawstring_init(&tbs->tbs_funcname); + _rawstring_init(&tbs->tbs_filename); + tbs->tbs_lineno = -1; + tbs->tbs_next = NULL; +} + +static _tbsnapshot * +_tbsnapshot_new(void) +{ + _tbsnapshot *tbs = PyMem_NEW(_tbsnapshot, 1); + if (tbs == NULL) { + PyErr_NoMemory(); + return NULL; + } + _tbsnapshot_init(tbs); + return tbs; +} + +static void _tbsnapshot_free(_tbsnapshot *); // forward + +static void +_tbsnapshot_clear(_tbsnapshot *tbs) +{ + _rawstring_clear(&tbs->tbs_funcname); + _rawstring_clear(&tbs->tbs_filename); + tbs->tbs_lineno = -1; + if (tbs->tbs_next != NULL) { + _tbsnapshot_free(tbs->tbs_next); + tbs->tbs_next = NULL; + } +} + +static void +_tbsnapshot_free(_tbsnapshot *tbs) +{ + _tbsnapshot_clear(tbs); + PyMem_Free(tbs); +} + +static int +_tbsnapshot_is_clear(_tbsnapshot *tbs) +{ + return tbs->tbs_lineno == -1 && tbs->tbs_next == NULL + && _rawstring_is_clear(&tbs->tbs_funcname) + && _rawstring_is_clear(&tbs->tbs_filename); +} + +static int +_tbsnapshot_from_pytb(_tbsnapshot *tbs, PyTracebackObject *pytb) +{ + assert(_tbsnapshot_is_clear(tbs)); + assert(pytb != NULL); + + PyCodeObject *pycode = pytb->tb_frame->f_code; + const char *funcname = PyUnicode_AsUTF8(pycode->co_name); + if (_rawstring_strcpy(&tbs->tbs_funcname, funcname, 0) != 0) { + goto error; + } + const char *filename = PyUnicode_AsUTF8(pycode->co_filename); + if (_rawstring_strcpy(&tbs->tbs_filename, filename, 0) != 0) { + goto error; + } + tbs->tbs_lineno = pytb->tb_lineno; + + return 0; + +error: + _tbsnapshot_clear(tbs); + return -1; +} + +static int +_tbsnapshot_extract(_tbsnapshot *tbs, PyTracebackObject *pytb) +{ + assert(_tbsnapshot_is_clear(tbs)); + assert(pytb != NULL); + + _tbsnapshot *next = NULL; + while (pytb->tb_next != NULL) { + _tbsnapshot *_next = _tbsnapshot_new(); + if (_next == NULL) { + goto error; + } + if (_tbsnapshot_from_pytb(_next, pytb) != 0) { + goto error; + } + if (next != NULL) { + _next->tbs_next = next; + } + next = _next; + pytb = pytb->tb_next; + } + if (_tbsnapshot_from_pytb(tbs, pytb) != 0) { + goto error; + } + tbs->tbs_next = next; + + return 0; + +error: + _tbsnapshot_clear(tbs); + return -1; +} + +static PyObject * +_tbsnapshot_resolve(_tbsnapshot *tbs) +{ + assert(!PyErr_Occurred()); + // At this point there should be no traceback set yet. + + while (tbs != NULL) { + const char *funcname = tbs->tbs_funcname.data; + const char *filename = tbs->tbs_filename.data; + _PyTraceback_Add(funcname ? funcname : "", + filename ? filename : "", + tbs->tbs_lineno); + tbs = tbs->tbs_next; + } + + PyObject *exctype = NULL, *excval = NULL, *tb = NULL; + PyErr_Fetch(&exctype, &excval, &tb); + // Leave it cleared. + return tb; +} + +/* exception snapshots */ + +typedef struct _excsnapshot { + _objsnapshot es_object; + _rawstring *es_msg; + struct _excsnapshot *es_cause; + struct _excsnapshot *es_context; + char es_suppress_context; + struct _tbsnapshot *es_traceback; +} _excsnapshot; + +static void +_excsnapshot_init(_excsnapshot *es) +{ + _objsnapshot_init(&es->es_object); + es->es_msg = NULL; + es->es_cause = NULL; + es->es_context = NULL; + es->es_suppress_context = 0; + es->es_traceback = NULL; +} + +static _excsnapshot * +_excsnapshot_new(void) { + _excsnapshot *es = PyMem_NEW(_excsnapshot, 1); + if (es == NULL) { + PyErr_NoMemory(); + return NULL; + } + _excsnapshot_init(es); + return es; +} + +static void _excsnapshot_free(_excsnapshot *); // forward + +static void +_excsnapshot_clear(_excsnapshot *es) +{ + _objsnapshot_clear(&es->es_object); + if (es->es_msg != NULL) { + _rawstring_free(es->es_msg); + es->es_msg = NULL; + } + if (es->es_cause != NULL) { + _excsnapshot_free(es->es_cause); + es->es_cause = NULL; + } + if (es->es_context != NULL) { + _excsnapshot_free(es->es_context); + es->es_context = NULL; + } + es->es_suppress_context = 0; + if (es->es_traceback != NULL) { + _tbsnapshot_free(es->es_traceback); + es->es_traceback = NULL; + } +} + +static void +_excsnapshot_free(_excsnapshot *es) +{ + _excsnapshot_clear(es); + PyMem_Free(es); +} + +static int +_excsnapshot_is_clear(_excsnapshot *es) +{ + return es->es_suppress_context == 0 + && es->es_cause == NULL + && es->es_context == NULL + && es->es_traceback == NULL + && es->es_msg == NULL + && _objsnapshot_is_clear(&es->es_object); +} + +static PyObject * +_excsnapshot_get_exc_naive(_excsnapshot *es) +{ + _rawstring buf; + const char *msg = NULL; + if (es->es_msg != NULL) { + msg = es->es_msg->data; + } else { + _objsnapshot_summarize(&es->es_object, &buf, NULL); + if (buf.size > 0) { + msg = buf.data; + } + } + + PyObject *exc = NULL; + // XXX Use _objsnapshot_resolve_naive()? + const char *modname = es->es_object.modname.size > 0 + ? es->es_object.modname.data + : NULL; + PyObject *exctype = _pyobj_get_class(modname, es->es_object.clsname.data); + if (exctype != NULL) { + exc = _pyexc_create(exctype, msg, NULL); + Py_DECREF(exctype); + if (exc != NULL) { + return exc; + } + PyErr_Clear(); + } else { + PyErr_Clear(); + } + exctype = PyExc_Exception; + return _pyexc_create(exctype, msg, NULL); +} + +static PyObject * +_excsnapshot_get_exc(_excsnapshot *es) +{ + assert(!_objsnapshot_is_clear(&es->es_object)); + + PyObject *exc = _objsnapshot_resolve(&es->es_object); + if (exc == NULL) { + // Fall back to resolving the object. + PyObject *curtype = NULL, *curexc = NULL, *curtb = NULL; + PyErr_Fetch(&curtype, &curexc, &curtb); + + exc = _excsnapshot_get_exc_naive(es); + if (exc == NULL) { + PyErr_Restore(curtype, curexc, curtb); + return NULL; + } + } + // People can do some weird stuff... + if (!PyExceptionInstance_Check(exc)) { + // We got a bogus "exception". + Py_DECREF(exc); + PyErr_SetString(PyExc_TypeError, "expected exception"); + return NULL; + } + return exc; +} + +static void _excsnapshot_extract(_excsnapshot *, PyObject *); +static void +_excsnapshot_extract(_excsnapshot *es, PyObject *excobj) +{ + assert(_excsnapshot_is_clear(es)); + assert(PyExceptionInstance_Check(excobj)); + + _objsnapshot_extract(&es->es_object, excobj); + + es->es_msg = _objsnapshot_get_minimal_summary(&es->es_object, excobj); + if (es->es_msg == NULL) { + PyErr_Clear(); + } + + PyBaseExceptionObject *exc = (PyBaseExceptionObject *)excobj; + + if (exc->cause != NULL && exc->cause != Py_None) { + es->es_cause = _excsnapshot_new(); + _excsnapshot_extract(es->es_cause, exc->cause); + } + + if (exc->context != NULL && exc->context != Py_None) { + es->es_context = _excsnapshot_new(); + _excsnapshot_extract(es->es_context, exc->context); + } + + es->es_suppress_context = exc->suppress_context; + + PyObject *tb = PyException_GetTraceback(excobj); + if (PyErr_Occurred()) { + IGNORE_FAILURE("could not get traceback"); + } else if (tb == Py_None) { + Py_DECREF(tb); + tb = NULL; + } + if (tb != NULL) { + es->es_traceback = _tbsnapshot_new(); + if (_tbsnapshot_extract(es->es_traceback, + (PyTracebackObject *)tb) != 0) { + IGNORE_FAILURE("could not extract __traceback__"); + } + } +} + +static PyObject * +_excsnapshot_resolve(_excsnapshot *es) +{ + PyObject *exc = _excsnapshot_get_exc(es); + if (exc == NULL) { + return NULL; + } + + if (es->es_traceback != NULL) { + PyObject *tb = _tbsnapshot_resolve(es->es_traceback); + if (tb == NULL) { + // The snapshot is still somewhat useful without this. + IGNORE_FAILURE("could not deserialize traceback"); + } else { + // This does not steal references. + PyException_SetTraceback(exc, tb); + Py_DECREF(tb); + } + } + // NULL means "not set". + + if (es->es_context != NULL) { + PyObject *context = _excsnapshot_resolve(es->es_context); + if (context == NULL) { + // The snapshot is still useful without this. + IGNORE_FAILURE("could not deserialize __context__"); + } else { + // This steals references but we have one to give. + PyException_SetContext(exc, context); + } + } + // NULL means "not set". + + if (es->es_cause != NULL) { + PyObject *cause = _excsnapshot_resolve(es->es_cause); + if (cause == NULL) { + // The snapshot is still useful without this. + IGNORE_FAILURE("could not deserialize __cause__"); + } else { + // This steals references, but we have one to give. + PyException_SetCause(exc, cause); + } + } + // NULL means "not set". + + ((PyBaseExceptionObject *)exc)->suppress_context = es->es_suppress_context; + + return exc; } /* data-sharing-specific code ***********************************************/ +/* shared "object" */ + struct _sharednsitem { - char *name; + _rawstring name; _PyCrossInterpreterData data; }; @@ -44,8 +929,7 @@ static void _sharednsitem_clear(struct _sharednsitem *); // forward static int _sharednsitem_init(struct _sharednsitem *item, PyObject *key, PyObject *value) { - item->name = _copy_raw_string(key); - if (item->name == NULL) { + if (_rawstring_from_pyobj(&item->name, key) != 0) { return -1; } if (_PyObject_GetCrossInterpreterData(value, &item->data) != 0) { @@ -58,17 +942,14 @@ _sharednsitem_init(struct _sharednsitem *item, PyObject *key, PyObject *value) static void _sharednsitem_clear(struct _sharednsitem *item) { - if (item->name != NULL) { - PyMem_Free(item->name); - item->name = NULL; - } + _rawstring_clear(&item->name); _PyCrossInterpreterData_Release(&item->data); } static int _sharednsitem_apply(struct _sharednsitem *item, PyObject *ns) { - PyObject *name = PyUnicode_FromString(item->name); + PyObject *name = PyUnicode_FromString(item->name.data); if (name == NULL) { return -1; } @@ -159,121 +1040,119 @@ _sharedns_apply(_sharedns *shared, PyObject *ns) return 0; } +/* shared exception */ + // Ultimately we'd like to preserve enough information about the // exception and traceback that we could re-constitute (or at least // simulate, a la traceback.TracebackException), and even chain, a copy // of the exception in the calling interpreter. typedef struct _sharedexception { - char *name; - char *msg; + _excsnapshot snapshot; + _rawstring msg; } _sharedexception; +static void +_sharedexception_init(_sharedexception *she) +{ + _excsnapshot_init(&she->snapshot); + _rawstring_init(&she->msg); +} + static _sharedexception * _sharedexception_new(void) { - _sharedexception *err = PyMem_NEW(_sharedexception, 1); - if (err == NULL) { + _sharedexception *she = PyMem_NEW(_sharedexception, 1); + if (she == NULL) { PyErr_NoMemory(); return NULL; } - err->name = NULL; - err->msg = NULL; - return err; + _sharedexception_init(she); + return she; } static void -_sharedexception_clear(_sharedexception *exc) +_sharedexception_clear(_sharedexception *she) { - if (exc->name != NULL) { - PyMem_Free(exc->name); - } - if (exc->msg != NULL) { - PyMem_Free(exc->msg); - } + _excsnapshot_clear(&she->snapshot); + _rawstring_clear(&she->msg); } static void -_sharedexception_free(_sharedexception *exc) +_sharedexception_free(_sharedexception *she) { - _sharedexception_clear(exc); - PyMem_Free(exc); + _sharedexception_clear(she); + PyMem_Free(she); } -static _sharedexception * -_sharedexception_bind(PyObject *exctype, PyObject *exc, PyObject *tb) +static int +_sharedexception_is_clear(_sharedexception *she) { - assert(exctype != NULL); - char *failure = NULL; - - _sharedexception *err = _sharedexception_new(); - if (err == NULL) { - goto finally; - } + return 1 + && _excsnapshot_is_clear(&she->snapshot) + && _rawstring_is_clear(&she->msg); +} - PyObject *name = PyUnicode_FromFormat("%S", exctype); - if (name == NULL) { - failure = "unable to format exception type name"; - goto finally; - } - err->name = _copy_raw_string(name); - Py_DECREF(name); - if (err->name == NULL) { - if (PyErr_ExceptionMatches(PyExc_MemoryError)) { - failure = "out of memory copying exception type name"; - } else { - failure = "unable to encode and copy exception type name"; +static PyObject * +_sharedexception_get_cause(_sharedexception *sharedexc) +{ + // FYI, "cause" is already normalized. + PyObject *cause = _excsnapshot_resolve(&sharedexc->snapshot); + if (cause == NULL) { + if (PyErr_Occurred()) { + IGNORE_FAILURE("could not deserialize exc snapshot"); } - goto finally; + return NULL; } + // XXX Ensure "cause" has a traceback. + return cause; +} - if (exc != NULL) { - PyObject *msg = PyUnicode_FromFormat("%S", exc); - if (msg == NULL) { - failure = "unable to format exception message"; - goto finally; - } - err->msg = _copy_raw_string(msg); - Py_DECREF(msg); - if (err->msg == NULL) { - if (PyErr_ExceptionMatches(PyExc_MemoryError)) { - failure = "out of memory copying exception message"; - } else { - failure = "unable to encode and copy exception message"; - } - goto finally; - } - } +static void +_sharedexception_extract(_sharedexception *she, PyObject *exc) +{ + assert(_sharedexception_is_clear(she)); + assert(exc != NULL); -finally: - if (failure != NULL) { - PyErr_Clear(); - if (err->name != NULL) { - PyMem_Free(err->name); - err->name = NULL; + _excsnapshot_extract(&she->snapshot, exc); + + // Compose the message. + const char *msg = NULL; + PyObject *msgobj = PyUnicode_FromFormat("%S", exc); + if (msgobj == NULL) { + IGNORE_FAILURE("unable to format exception message"); + } else { + msg = PyUnicode_AsUTF8(msgobj); + if (PyErr_Occurred()) { + PyErr_Clear(); } - err->msg = failure; } - return err; + _objsnapshot_summarize(&she->snapshot.es_object, &she->msg, msg); + Py_XDECREF(msgobj); } -static void -_sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass) +static PyObject * +_sharedexception_resolve(_sharedexception *sharedexc, PyObject *wrapperclass) { - if (exc->name != NULL) { - if (exc->msg != NULL) { - PyErr_Format(wrapperclass, "%s: %s", exc->name, exc->msg); - } - else { - PyErr_SetString(wrapperclass, exc->name); - } - } - else if (exc->msg != NULL) { - PyErr_SetString(wrapperclass, exc->msg); - } - else { - PyErr_SetNone(wrapperclass); + assert(!PyErr_Occurred()); + + // Get the exception object (already normalized). + PyObject *exc = _pyexc_create(wrapperclass, sharedexc->msg.data, NULL); + assert(exc != NULL); + + // Set __cause__, is possible. + PyObject *cause = _sharedexception_get_cause(sharedexc); + if (cause != NULL) { + // Set __context__. + Py_INCREF(cause); // PyException_SetContext() steals a reference. + PyException_SetContext(exc, cause); + + // Set __cause__. + Py_INCREF(cause); // PyException_SetCause() steals a reference. + PyException_SetCause(exc, cause); } + + return exc; } @@ -1869,11 +2748,9 @@ _ensure_not_running(PyInterpreterState *interp) static int _run_script(PyInterpreterState *interp, const char *codestr, - _sharedns *shared, _sharedexception **exc) + _sharedns *shared, _sharedexception **pexc) { - PyObject *exctype = NULL; - PyObject *excval = NULL; - PyObject *tb = NULL; + assert(!PyErr_Occurred()); // ...in the called interpreter. PyObject *main_mod = _PyInterpreterState_GetMainModule(interp); if (main_mod == NULL) { @@ -1904,25 +2781,38 @@ _run_script(PyInterpreterState *interp, const char *codestr, Py_DECREF(result); // We throw away the result. } - *exc = NULL; + *pexc = NULL; return 0; + PyObject *exctype = NULL, *exc = NULL, *tb = NULL; error: - PyErr_Fetch(&exctype, &excval, &tb); + PyErr_Fetch(&exctype, &exc, &tb); - _sharedexception *sharedexc = _sharedexception_bind(exctype, excval, tb); - Py_XDECREF(exctype); - Py_XDECREF(excval); - Py_XDECREF(tb); - if (sharedexc == NULL) { - fprintf(stderr, "RunFailedError: script raised an uncaught exception"); - PyErr_Clear(); - sharedexc = NULL; + // First normalize the exception. + PyErr_NormalizeException(&exctype, &exc, &tb); + assert(PyExceptionInstance_Check(exc)); + if (tb != NULL) { + PyException_SetTraceback(exc, tb); } - else { + + // Behave as though the exception was caught in this thread. + PyErr_SetExcInfo(exctype, exc, tb); // Like entering "except" block. + + // Serialize the exception. + _sharedexception *sharedexc = _sharedexception_new(); + if (sharedexc == NULL) { + IGNORE_FAILURE("script raised an uncaught exception"); + } else { + _sharedexception_extract(sharedexc, exc); assert(!PyErr_Occurred()); } - *exc = sharedexc; + + // Clear the exception. + PyErr_SetExcInfo(NULL, NULL, NULL); // Like leaving "except" block. + PyErr_Clear(); // Do not re-raise. + + // "Return" the serialized exception. + *pexc = sharedexc; return -1; } @@ -1930,6 +2820,8 @@ static int _run_script_in_interpreter(PyInterpreterState *interp, const char *codestr, PyObject *shareables) { + assert(!PyErr_Occurred()); // ...in the calling interpreter. + if (_ensure_not_running(interp) < 0) { return -1; } @@ -1963,8 +2855,8 @@ _run_script_in_interpreter(PyInterpreterState *interp, const char *codestr, } // Run the script. - _sharedexception *exc = NULL; - int result = _run_script(interp, codestr, shared, &exc); + _sharedexception *sharedexc = NULL; + int result = _run_script(interp, codestr, shared, &sharedexc); // Switch back. if (save_tstate != NULL) { @@ -1973,9 +2865,14 @@ _run_script_in_interpreter(PyInterpreterState *interp, const char *codestr, #endif // Propagate any exception out to the caller. - if (exc != NULL) { - _sharedexception_apply(exc, RunFailedError); - _sharedexception_free(exc); + if (sharedexc != NULL) { + assert(!PyErr_Occurred()); + PyObject *exc = _sharedexception_resolve(sharedexc, RunFailedError); + // XXX This is not safe once interpreters no longer share allocators. + _sharedexception_free(sharedexc); + PyObject *exctype = (PyObject *)Py_TYPE(exc); + Py_INCREF(exctype); // PyErr_Restore() steals a reference. + PyErr_Restore(exctype, exc, PyException_GetTraceback(exc)); } else if (result != 0) { // We were unable to allocate a shared exception. |