diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/test_xml_etree.py | 12 | ||||
-rw-r--r-- | Lib/xml/etree/ElementTree.py | 12 |
2 files changed, 17 insertions, 7 deletions
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index 50e5196..8a1ea0f 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -1839,8 +1839,15 @@ def check_issue10777(): # -------------------------------------------------------------------- -class ElementTreeTest(unittest.TestCase): +class BasicElementTest(unittest.TestCase): + def test_augmentation_type_errors(self): + e = ET.Element('joe') + self.assertRaises(TypeError, e.append, 'b') + self.assertRaises(TypeError, e.extend, [ET.Element('bar'), 'foo']) + self.assertRaises(TypeError, e.insert, 0, 'foo') + +class ElementTreeTest(unittest.TestCase): def test_istype(self): self.assertIsInstance(ET.ParseError, type) self.assertIsInstance(ET.QName, type) @@ -1879,7 +1886,6 @@ class ElementTreeTest(unittest.TestCase): class TreeBuilderTest(unittest.TestCase): - sample1 = ('<!DOCTYPE html PUBLIC' ' "-//W3C//DTD XHTML 1.0 Transitional//EN"' ' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">' @@ -1931,7 +1937,6 @@ class TreeBuilderTest(unittest.TestCase): class NoAcceleratorTest(unittest.TestCase): - # Test that the C accelerator was not imported for pyET def test_correct_import_pyET(self): self.assertEqual(pyET.Element.__module__, 'xml.etree.ElementTree') @@ -2096,6 +2101,7 @@ def test_main(module=pyET): test_classes = [ ElementSlicingTest, + BasicElementTest, StringIOTest, ParseErrorTest, ElementTreeTest, diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py index 10ee896..5f974f6 100644 --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -298,7 +298,7 @@ class Element: # @param element The element to add. def append(self, element): - # assert iselement(element) + self._assert_is_element(element) self._children.append(element) ## @@ -308,8 +308,8 @@ class Element: # @since 1.3 def extend(self, elements): - # for element in elements: - # assert iselement(element) + for element in elements: + self._assert_is_element(element) self._children.extend(elements) ## @@ -318,9 +318,13 @@ class Element: # @param index Where to insert the new subelement. def insert(self, index, element): - # assert iselement(element) + self._assert_is_element(element) self._children.insert(index, element) + def _assert_is_element(self, e): + if not isinstance(e, Element): + raise TypeError('expected an Element, not %s' % type(e).__name__) + ## # Removes a matching subelement. Unlike the <b>find</b> methods, # this method compares elements based on identity, not on tag |