diff options
author | Eli Bendersky <eliben@gmail.com> | 2012-07-15 03:02:22 (GMT) |
---|---|---|
committer | Eli Bendersky <eliben@gmail.com> | 2012-07-15 03:02:22 (GMT) |
commit | 00f402bfcbe3245f9c62f86376fc77bb9e7de639 (patch) | |
tree | c5035e1c4af4be283479aca143ba687d74d19c0f /Lib | |
parent | 1191709b1379661a15287a2c6ac8263f23655f73 (diff) | |
download | cpython-00f402bfcbe3245f9c62f86376fc77bb9e7de639.zip cpython-00f402bfcbe3245f9c62f86376fc77bb9e7de639.tar.gz cpython-00f402bfcbe3245f9c62f86376fc77bb9e7de639.tar.bz2 |
Close #1767933: Badly formed XML using etree and utf-16. Patch by Serhiy Storchaka, with some minor fixes by me
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/test_xml_etree.py | 240 | ||||
-rw-r--r-- | Lib/xml/etree/ElementTree.py | 138 |
2 files changed, 257 insertions, 121 deletions
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index c1fc955..d90f978 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -21,7 +21,7 @@ import unittest import weakref from test import support -from test.support import findfile, import_fresh_module, gc_collect +from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect pyET = None ET = None @@ -888,65 +888,6 @@ def check_encoding(encoding): """ ET.XML("<?xml version='1.0' encoding='%s'?><xml />" % encoding) -def encoding(): - r""" - Test encoding issues. - - >>> elem = ET.Element("tag") - >>> elem.text = "abc" - >>> serialize(elem) - '<tag>abc</tag>' - >>> serialize(elem, encoding="utf-8") - b'<tag>abc</tag>' - >>> serialize(elem, encoding="us-ascii") - b'<tag>abc</tag>' - >>> serialize(elem, encoding="iso-8859-1") - b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>abc</tag>" - - >>> elem.text = "<&\"\'>" - >>> serialize(elem) - '<tag><&"\'></tag>' - >>> serialize(elem, encoding="utf-8") - b'<tag><&"\'></tag>' - >>> serialize(elem, encoding="us-ascii") # cdata characters - b'<tag><&"\'></tag>' - >>> serialize(elem, encoding="iso-8859-1") - b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag><&"\'></tag>' - - >>> elem.attrib["key"] = "<&\"\'>" - >>> elem.text = None - >>> serialize(elem) - '<tag key="<&"\'>" />' - >>> serialize(elem, encoding="utf-8") - b'<tag key="<&"\'>" />' - >>> serialize(elem, encoding="us-ascii") - b'<tag key="<&"\'>" />' - >>> serialize(elem, encoding="iso-8859-1") - b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="<&"\'>" />' - - >>> elem.text = '\xe5\xf6\xf6<>' - >>> elem.attrib.clear() - >>> serialize(elem) - '<tag>\xe5\xf6\xf6<></tag>' - >>> serialize(elem, encoding="utf-8") - b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6<></tag>' - >>> serialize(elem, encoding="us-ascii") - b'<tag>åöö<></tag>' - >>> serialize(elem, encoding="iso-8859-1") - b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>\xe5\xf6\xf6<></tag>" - - >>> elem.attrib["key"] = '\xe5\xf6\xf6<>' - >>> elem.text = None - >>> serialize(elem) - '<tag key="\xe5\xf6\xf6<>" />' - >>> serialize(elem, encoding="utf-8") - b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6<>" />' - >>> serialize(elem, encoding="us-ascii") - b'<tag key="åöö<>" />' - >>> serialize(elem, encoding="iso-8859-1") - b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="\xe5\xf6\xf6<>" />' - """ - def methods(): r""" Test serialization methods. @@ -2166,16 +2107,185 @@ class ElementSlicingTest(unittest.TestCase): self.assertEqual(self._subelem_tags(e), ['a1']) -class StringIOTest(unittest.TestCase): +class IOTest(unittest.TestCase): + def tearDown(self): + unlink(TESTFN) + + def test_encoding(self): + # Test encoding issues. + elem = ET.Element("tag") + elem.text = "abc" + self.assertEqual(serialize(elem), '<tag>abc</tag>') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'<tag>abc</tag>') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'<tag>abc</tag>') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("<?xml version='1.0' encoding='%s'?>\n" + "<tag>abc</tag>" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.text = "<&\"\'>" + self.assertEqual(serialize(elem), '<tag><&"\'></tag>') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'<tag><&"\'></tag>') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'<tag><&"\'></tag>') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("<?xml version='1.0' encoding='%s'?>\n" + "<tag><&\"'></tag>" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.attrib["key"] = "<&\"\'>" + self.assertEqual(serialize(elem), '<tag key="<&"\'>" />') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'<tag key="<&"\'>" />') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'<tag key="<&"\'>" />') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("<?xml version='1.0' encoding='%s'?>\n" + "<tag key=\"<&"'>\" />" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.text = '\xe5\xf6\xf6<>' + self.assertEqual(serialize(elem), '<tag>\xe5\xf6\xf6<></tag>') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6<></tag>') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'<tag>åöö<></tag>') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("<?xml version='1.0' encoding='%s'?>\n" + "<tag>åöö<></tag>" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.attrib["key"] = '\xe5\xf6\xf6<>' + self.assertEqual(serialize(elem), '<tag key="\xe5\xf6\xf6<>" />') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6<>" />') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'<tag key="åöö<>" />') + for enc in ("iso-8859-1", "utf-16", "utf-16le", "utf-16be", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("<?xml version='1.0' encoding='%s'?>\n" + "<tag key=\"åöö<>\" />" % enc).encode(enc)) + + def test_write_to_filename(self): + tree = ET.ElementTree(ET.XML('''<site />''')) + tree.write(TESTFN) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''<site />''') + + def test_write_to_text_file(self): + tree = ET.ElementTree(ET.XML('''<site />''')) + with open(TESTFN, 'w', encoding='utf-8') as f: + tree.write(f, encoding='unicode') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''<site />''') + + def test_write_to_binary_file(self): + tree = ET.ElementTree(ET.XML('''<site />''')) + with open(TESTFN, 'wb') as f: + tree.write(f) + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''<site />''') + + def test_write_to_binary_file_with_bom(self): + tree = ET.ElementTree(ET.XML('''<site />''')) + # test BOM writing to buffered file + with open(TESTFN, 'wb') as f: + tree.write(f, encoding='utf-16') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), + '''<?xml version='1.0' encoding='utf-16'?>\n''' + '''<site />'''.encode("utf-16")) + # test BOM writing to non-buffered file + with open(TESTFN, 'wb', buffering=0) as f: + tree.write(f, encoding='utf-16') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), + '''<?xml version='1.0' encoding='utf-16'?>\n''' + '''<site />'''.encode("utf-16")) + def test_read_from_stringio(self): tree = ET.ElementTree() - stream = io.StringIO() - stream.write('''<?xml version="1.0"?><site></site>''') - stream.seek(0) + stream = io.StringIO('''<?xml version="1.0"?><site></site>''') tree.parse(stream) + self.assertEqual(tree.getroot().tag, 'site') + def test_write_to_stringio(self): + tree = ET.ElementTree(ET.XML('''<site />''')) + stream = io.StringIO() + tree.write(stream, encoding='unicode') + self.assertEqual(stream.getvalue(), '''<site />''') + + def test_read_from_bytesio(self): + tree = ET.ElementTree() + raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''') + tree.parse(raw) + self.assertEqual(tree.getroot().tag, 'site') + + def test_write_to_bytesio(self): + tree = ET.ElementTree(ET.XML('''<site />''')) + raw = io.BytesIO() + tree.write(raw) + self.assertEqual(raw.getvalue(), b'''<site />''') + + class dummy: + pass + + def test_read_from_user_text_reader(self): + stream = io.StringIO('''<?xml version="1.0"?><site></site>''') + reader = self.dummy() + reader.read = stream.read + tree = ET.ElementTree() + tree.parse(reader) self.assertEqual(tree.getroot().tag, 'site') + def test_write_to_user_text_writer(self): + tree = ET.ElementTree(ET.XML('''<site />''')) + stream = io.StringIO() + writer = self.dummy() + writer.write = stream.write + tree.write(writer, encoding='unicode') + self.assertEqual(stream.getvalue(), '''<site />''') + + def test_read_from_user_binary_reader(self): + raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''') + reader = self.dummy() + reader.read = raw.read + tree = ET.ElementTree() + tree.parse(reader) + self.assertEqual(tree.getroot().tag, 'site') + tree = ET.ElementTree() + + def test_write_to_user_binary_writer(self): + tree = ET.ElementTree(ET.XML('''<site />''')) + raw = io.BytesIO() + writer = self.dummy() + writer.write = raw.write + tree.write(writer) + self.assertEqual(raw.getvalue(), b'''<site />''') + + def test_write_to_user_binary_writer_with_bom(self): + tree = ET.ElementTree(ET.XML('''<site />''')) + raw = io.BytesIO() + writer = self.dummy() + writer.write = raw.write + writer.seekable = lambda: True + writer.tell = raw.tell + tree.write(writer, encoding="utf-16") + self.assertEqual(raw.getvalue(), + '''<?xml version='1.0' encoding='utf-16'?>\n''' + '''<site />'''.encode("utf-16")) + class ParseErrorTest(unittest.TestCase): def test_subclass(self): @@ -2299,7 +2409,7 @@ def test_main(module=None): test_classes = [ ElementSlicingTest, BasicElementTest, - StringIOTest, + IOTest, ParseErrorTest, XincludeTest, ElementTreeTest, diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py index 61fe155..10bf849 100644 --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -100,6 +100,8 @@ VERSION = "1.3.0" import sys import re import warnings +import io +import contextlib from . import ElementPath @@ -792,59 +794,38 @@ class ElementTree: # None for only if not US-ASCII or UTF-8 or Unicode. None is default. def write(self, file_or_filename, - # keyword arguments encoding=None, xml_declaration=None, default_namespace=None, method=None): - # assert self._root is not None if not method: method = "xml" elif method not in _serialize: - # FIXME: raise an ImportError for c14n if ElementC14N is missing? raise ValueError("unknown method %r" % method) if not encoding: if method == "c14n": encoding = "utf-8" else: encoding = "us-ascii" - elif encoding == str: # lxml.etree compatibility. - encoding = "unicode" else: encoding = encoding.lower() - if hasattr(file_or_filename, "write"): - file = file_or_filename - else: - if encoding != "unicode": - file = open(file_or_filename, "wb") + with _get_writer(file_or_filename, encoding) as write: + if method == "xml" and (xml_declaration or + (xml_declaration is None and + encoding not in ("utf-8", "us-ascii", "unicode"))): + declared_encoding = encoding + if encoding == "unicode": + # Retrieve the default encoding for the xml declaration + import locale + declared_encoding = locale.getpreferredencoding() + write("<?xml version='1.0' encoding='%s'?>\n" % ( + declared_encoding,)) + if method == "text": + _serialize_text(write, self._root) else: - file = open(file_or_filename, "w") - if encoding != "unicode": - def write(text): - try: - return file.write(text.encode(encoding, - "xmlcharrefreplace")) - except (TypeError, AttributeError): - _raise_serialization_error(text) - else: - write = file.write - if method == "xml" and (xml_declaration or - (xml_declaration is None and - encoding not in ("utf-8", "us-ascii", "unicode"))): - declared_encoding = encoding - if encoding == "unicode": - # Retrieve the default encoding for the xml declaration - import locale - declared_encoding = locale.getpreferredencoding() - write("<?xml version='1.0' encoding='%s'?>\n" % declared_encoding) - if method == "text": - _serialize_text(write, self._root) - else: - qnames, namespaces = _namespaces(self._root, default_namespace) - serialize = _serialize[method] - serialize(write, self._root, qnames, namespaces) - if file_or_filename is not file: - file.close() + qnames, namespaces = _namespaces(self._root, default_namespace) + serialize = _serialize[method] + serialize(write, self._root, qnames, namespaces) def write_c14n(self, file): # lxml.etree compatibility. use output method instead @@ -853,6 +834,58 @@ class ElementTree: # -------------------------------------------------------------------- # serialization support +@contextlib.contextmanager +def _get_writer(file_or_filename, encoding): + # returns text write method and release all resourses after using + try: + write = file_or_filename.write + except AttributeError: + # file_or_filename is a file name + if encoding == "unicode": + file = open(file_or_filename, "w") + else: + file = open(file_or_filename, "w", encoding=encoding, + errors="xmlcharrefreplace") + with file: + yield file.write + else: + # file_or_filename is a file-like object + # encoding determines if it is a text or binary writer + if encoding == "unicode": + # use a text writer as is + yield write + else: + # wrap a binary writer with TextIOWrapper + with contextlib.ExitStack() as stack: + if isinstance(file_or_filename, io.BufferedIOBase): + file = file_or_filename + elif isinstance(file_or_filename, io.RawIOBase): + file = io.BufferedWriter(file_or_filename) + # Keep the original file open when the BufferedWriter is + # destroyed + stack.callback(file.detach) + else: + # This is to handle passed objects that aren't in the + # IOBase hierarchy, but just have a write method + file = io.BufferedIOBase() + file.writable = lambda: True + file.write = write + try: + # TextIOWrapper uses this methods to determine + # if BOM (for UTF-16, etc) should be added + file.seekable = file_or_filename.seekable + file.tell = file_or_filename.tell + except AttributeError: + pass + file = io.TextIOWrapper(file, + encoding=encoding, + errors="xmlcharrefreplace", + newline="\n") + # Keep the original file open when the TextIOWrapper is + # destroyed + stack.callback(file.detach) + yield file.write + def _namespaces(elem, default_namespace=None): # identify namespaces used in this tree @@ -1134,22 +1167,13 @@ def _escape_attrib_html(text): # @defreturn string def tostring(element, encoding=None, method=None): - class dummy: - pass - data = [] - file = dummy() - file.write = data.append - ElementTree(element).write(file, encoding, method=method) - if encoding in (str, "unicode"): - return "".join(data) - else: - return b"".join(data) + stream = io.StringIO() if encoding == 'unicode' else io.BytesIO() + ElementTree(element).write(stream, encoding, method=method) + return stream.getvalue() ## # Generates a string representation of an XML element, including all -# subelements. If encoding is False, the string is returned as a -# sequence of string fragments; otherwise it is a sequence of -# bytestrings. +# subelements. # # @param element An Element instance. # @keyparam encoding Optional output encoding (default is US-ASCII). @@ -1161,13 +1185,15 @@ def tostring(element, encoding=None, method=None): # @since 1.3 def tostringlist(element, encoding=None, method=None): - class dummy: - pass data = [] - file = dummy() - file.write = data.append - ElementTree(element).write(file, encoding, method=method) - # FIXME: merge small fragments into larger parts + class DataStream(io.BufferedIOBase): + def writable(self): + return True + + def write(self, b): + data.append(b) + + ElementTree(element).write(DataStream(), encoding, method=method) return data ## |