summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_codecs.py41
-rw-r--r--Objects/exceptions.c17
2 files changed, 50 insertions, 8 deletions
diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py
index 31bd089..728f7d0 100644
--- a/Lib/test/test_codecs.py
+++ b/Lib/test/test_codecs.py
@@ -2402,6 +2402,25 @@ class TransformCodecTest(unittest.TestCase):
self.assertTrue(isinstance(failure.exception.__cause__,
AttributeError))
+ def test_custom_zlib_error_is_wrapped(self):
+ # Check zlib codec gives a good error for malformed input
+ msg = "^decoding with 'zlib_codec' codec failed"
+ with self.assertRaisesRegex(Exception, msg) as failure:
+ b"hello".decode("zlib_codec")
+ self.assertTrue(isinstance(failure.exception.__cause__,
+ type(failure.exception)))
+
+ def test_custom_hex_error_is_wrapped(self):
+ # Check hex codec gives a good error for malformed input
+ msg = "^decoding with 'hex_codec' codec failed"
+ with self.assertRaisesRegex(Exception, msg) as failure:
+ b"hello".decode("hex_codec")
+ self.assertTrue(isinstance(failure.exception.__cause__,
+ type(failure.exception)))
+
+ # Unfortunately, the bz2 module throws OSError, which the codec
+ # machinery currently can't wrap :(
+
def test_bad_decoding_output_type(self):
# Check bytes.decode and bytearray.decode give a good error
# message for binary -> binary codecs
@@ -2466,15 +2485,15 @@ class ExceptionChainingTest(unittest.TestCase):
with self.assertRaisesRegex(exc_type, full_msg) as caught:
yield caught
- def check_wrapped(self, obj_to_raise, msg):
+ def check_wrapped(self, obj_to_raise, msg, exc_type=RuntimeError):
self.set_codec(obj_to_raise)
- with self.assertWrapped("encoding", RuntimeError, msg):
+ with self.assertWrapped("encoding", exc_type, msg):
"str_input".encode(self.codec_name)
- with self.assertWrapped("encoding", RuntimeError, msg):
+ with self.assertWrapped("encoding", exc_type, msg):
codecs.encode("str_input", self.codec_name)
- with self.assertWrapped("decoding", RuntimeError, msg):
+ with self.assertWrapped("decoding", exc_type, msg):
b"bytes input".decode(self.codec_name)
- with self.assertWrapped("decoding", RuntimeError, msg):
+ with self.assertWrapped("decoding", exc_type, msg):
codecs.decode(b"bytes input", self.codec_name)
def test_raise_by_type(self):
@@ -2484,6 +2503,18 @@ class ExceptionChainingTest(unittest.TestCase):
msg = "This should be wrapped"
self.check_wrapped(RuntimeError(msg), msg)
+ def test_raise_grandchild_subclass_exact_size(self):
+ msg = "This should be wrapped"
+ class MyRuntimeError(RuntimeError):
+ __slots__ = ()
+ self.check_wrapped(MyRuntimeError(msg), msg, MyRuntimeError)
+
+ def test_raise_subclass_with_weakref_support(self):
+ msg = "This should be wrapped"
+ class MyRuntimeError(RuntimeError):
+ pass
+ self.check_wrapped(MyRuntimeError(msg), msg, MyRuntimeError)
+
@contextlib.contextmanager
def assertNotWrapped(self, operation, exc_type, msg_re, msg=None):
if msg is None:
diff --git a/Objects/exceptions.c b/Objects/exceptions.c
index 3476db0..af40bc8 100644
--- a/Objects/exceptions.c
+++ b/Objects/exceptions.c
@@ -2630,16 +2630,27 @@ _PyErr_TrySetFromCause(const char *format, ...)
PyTypeObject *caught_type;
PyObject **dictptr;
PyObject *instance_args;
- Py_ssize_t num_args;
+ Py_ssize_t num_args, caught_type_size, base_exc_size;
PyObject *new_exc, *new_val, *new_tb;
va_list vargs;
+ int same_basic_size;
PyErr_Fetch(&exc, &val, &tb);
caught_type = (PyTypeObject *)exc;
- /* Ensure type info indicates no extra state is stored at the C level */
+ /* Ensure type info indicates no extra state is stored at the C level
+ * and that the type can be reinstantiated using PyErr_Format
+ */
+ caught_type_size = caught_type->tp_basicsize;
+ base_exc_size = _PyExc_BaseException.tp_basicsize;
+ same_basic_size = (
+ caught_type_size == base_exc_size ||
+ (PyType_SUPPORTS_WEAKREFS(caught_type) &&
+ (caught_type_size == base_exc_size + sizeof(PyObject *))
+ )
+ );
if (caught_type->tp_init != (initproc)BaseException_init ||
caught_type->tp_new != BaseException_new ||
- caught_type->tp_basicsize != _PyExc_BaseException.tp_basicsize ||
+ !same_basic_size ||
caught_type->tp_itemsize != _PyExc_BaseException.tp_itemsize) {
/* We can't be sure we can wrap this safely, since it may contain
* more state than just the exception type. Accordingly, we just