diff options
-rw-r--r-- | Lib/test/test_xml_etree.py | 122 | ||||
-rw-r--r-- | Misc/NEWS | 3 | ||||
-rw-r--r-- | Modules/_elementtree.c | 144 |
3 files changed, 222 insertions, 47 deletions
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index d3c0da0..d4c7c6e 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -1709,6 +1709,126 @@ class BasicElementTest(ElementTestCase, unittest.TestCase): self.assertEqual(e2[0].tag, 'dogs') +class BadElementTest(ElementTestCase, unittest.TestCase): + def test_extend_mutable_list(self): + class X: + @property + def __class__(self): + L[:] = [ET.Element('baz')] + return ET.Element + L = [X()] + e = ET.Element('foo') + try: + e.extend(L) + except TypeError: + pass + + class Y(X, ET.Element): + pass + L = [Y('x')] + e = ET.Element('foo') + e.extend(L) + + def test_extend_mutable_list2(self): + class X: + @property + def __class__(self): + del L[:] + return ET.Element + L = [X(), ET.Element('baz')] + e = ET.Element('foo') + try: + e.extend(L) + except TypeError: + pass + + class Y(X, ET.Element): + pass + L = [Y('bar'), ET.Element('baz')] + e = ET.Element('foo') + e.extend(L) + + def test_remove_with_mutating(self): + class X(ET.Element): + def __eq__(self, o): + del e[:] + return False + e = ET.Element('foo') + e.extend([X('bar')]) + self.assertRaises(ValueError, e.remove, ET.Element('baz')) + + e = ET.Element('foo') + e.extend([ET.Element('bar')]) + self.assertRaises(ValueError, e.remove, X('baz')) + + +class MutatingElementPath(str): + def __new__(cls, elem, *args): + self = str.__new__(cls, *args) + self.elem = elem + return self + def __eq__(self, o): + del self.elem[:] + return True +MutatingElementPath.__hash__ = str.__hash__ + +class BadElementPath(str): + def __eq__(self, o): + raise 1/0 +BadElementPath.__hash__ = str.__hash__ + +class BadElementPathTest(ElementTestCase, unittest.TestCase): + def setUp(self): + super().setUp() + from xml.etree import ElementPath + self.path_cache = ElementPath._cache + ElementPath._cache = {} + + def tearDown(self): + from xml.etree import ElementPath + ElementPath._cache = self.path_cache + super().tearDown() + + def test_find_with_mutating(self): + e = ET.Element('foo') + e.extend([ET.Element('bar')]) + e.find(MutatingElementPath(e, 'x')) + + def test_find_with_error(self): + e = ET.Element('foo') + e.extend([ET.Element('bar')]) + try: + e.find(BadElementPath('x')) + except ZeroDivisionError: + pass + + def test_findtext_with_mutating(self): + e = ET.Element('foo') + e.extend([ET.Element('bar')]) + e.findtext(MutatingElementPath(e, 'x')) + + def test_findtext_with_error(self): + e = ET.Element('foo') + e.extend([ET.Element('bar')]) + try: + e.findtext(BadElementPath('x')) + except ZeroDivisionError: + pass + + def test_findall_with_mutating(self): + e = ET.Element('foo') + e.extend([ET.Element('bar')]) + e.findall(MutatingElementPath(e, 'x')) + + def test_findall_with_error(self): + e = ET.Element('foo') + e.extend([ET.Element('bar')]) + try: + e.findall(BadElementPath('x')) + except ZeroDivisionError: + pass + + class ElementTreeTypeTest(unittest.TestCase): def test_istype(self): self.assertIsInstance(ET.ParseError, type) @@ -2556,6 +2676,8 @@ def test_main(module=None): ModuleTest, ElementSlicingTest, BasicElementTest, + BadElementTest, + BadElementPathTest, ElementTreeTest, IOTest, ParseErrorTest, @@ -50,6 +50,9 @@ Core and Builtins Library ------- +- Issue #24091: Fixed various crashes in corner cases in C implementation of + ElementTree. + - Issue #21931: msilib.FCICreate() now raises TypeError in the case of a bad argument instead of a ValueError with a bogus FCI error number. Patch by Jeffrey Armstrong. diff --git a/Modules/_elementtree.c b/Modules/_elementtree.c index b3b6976..136f19c 100644 --- a/Modules/_elementtree.c +++ b/Modules/_elementtree.c @@ -1036,7 +1036,7 @@ static PyObject* element_extend(ElementObject* self, PyObject* args) { PyObject* seq; - Py_ssize_t i, seqlen = 0; + Py_ssize_t i; PyObject* seq_in; if (!PyArg_ParseTuple(args, "O:extend", &seq_in)) @@ -1051,22 +1051,25 @@ element_extend(ElementObject* self, PyObject* args) return NULL; } - seqlen = PySequence_Size(seq); - for (i = 0; i < seqlen; i++) { + for (i = 0; i < PySequence_Fast_GET_SIZE(seq); i++) { PyObject* element = PySequence_Fast_GET_ITEM(seq, i); - if (!PyObject_IsInstance(element, (PyObject *)&Element_Type)) { - Py_DECREF(seq); + Py_INCREF(element); + if (!PyObject_TypeCheck(element, (PyTypeObject *)&Element_Type)) { PyErr_Format( PyExc_TypeError, "expected an Element, not \"%.200s\"", Py_TYPE(element)->tp_name); + Py_DECREF(seq); + Py_DECREF(element); return NULL; } if (element_add_subelement(self, element) < 0) { Py_DECREF(seq); + Py_DECREF(element); return NULL; } + Py_DECREF(element); } Py_DECREF(seq); @@ -1099,11 +1102,16 @@ element_find(ElementObject *self, PyObject *args, PyObject *kwds) for (i = 0; i < self->extra->length; i++) { PyObject* item = self->extra->children[i]; - if (Element_CheckExact(item) && - PyObject_RichCompareBool(((ElementObject*)item)->tag, tag, Py_EQ) == 1) { - Py_INCREF(item); + int rc; + if (!Element_CheckExact(item)) + continue; + Py_INCREF(item); + rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, tag, Py_EQ); + if (rc > 0) return item; - } + Py_DECREF(item); + if (rc < 0) + return NULL; } Py_RETURN_NONE; @@ -1136,14 +1144,24 @@ element_findtext(ElementObject *self, PyObject *args, PyObject *kwds) for (i = 0; i < self->extra->length; i++) { ElementObject* item = (ElementObject*) self->extra->children[i]; - if (Element_CheckExact(item) && - (PyObject_RichCompareBool(item->tag, tag, Py_EQ) == 1)) { + int rc; + if (!Element_CheckExact(item)) + continue; + Py_INCREF(item); + rc = PyObject_RichCompareBool(item->tag, tag, Py_EQ); + if (rc > 0) { PyObject* text = element_get_text(item); - if (text == Py_None) + if (text == Py_None) { + Py_DECREF(item); return PyUnicode_New(0, 0); + } Py_XINCREF(text); + Py_DECREF(item); return text; } + Py_DECREF(item); + if (rc < 0) + return NULL; } Py_INCREF(default_value); @@ -1180,13 +1198,17 @@ element_findall(ElementObject *self, PyObject *args, PyObject *kwds) for (i = 0; i < self->extra->length; i++) { PyObject* item = self->extra->children[i]; - if (Element_CheckExact(item) && - PyObject_RichCompareBool(((ElementObject*)item)->tag, tag, Py_EQ) == 1) { - if (PyList_Append(out, item) < 0) { - Py_DECREF(out); - return NULL; - } + int rc; + if (!Element_CheckExact(item)) + continue; + Py_INCREF(item); + rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, tag, Py_EQ); + if (rc != 0 && (rc < 0 || PyList_Append(out, item) < 0)) { + Py_DECREF(item); + Py_DECREF(out); + return NULL; } + Py_DECREF(item); } return out; @@ -1403,8 +1425,10 @@ static PyObject* element_remove(ElementObject* self, PyObject* args) { int i; - + int rc; PyObject* element; + PyObject* found; + if (!PyArg_ParseTuple(args, "O!:remove", &Element_Type, &element)) return NULL; @@ -1420,11 +1444,14 @@ element_remove(ElementObject* self, PyObject* args) for (i = 0; i < self->extra->length; i++) { if (self->extra->children[i] == element) break; - if (PyObject_RichCompareBool(self->extra->children[i], element, Py_EQ) == 1) + rc = PyObject_RichCompareBool(self->extra->children[i], element, Py_EQ); + if (rc > 0) break; + if (rc < 0) + return NULL; } - if (i == self->extra->length) { + if (i >= self->extra->length) { /* element is not in children, so raise exception */ PyErr_SetString( PyExc_ValueError, @@ -1433,13 +1460,13 @@ element_remove(ElementObject* self, PyObject* args) return NULL; } - Py_DECREF(self->extra->children[i]); + found = self->extra->children[i]; self->extra->length--; - for (; i < self->extra->length; i++) self->extra->children[i] = self->extra->children[i+1]; + Py_DECREF(found); Py_RETURN_NONE; } @@ -2012,6 +2039,7 @@ elementiter_next(ElementIterObject *it) */ ElementObject *cur_parent; Py_ssize_t child_index; + int rc; while (1) { /* Handle the case reached in the beginning and end of iteration, where @@ -2033,14 +2061,22 @@ elementiter_next(ElementIterObject *it) } it->root_done = 1; - if (it->sought_tag == Py_None || - PyObject_RichCompareBool(it->root_element->tag, - it->sought_tag, Py_EQ) == 1) { + rc = (it->sought_tag == Py_None); + if (!rc) { + rc = PyObject_RichCompareBool(it->root_element->tag, + it->sought_tag, Py_EQ); + if (rc < 0) + return NULL; + } + if (rc) { if (it->gettext) { PyObject *text = element_get_text(it->root_element); if (!text) return NULL; - if (PyObject_IsTrue(text)) { + rc = PyObject_IsTrue(text); + if (rc < 0) + return NULL; + if (rc) { Py_INCREF(text); return text; } @@ -2072,18 +2108,26 @@ elementiter_next(ElementIterObject *it) PyObject *text = element_get_text(child); if (!text) return NULL; - if (PyObject_IsTrue(text)) { + rc = PyObject_IsTrue(text); + if (rc < 0) + return NULL; + if (rc) { Py_INCREF(text); return text; } - } else if (it->sought_tag == Py_None || - PyObject_RichCompareBool(child->tag, - it->sought_tag, Py_EQ) == 1) { - Py_INCREF(child); - return (PyObject *)child; + } else { + rc = (it->sought_tag == Py_None); + if (!rc) { + rc = PyObject_RichCompareBool(child->tag, + it->sought_tag, Py_EQ); + if (rc < 0) + return NULL; + } + if (rc) { + Py_INCREF(child); + return (PyObject *)child; + } } - else - continue; } else { PyObject *tail; @@ -2103,9 +2147,14 @@ elementiter_next(ElementIterObject *it) * this is because itertext() is supposed to only return *inner* * text, not text following the element it began iteration with. */ - if (it->parent_stack->parent && PyObject_IsTrue(tail)) { - Py_INCREF(tail); - return tail; + if (it->parent_stack->parent) { + rc = PyObject_IsTrue(tail); + if (rc < 0) + return NULL; + if (rc) { + Py_INCREF(tail); + return tail; + } } } } @@ -2163,20 +2212,21 @@ static PyObject * create_elementiter(ElementObject *self, PyObject *tag, int gettext) { ElementIterObject *it; - PyObject *star = NULL; it = PyObject_GC_New(ElementIterObject, &ElementIter_Type); if (!it) return NULL; - if (PyUnicode_Check(tag)) - star = PyUnicode_FromString("*"); - else if (PyBytes_Check(tag)) - star = PyBytes_FromString("*"); - - if (star && PyObject_RichCompareBool(tag, star, Py_EQ) == 1) - tag = Py_None; - Py_XDECREF(star); + if (PyUnicode_Check(tag)) { + if (PyUnicode_READY(tag) < 0) + return NULL; + if (PyUnicode_GET_LENGTH(tag) == 1 && PyUnicode_READ_CHAR(tag, 0) == '*') + tag = Py_None; + } + else if (PyBytes_Check(tag)) { + if (PyBytes_GET_SIZE(tag) == 1 && *PyBytes_AS_STRING(tag) == '*') + tag = Py_None; + } Py_INCREF(tag); it->sought_tag = tag; |