summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_xml_etree.py42
-rw-r--r--Lib/xml/etree/ElementTree.py10
-rw-r--r--Modules/_elementtree.c69
3 files changed, 71 insertions, 50 deletions
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index 61416ba..8b16905 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -1795,6 +1795,28 @@ class BasicElementTest(ElementTestCase, unittest.TestCase):
self.assertRaises(TypeError, e.append, 'b')
self.assertRaises(TypeError, e.extend, [ET.Element('bar'), 'foo'])
self.assertRaises(TypeError, e.insert, 0, 'foo')
+ e[:] = [ET.Element('bar')]
+ with self.assertRaises(TypeError):
+ e[0] = 'foo'
+ with self.assertRaises(TypeError):
+ e[:] = [ET.Element('bar'), 'foo']
+
+ if hasattr(e, '__setstate__'):
+ state = {
+ 'tag': 'tag',
+ '_children': [None], # non-Element
+ 'attrib': 'attr',
+ 'tail': 'tail',
+ 'text': 'text',
+ }
+ self.assertRaises(TypeError, e.__setstate__, state)
+
+ if hasattr(e, '__deepcopy__'):
+ class E(ET.Element):
+ def __deepcopy__(self, memo):
+ return None # non-Element
+ e[:] = [E('bar')]
+ self.assertRaises(TypeError, copy.deepcopy, e)
def test_cyclic_gc(self):
class Dummy:
@@ -1981,26 +2003,6 @@ class BadElementTest(ElementTestCase, unittest.TestCase):
elem = b.close()
self.assertEqual(elem[0].tail, 'ABCDEFGHIJKL')
- def test_element_iter(self):
- # Issue #27863
- state = {
- 'tag': 'tag',
- '_children': [None], # non-Element
- 'attrib': 'attr',
- 'tail': 'tail',
- 'text': 'text',
- }
-
- e = ET.Element('tag')
- try:
- e.__setstate__(state)
- except AttributeError:
- e.__dict__ = state
-
- it = e.iter()
- self.assertIs(next(it), e)
- self.assertRaises(AttributeError, next, it)
-
def test_subscr(self):
# Issue #27863
class X:
diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py
index 371b371..85586d0 100644
--- a/Lib/xml/etree/ElementTree.py
+++ b/Lib/xml/etree/ElementTree.py
@@ -217,11 +217,11 @@ class Element:
return self._children[index]
def __setitem__(self, index, element):
- # if isinstance(index, slice):
- # for elt in element:
- # assert iselement(elt)
- # else:
- # assert iselement(element)
+ if isinstance(index, slice):
+ for elt in element:
+ self._assert_is_element(elt)
+ else:
+ self._assert_is_element(element)
self._children[index] = element
def __delitem__(self, index):
diff --git a/Modules/_elementtree.c b/Modules/_elementtree.c
index 9195914..f88315d 100644
--- a/Modules/_elementtree.c
+++ b/Modules/_elementtree.c
@@ -480,11 +480,24 @@ element_resize(ElementObject* self, Py_ssize_t extra)
return -1;
}
+LOCAL(void)
+raise_type_error(PyObject *element)
+{
+ PyErr_Format(PyExc_TypeError,
+ "expected an Element, not \"%.200s\"",
+ Py_TYPE(element)->tp_name);
+}
+
LOCAL(int)
element_add_subelement(ElementObject* self, PyObject* element)
{
/* add a child element to a parent */
+ if (!Element_Check(element)) {
+ raise_type_error(element);
+ return -1;
+ }
+
if (element_resize(self, 1) < 0)
return -1;
@@ -803,7 +816,11 @@ _elementtree_Element___deepcopy___impl(ElementObject *self, PyObject *memo)
for (i = 0; i < self->extra->length; i++) {
PyObject* child = deepcopy(self->extra->children[i], memo);
- if (!child) {
+ if (!child || !Element_Check(child)) {
+ if (child) {
+ raise_type_error(child);
+ Py_DECREF(child);
+ }
element->extra->length = i;
goto error;
}
@@ -1024,8 +1041,15 @@ element_setstate_from_attributes(ElementObject *self,
/* Copy children */
for (i = 0; i < nchildren; i++) {
- self->extra->children[i] = PyList_GET_ITEM(children, i);
- Py_INCREF(self->extra->children[i]);
+ PyObject *child = PyList_GET_ITEM(children, i);
+ if (!Element_Check(child)) {
+ raise_type_error(child);
+ self->extra->length = i;
+ dealloc_extra(oldextra);
+ return NULL;
+ }
+ Py_INCREF(child);
+ self->extra->children[i] = child;
}
assert(!self->extra->length);
@@ -1167,16 +1191,6 @@ _elementtree_Element_extend(ElementObject *self, PyObject *elements)
for (i = 0; i < PySequence_Fast_GET_SIZE(seq); i++) {
PyObject* element = PySequence_Fast_GET_ITEM(seq, i);
Py_INCREF(element);
- if (!Element_Check(element)) {
- PyErr_Format(
- PyExc_TypeError,
- "expected an Element, not \"%.200s\"",
- Py_TYPE(element)->tp_name);
- Py_DECREF(seq);
- Py_DECREF(element);
- return NULL;
- }
-
if (element_add_subelement(self, element) < 0) {
Py_DECREF(seq);
Py_DECREF(element);
@@ -1219,8 +1233,7 @@ _elementtree_Element_find_impl(ElementObject *self, PyObject *path,
for (i = 0; i < self->extra->length; i++) {
PyObject* item = self->extra->children[i];
int rc;
- if (!Element_Check(item))
- continue;
+ assert(Element_Check(item));
Py_INCREF(item);
rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ);
if (rc > 0)
@@ -1266,8 +1279,7 @@ _elementtree_Element_findtext_impl(ElementObject *self, PyObject *path,
for (i = 0; i < self->extra->length; i++) {
PyObject *item = self->extra->children[i];
int rc;
- if (!Element_Check(item))
- continue;
+ assert(Element_Check(item));
Py_INCREF(item);
rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ);
if (rc > 0) {
@@ -1323,8 +1335,7 @@ _elementtree_Element_findall_impl(ElementObject *self, PyObject *path,
for (i = 0; i < self->extra->length; i++) {
PyObject* item = self->extra->children[i];
int rc;
- if (!Element_Check(item))
- continue;
+ assert(Element_Check(item));
Py_INCREF(item);
rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ);
if (rc != 0 && (rc < 0 || PyList_Append(out, item) < 0)) {
@@ -1736,6 +1747,10 @@ element_setitem(PyObject* self_, Py_ssize_t index, PyObject* item)
old = self->extra->children[index];
if (item) {
+ if (!Element_Check(item)) {
+ raise_type_error(item);
+ return -1;
+ }
Py_INCREF(item);
self->extra->children[index] = item;
} else {
@@ -1930,6 +1945,15 @@ element_ass_subscr(PyObject* self_, PyObject* item, PyObject* value)
}
}
+ for (i = 0; i < newlen; i++) {
+ PyObject *element = PySequence_Fast_GET_ITEM(seq, i);
+ if (!Element_Check(element)) {
+ raise_type_error(element);
+ Py_DECREF(seq);
+ return -1;
+ }
+ }
+
if (slicelen > 0) {
/* to avoid recursive calls to this method (via decref), move
old items to the recycle bin here, and get rid of them when
@@ -2207,12 +2231,7 @@ elementiter_next(ElementIterObject *it)
continue;
}
- if (!Element_Check(extra->children[child_index])) {
- PyErr_Format(PyExc_AttributeError,
- "'%.100s' object has no attribute 'iter'",
- Py_TYPE(extra->children[child_index])->tp_name);
- return NULL;
- }
+ assert(Element_Check(extra->children[child_index]));
elem = (ElementObject *)extra->children[child_index];
item->child_index++;
Py_INCREF(elem);