summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/ctypes/test/test_arrays.py52
-rw-r--r--Misc/NEWS2
-rw-r--r--Modules/_ctypes/_ctypes.c79
3 files changed, 93 insertions, 40 deletions
diff --git a/Lib/ctypes/test/test_arrays.py b/Lib/ctypes/test/test_arrays.py
index f2a9f07..cfc219e 100644
--- a/Lib/ctypes/test/test_arrays.py
+++ b/Lib/ctypes/test/test_arrays.py
@@ -127,5 +127,57 @@ class ArrayTestCase(unittest.TestCase):
t2 = my_int * 1
self.assertTrue(t1 is t2)
+ def test_subclass(self):
+ class T(Array):
+ _type_ = c_int
+ _length_ = 13
+ class U(T):
+ pass
+ class V(U):
+ pass
+ class W(V):
+ pass
+ class X(T):
+ _type_ = c_short
+ class Y(T):
+ _length_ = 187
+
+ for c in [T, U, V, W]:
+ self.assertEqual(c._type_, c_int)
+ self.assertEqual(c._length_, 13)
+ self.assertEqual(c()._type_, c_int)
+ self.assertEqual(c()._length_, 13)
+
+ self.assertEqual(X._type_, c_short)
+ self.assertEqual(X._length_, 13)
+ self.assertEqual(X()._type_, c_short)
+ self.assertEqual(X()._length_, 13)
+
+ self.assertEqual(Y._type_, c_int)
+ self.assertEqual(Y._length_, 187)
+ self.assertEqual(Y()._type_, c_int)
+ self.assertEqual(Y()._length_, 187)
+
+ def test_bad_subclass(self):
+ import sys
+
+ with self.assertRaises(AttributeError):
+ class T(Array):
+ pass
+ with self.assertRaises(AttributeError):
+ class T(Array):
+ _type_ = c_int
+ with self.assertRaises(AttributeError):
+ class T(Array):
+ _length_ = 13
+ with self.assertRaises(OverflowError):
+ class T(Array):
+ _type_ = c_int
+ _length_ = sys.maxsize * 2
+ with self.assertRaises(AttributeError):
+ class T(Array):
+ _type_ = c_int
+ _length_ = 1.87
+
if __name__ == '__main__':
unittest.main()
diff --git a/Misc/NEWS b/Misc/NEWS
index a6db483..a562d29 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -185,6 +185,8 @@ Library
Extension Modules
-----------------
+- Issue #11241: subclasses of ctypes.Array can now be subclassed.
+
- Issue #9651: Fix a crash when ctypes.create_string_buffer(0) was passed to
some functions like file.write().
diff --git a/Modules/_ctypes/_ctypes.c b/Modules/_ctypes/_ctypes.c
index 277206c..17a00f5 100644
--- a/Modules/_ctypes/_ctypes.c
+++ b/Modules/_ctypes/_ctypes.c
@@ -1256,49 +1256,57 @@ PyCArrayType_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
PyTypeObject *result;
StgDictObject *stgdict;
StgDictObject *itemdict;
- PyObject *proto;
- PyObject *typedict;
+ PyObject *length_attr, *type_attr;
long length;
int overflow;
Py_ssize_t itemsize, itemalign;
char buf[32];
- typedict = PyTuple_GetItem(args, 2);
- if (!typedict)
+ /* create the new instance (which is a class,
+ since we are a metatype!) */
+ result = (PyTypeObject *)PyType_Type.tp_new(type, args, kwds);
+ if (result == NULL)
return NULL;
- proto = PyDict_GetItemString(typedict, "_length_"); /* Borrowed ref */
- if (!proto || !PyLong_Check(proto)) {
+ /* Initialize these variables to NULL so that we can simplify error
+ handling by using Py_XDECREF. */
+ stgdict = NULL;
+ type_attr = NULL;
+
+ length_attr = PyObject_GetAttrString((PyObject *)result, "_length_");
+ if (!length_attr || !PyLong_Check(length_attr)) {
PyErr_SetString(PyExc_AttributeError,
"class must define a '_length_' attribute, "
"which must be a positive integer");
- return NULL;
+ Py_XDECREF(length_attr);
+ goto error;
}
- length = PyLong_AsLongAndOverflow(proto, &overflow);
+ length = PyLong_AsLongAndOverflow(length_attr, &overflow);
if (overflow) {
PyErr_SetString(PyExc_OverflowError,
"The '_length_' attribute is too large");
- return NULL;
+ Py_DECREF(length_attr);
+ goto error;
}
+ Py_DECREF(length_attr);
- proto = PyDict_GetItemString(typedict, "_type_"); /* Borrowed ref */
- if (!proto) {
+ type_attr = PyObject_GetAttrString((PyObject *)result, "_type_");
+ if (!type_attr) {
PyErr_SetString(PyExc_AttributeError,
"class must define a '_type_' attribute");
- return NULL;
+ goto error;
}
stgdict = (StgDictObject *)PyObject_CallObject(
(PyObject *)&PyCStgDict_Type, NULL);
if (!stgdict)
- return NULL;
+ goto error;
- itemdict = PyType_stgdict(proto);
+ itemdict = PyType_stgdict(type_attr);
if (!itemdict) {
PyErr_SetString(PyExc_TypeError,
"_type_ must have storage info");
- Py_DECREF((PyObject *)stgdict);
- return NULL;
+ goto error;
}
assert(itemdict->format);
@@ -1309,16 +1317,12 @@ PyCArrayType_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
sprintf(buf, "(%ld)", length);
stgdict->format = _ctypes_alloc_format_string(buf, itemdict->format);
}
- if (stgdict->format == NULL) {
- Py_DECREF((PyObject *)stgdict);
- return NULL;
- }
+ if (stgdict->format == NULL)
+ goto error;
stgdict->ndim = itemdict->ndim + 1;
stgdict->shape = PyMem_Malloc(sizeof(Py_ssize_t *) * stgdict->ndim);
- if (stgdict->shape == NULL) {
- Py_DECREF((PyObject *)stgdict);
- return NULL;
- }
+ if (stgdict->shape == NULL)
+ goto error;
stgdict->shape[0] = length;
memmove(&stgdict->shape[1], itemdict->shape,
sizeof(Py_ssize_t) * (stgdict->ndim - 1));
@@ -1327,7 +1331,7 @@ PyCArrayType_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
if (length * itemsize < 0) {
PyErr_SetString(PyExc_OverflowError,
"array too large");
- return NULL;
+ goto error;
}
itemalign = itemdict->align;
@@ -1338,26 +1342,16 @@ PyCArrayType_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
stgdict->size = itemsize * length;
stgdict->align = itemalign;
stgdict->length = length;
- Py_INCREF(proto);
- stgdict->proto = proto;
+ stgdict->proto = type_attr;
stgdict->paramfunc = &PyCArrayType_paramfunc;
/* Arrays are passed as pointers to function calls. */
stgdict->ffi_type_pointer = ffi_type_pointer;
- /* create the new instance (which is a class,
- since we are a metatype!) */
- result = (PyTypeObject *)PyType_Type.tp_new(type, args, kwds);
- if (result == NULL)
- return NULL;
-
/* replace the class dict by our updated spam dict */
- if (-1 == PyDict_Update((PyObject *)stgdict, result->tp_dict)) {
- Py_DECREF(result);
- Py_DECREF((PyObject *)stgdict);
- return NULL;
- }
+ if (-1 == PyDict_Update((PyObject *)stgdict, result->tp_dict))
+ goto error;
Py_DECREF(result->tp_dict);
result->tp_dict = (PyObject *)stgdict;
@@ -1366,15 +1360,20 @@ PyCArrayType_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
*/
if (itemdict->getfunc == _ctypes_get_fielddesc("c")->getfunc) {
if (-1 == add_getset(result, CharArray_getsets))
- return NULL;
+ goto error;
#ifdef CTYPES_UNICODE
} else if (itemdict->getfunc == _ctypes_get_fielddesc("u")->getfunc) {
if (-1 == add_getset(result, WCharArray_getsets))
- return NULL;
+ goto error;
#endif
}
return (PyObject *)result;
+error:
+ Py_XDECREF((PyObject*)stgdict);
+ Py_XDECREF(type_attr);
+ Py_DECREF(result);
+ return NULL;
}
PyTypeObject PyCArrayType_Type = {