summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_xml_etree.py12
-rw-r--r--Lib/xml/etree/ElementTree.py12
2 files changed, 17 insertions, 7 deletions
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index 50e5196..8a1ea0f 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -1839,8 +1839,15 @@ def check_issue10777():
# --------------------------------------------------------------------
-class ElementTreeTest(unittest.TestCase):
+class BasicElementTest(unittest.TestCase):
+ def test_augmentation_type_errors(self):
+ e = ET.Element('joe')
+ self.assertRaises(TypeError, e.append, 'b')
+ self.assertRaises(TypeError, e.extend, [ET.Element('bar'), 'foo'])
+ self.assertRaises(TypeError, e.insert, 0, 'foo')
+
+class ElementTreeTest(unittest.TestCase):
def test_istype(self):
self.assertIsInstance(ET.ParseError, type)
self.assertIsInstance(ET.QName, type)
@@ -1879,7 +1886,6 @@ class ElementTreeTest(unittest.TestCase):
class TreeBuilderTest(unittest.TestCase):
-
sample1 = ('<!DOCTYPE html PUBLIC'
' "-//W3C//DTD XHTML 1.0 Transitional//EN"'
' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">'
@@ -1931,7 +1937,6 @@ class TreeBuilderTest(unittest.TestCase):
class NoAcceleratorTest(unittest.TestCase):
-
# Test that the C accelerator was not imported for pyET
def test_correct_import_pyET(self):
self.assertEqual(pyET.Element.__module__, 'xml.etree.ElementTree')
@@ -2096,6 +2101,7 @@ def test_main(module=pyET):
test_classes = [
ElementSlicingTest,
+ BasicElementTest,
StringIOTest,
ParseErrorTest,
ElementTreeTest,
diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py
index 10ee896..5f974f6 100644
--- a/Lib/xml/etree/ElementTree.py
+++ b/Lib/xml/etree/ElementTree.py
@@ -298,7 +298,7 @@ class Element:
# @param element The element to add.
def append(self, element):
- # assert iselement(element)
+ self._assert_is_element(element)
self._children.append(element)
##
@@ -308,8 +308,8 @@ class Element:
# @since 1.3
def extend(self, elements):
- # for element in elements:
- # assert iselement(element)
+ for element in elements:
+ self._assert_is_element(element)
self._children.extend(elements)
##
@@ -318,9 +318,13 @@ class Element:
# @param index Where to insert the new subelement.
def insert(self, index, element):
- # assert iselement(element)
+ self._assert_is_element(element)
self._children.insert(index, element)
+ def _assert_is_element(self, e):
+ if not isinstance(e, Element):
+ raise TypeError('expected an Element, not %s' % type(e).__name__)
+
##
# Removes a matching subelement. Unlike the <b>find</b> methods,
# this method compares elements based on identity, not on tag