diff options
-rw-r--r-- | Lib/test/test_codecs.py | 41 | ||||
-rw-r--r-- | Objects/exceptions.c | 17 |
2 files changed, 50 insertions, 8 deletions
diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index 31bd089..728f7d0 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -2402,6 +2402,25 @@ class TransformCodecTest(unittest.TestCase): self.assertTrue(isinstance(failure.exception.__cause__, AttributeError)) + def test_custom_zlib_error_is_wrapped(self): + # Check zlib codec gives a good error for malformed input + msg = "^decoding with 'zlib_codec' codec failed" + with self.assertRaisesRegex(Exception, msg) as failure: + b"hello".decode("zlib_codec") + self.assertTrue(isinstance(failure.exception.__cause__, + type(failure.exception))) + + def test_custom_hex_error_is_wrapped(self): + # Check hex codec gives a good error for malformed input + msg = "^decoding with 'hex_codec' codec failed" + with self.assertRaisesRegex(Exception, msg) as failure: + b"hello".decode("hex_codec") + self.assertTrue(isinstance(failure.exception.__cause__, + type(failure.exception))) + + # Unfortunately, the bz2 module throws OSError, which the codec + # machinery currently can't wrap :( + def test_bad_decoding_output_type(self): # Check bytes.decode and bytearray.decode give a good error # message for binary -> binary codecs @@ -2466,15 +2485,15 @@ class ExceptionChainingTest(unittest.TestCase): with self.assertRaisesRegex(exc_type, full_msg) as caught: yield caught - def check_wrapped(self, obj_to_raise, msg): + def check_wrapped(self, obj_to_raise, msg, exc_type=RuntimeError): self.set_codec(obj_to_raise) - with self.assertWrapped("encoding", RuntimeError, msg): + with self.assertWrapped("encoding", exc_type, msg): "str_input".encode(self.codec_name) - with self.assertWrapped("encoding", RuntimeError, msg): + with self.assertWrapped("encoding", exc_type, msg): codecs.encode("str_input", self.codec_name) - with self.assertWrapped("decoding", RuntimeError, msg): + with self.assertWrapped("decoding", exc_type, msg): b"bytes input".decode(self.codec_name) - with self.assertWrapped("decoding", RuntimeError, msg): + with self.assertWrapped("decoding", exc_type, msg): codecs.decode(b"bytes input", self.codec_name) def test_raise_by_type(self): @@ -2484,6 +2503,18 @@ class ExceptionChainingTest(unittest.TestCase): msg = "This should be wrapped" self.check_wrapped(RuntimeError(msg), msg) + def test_raise_grandchild_subclass_exact_size(self): + msg = "This should be wrapped" + class MyRuntimeError(RuntimeError): + __slots__ = () + self.check_wrapped(MyRuntimeError(msg), msg, MyRuntimeError) + + def test_raise_subclass_with_weakref_support(self): + msg = "This should be wrapped" + class MyRuntimeError(RuntimeError): + pass + self.check_wrapped(MyRuntimeError(msg), msg, MyRuntimeError) + @contextlib.contextmanager def assertNotWrapped(self, operation, exc_type, msg_re, msg=None): if msg is None: diff --git a/Objects/exceptions.c b/Objects/exceptions.c index 3476db0..af40bc8 100644 --- a/Objects/exceptions.c +++ b/Objects/exceptions.c @@ -2630,16 +2630,27 @@ _PyErr_TrySetFromCause(const char *format, ...) PyTypeObject *caught_type; PyObject **dictptr; PyObject *instance_args; - Py_ssize_t num_args; + Py_ssize_t num_args, caught_type_size, base_exc_size; PyObject *new_exc, *new_val, *new_tb; va_list vargs; + int same_basic_size; PyErr_Fetch(&exc, &val, &tb); caught_type = (PyTypeObject *)exc; - /* Ensure type info indicates no extra state is stored at the C level */ + /* Ensure type info indicates no extra state is stored at the C level + * and that the type can be reinstantiated using PyErr_Format + */ + caught_type_size = caught_type->tp_basicsize; + base_exc_size = _PyExc_BaseException.tp_basicsize; + same_basic_size = ( + caught_type_size == base_exc_size || + (PyType_SUPPORTS_WEAKREFS(caught_type) && + (caught_type_size == base_exc_size + sizeof(PyObject *)) + ) + ); if (caught_type->tp_init != (initproc)BaseException_init || caught_type->tp_new != BaseException_new || - caught_type->tp_basicsize != _PyExc_BaseException.tp_basicsize || + !same_basic_size || caught_type->tp_itemsize != _PyExc_BaseException.tp_itemsize) { /* We can't be sure we can wrap this safely, since it may contain * more state than just the exception type. Accordingly, we just |