From 00f402bfcbe3245f9c62f86376fc77bb9e7de639 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Sun, 15 Jul 2012 06:02:22 +0300 Subject: Close #1767933: Badly formed XML using etree and utf-16. Patch by Serhiy Storchaka, with some minor fixes by me --- Doc/library/xml.etree.elementtree.rst | 1 - Lib/test/test_xml_etree.py | 240 +++++++++++++++++++++++++--------- Lib/xml/etree/ElementTree.py | 138 +++++++++++-------- 3 files changed, 257 insertions(+), 122 deletions(-) diff --git a/Doc/library/xml.etree.elementtree.rst b/Doc/library/xml.etree.elementtree.rst index 335a6e2..3c2ddd3 100644 --- a/Doc/library/xml.etree.elementtree.rst +++ b/Doc/library/xml.etree.elementtree.rst @@ -659,7 +659,6 @@ ElementTree Objects should be added to the file. Use False for never, True for always, None for only if not US-ASCII or UTF-8 or Unicode (default is None). *method* is either ``"xml"``, ``"html"`` or ``"text"`` (default is ``"xml"``). - Returns an (optionally) encoded string. This is the XML file that is going to be manipulated:: diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index c1fc955..d90f978 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -21,7 +21,7 @@ import unittest import weakref from test import support -from test.support import findfile, import_fresh_module, gc_collect +from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect pyET = None ET = None @@ -888,65 +888,6 @@ def check_encoding(encoding): """ ET.XML("" % encoding) -def encoding(): - r""" - Test encoding issues. - - >>> elem = ET.Element("tag") - >>> elem.text = "abc" - >>> serialize(elem) - 'abc' - >>> serialize(elem, encoding="utf-8") - b'abc' - >>> serialize(elem, encoding="us-ascii") - b'abc' - >>> serialize(elem, encoding="iso-8859-1") - b"\nabc" - - >>> elem.text = "<&\"\'>" - >>> serialize(elem) - '<&"\'>' - >>> serialize(elem, encoding="utf-8") - b'<&"\'>' - >>> serialize(elem, encoding="us-ascii") # cdata characters - b'<&"\'>' - >>> serialize(elem, encoding="iso-8859-1") - b'\n<&"\'>' - - >>> elem.attrib["key"] = "<&\"\'>" - >>> elem.text = None - >>> serialize(elem) - '' - >>> serialize(elem, encoding="utf-8") - b'' - >>> serialize(elem, encoding="us-ascii") - b'' - >>> serialize(elem, encoding="iso-8859-1") - b'\n' - - >>> elem.text = '\xe5\xf6\xf6<>' - >>> elem.attrib.clear() - >>> serialize(elem) - '\xe5\xf6\xf6<>' - >>> serialize(elem, encoding="utf-8") - b'\xc3\xa5\xc3\xb6\xc3\xb6<>' - >>> serialize(elem, encoding="us-ascii") - b'åöö<>' - >>> serialize(elem, encoding="iso-8859-1") - b"\n\xe5\xf6\xf6<>" - - >>> elem.attrib["key"] = '\xe5\xf6\xf6<>' - >>> elem.text = None - >>> serialize(elem) - '' - >>> serialize(elem, encoding="utf-8") - b'' - >>> serialize(elem, encoding="us-ascii") - b'' - >>> serialize(elem, encoding="iso-8859-1") - b'\n' - """ - def methods(): r""" Test serialization methods. @@ -2166,16 +2107,185 @@ class ElementSlicingTest(unittest.TestCase): self.assertEqual(self._subelem_tags(e), ['a1']) -class StringIOTest(unittest.TestCase): +class IOTest(unittest.TestCase): + def tearDown(self): + unlink(TESTFN) + + def test_encoding(self): + # Test encoding issues. + elem = ET.Element("tag") + elem.text = "abc" + self.assertEqual(serialize(elem), 'abc') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'abc') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'abc') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "abc" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.text = "<&\"\'>" + self.assertEqual(serialize(elem), '<&"\'>') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'<&"\'>') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'<&"\'>') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "<&\"'>" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.attrib["key"] = "<&\"\'>" + self.assertEqual(serialize(elem), '') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.text = '\xe5\xf6\xf6<>' + self.assertEqual(serialize(elem), '\xe5\xf6\xf6<>') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'\xc3\xa5\xc3\xb6\xc3\xb6<>') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'åöö<>') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "åöö<>" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.attrib["key"] = '\xe5\xf6\xf6<>' + self.assertEqual(serialize(elem), '') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'') + for enc in ("iso-8859-1", "utf-16", "utf-16le", "utf-16be", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "" % enc).encode(enc)) + + def test_write_to_filename(self): + tree = ET.ElementTree(ET.XML('''''')) + tree.write(TESTFN) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''''') + + def test_write_to_text_file(self): + tree = ET.ElementTree(ET.XML('''''')) + with open(TESTFN, 'w', encoding='utf-8') as f: + tree.write(f, encoding='unicode') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''''') + + def test_write_to_binary_file(self): + tree = ET.ElementTree(ET.XML('''''')) + with open(TESTFN, 'wb') as f: + tree.write(f) + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''''') + + def test_write_to_binary_file_with_bom(self): + tree = ET.ElementTree(ET.XML('''''')) + # test BOM writing to buffered file + with open(TESTFN, 'wb') as f: + tree.write(f, encoding='utf-16') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), + '''\n''' + ''''''.encode("utf-16")) + # test BOM writing to non-buffered file + with open(TESTFN, 'wb', buffering=0) as f: + tree.write(f, encoding='utf-16') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), + '''\n''' + ''''''.encode("utf-16")) + def test_read_from_stringio(self): tree = ET.ElementTree() - stream = io.StringIO() - stream.write('''''') - stream.seek(0) + stream = io.StringIO('''''') tree.parse(stream) + self.assertEqual(tree.getroot().tag, 'site') + def test_write_to_stringio(self): + tree = ET.ElementTree(ET.XML('''''')) + stream = io.StringIO() + tree.write(stream, encoding='unicode') + self.assertEqual(stream.getvalue(), '''''') + + def test_read_from_bytesio(self): + tree = ET.ElementTree() + raw = io.BytesIO(b'''''') + tree.parse(raw) + self.assertEqual(tree.getroot().tag, 'site') + + def test_write_to_bytesio(self): + tree = ET.ElementTree(ET.XML('''''')) + raw = io.BytesIO() + tree.write(raw) + self.assertEqual(raw.getvalue(), b'''''') + + class dummy: + pass + + def test_read_from_user_text_reader(self): + stream = io.StringIO('''''') + reader = self.dummy() + reader.read = stream.read + tree = ET.ElementTree() + tree.parse(reader) self.assertEqual(tree.getroot().tag, 'site') + def test_write_to_user_text_writer(self): + tree = ET.ElementTree(ET.XML('''''')) + stream = io.StringIO() + writer = self.dummy() + writer.write = stream.write + tree.write(writer, encoding='unicode') + self.assertEqual(stream.getvalue(), '''''') + + def test_read_from_user_binary_reader(self): + raw = io.BytesIO(b'''''') + reader = self.dummy() + reader.read = raw.read + tree = ET.ElementTree() + tree.parse(reader) + self.assertEqual(tree.getroot().tag, 'site') + tree = ET.ElementTree() + + def test_write_to_user_binary_writer(self): + tree = ET.ElementTree(ET.XML('''''')) + raw = io.BytesIO() + writer = self.dummy() + writer.write = raw.write + tree.write(writer) + self.assertEqual(raw.getvalue(), b'''''') + + def test_write_to_user_binary_writer_with_bom(self): + tree = ET.ElementTree(ET.XML('''''')) + raw = io.BytesIO() + writer = self.dummy() + writer.write = raw.write + writer.seekable = lambda: True + writer.tell = raw.tell + tree.write(writer, encoding="utf-16") + self.assertEqual(raw.getvalue(), + '''\n''' + ''''''.encode("utf-16")) + class ParseErrorTest(unittest.TestCase): def test_subclass(self): @@ -2299,7 +2409,7 @@ def test_main(module=None): test_classes = [ ElementSlicingTest, BasicElementTest, - StringIOTest, + IOTest, ParseErrorTest, XincludeTest, ElementTreeTest, diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py index 61fe155..10bf849 100644 --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -100,6 +100,8 @@ VERSION = "1.3.0" import sys import re import warnings +import io +import contextlib from . import ElementPath @@ -792,59 +794,38 @@ class ElementTree: # None for only if not US-ASCII or UTF-8 or Unicode. None is default. def write(self, file_or_filename, - # keyword arguments encoding=None, xml_declaration=None, default_namespace=None, method=None): - # assert self._root is not None if not method: method = "xml" elif method not in _serialize: - # FIXME: raise an ImportError for c14n if ElementC14N is missing? raise ValueError("unknown method %r" % method) if not encoding: if method == "c14n": encoding = "utf-8" else: encoding = "us-ascii" - elif encoding == str: # lxml.etree compatibility. - encoding = "unicode" else: encoding = encoding.lower() - if hasattr(file_or_filename, "write"): - file = file_or_filename - else: - if encoding != "unicode": - file = open(file_or_filename, "wb") + with _get_writer(file_or_filename, encoding) as write: + if method == "xml" and (xml_declaration or + (xml_declaration is None and + encoding not in ("utf-8", "us-ascii", "unicode"))): + declared_encoding = encoding + if encoding == "unicode": + # Retrieve the default encoding for the xml declaration + import locale + declared_encoding = locale.getpreferredencoding() + write("\n" % ( + declared_encoding,)) + if method == "text": + _serialize_text(write, self._root) else: - file = open(file_or_filename, "w") - if encoding != "unicode": - def write(text): - try: - return file.write(text.encode(encoding, - "xmlcharrefreplace")) - except (TypeError, AttributeError): - _raise_serialization_error(text) - else: - write = file.write - if method == "xml" and (xml_declaration or - (xml_declaration is None and - encoding not in ("utf-8", "us-ascii", "unicode"))): - declared_encoding = encoding - if encoding == "unicode": - # Retrieve the default encoding for the xml declaration - import locale - declared_encoding = locale.getpreferredencoding() - write("\n" % declared_encoding) - if method == "text": - _serialize_text(write, self._root) - else: - qnames, namespaces = _namespaces(self._root, default_namespace) - serialize = _serialize[method] - serialize(write, self._root, qnames, namespaces) - if file_or_filename is not file: - file.close() + qnames, namespaces = _namespaces(self._root, default_namespace) + serialize = _serialize[method] + serialize(write, self._root, qnames, namespaces) def write_c14n(self, file): # lxml.etree compatibility. use output method instead @@ -853,6 +834,58 @@ class ElementTree: # -------------------------------------------------------------------- # serialization support +@contextlib.contextmanager +def _get_writer(file_or_filename, encoding): + # returns text write method and release all resourses after using + try: + write = file_or_filename.write + except AttributeError: + # file_or_filename is a file name + if encoding == "unicode": + file = open(file_or_filename, "w") + else: + file = open(file_or_filename, "w", encoding=encoding, + errors="xmlcharrefreplace") + with file: + yield file.write + else: + # file_or_filename is a file-like object + # encoding determines if it is a text or binary writer + if encoding == "unicode": + # use a text writer as is + yield write + else: + # wrap a binary writer with TextIOWrapper + with contextlib.ExitStack() as stack: + if isinstance(file_or_filename, io.BufferedIOBase): + file = file_or_filename + elif isinstance(file_or_filename, io.RawIOBase): + file = io.BufferedWriter(file_or_filename) + # Keep the original file open when the BufferedWriter is + # destroyed + stack.callback(file.detach) + else: + # This is to handle passed objects that aren't in the + # IOBase hierarchy, but just have a write method + file = io.BufferedIOBase() + file.writable = lambda: True + file.write = write + try: + # TextIOWrapper uses this methods to determine + # if BOM (for UTF-16, etc) should be added + file.seekable = file_or_filename.seekable + file.tell = file_or_filename.tell + except AttributeError: + pass + file = io.TextIOWrapper(file, + encoding=encoding, + errors="xmlcharrefreplace", + newline="\n") + # Keep the original file open when the TextIOWrapper is + # destroyed + stack.callback(file.detach) + yield file.write + def _namespaces(elem, default_namespace=None): # identify namespaces used in this tree @@ -1134,22 +1167,13 @@ def _escape_attrib_html(text): # @defreturn string def tostring(element, encoding=None, method=None): - class dummy: - pass - data = [] - file = dummy() - file.write = data.append - ElementTree(element).write(file, encoding, method=method) - if encoding in (str, "unicode"): - return "".join(data) - else: - return b"".join(data) + stream = io.StringIO() if encoding == 'unicode' else io.BytesIO() + ElementTree(element).write(stream, encoding, method=method) + return stream.getvalue() ## # Generates a string representation of an XML element, including all -# subelements. If encoding is False, the string is returned as a -# sequence of string fragments; otherwise it is a sequence of -# bytestrings. +# subelements. # # @param element An Element instance. # @keyparam encoding Optional output encoding (default is US-ASCII). @@ -1161,13 +1185,15 @@ def tostring(element, encoding=None, method=None): # @since 1.3 def tostringlist(element, encoding=None, method=None): - class dummy: - pass data = [] - file = dummy() - file.write = data.append - ElementTree(element).write(file, encoding, method=method) - # FIXME: merge small fragments into larger parts + class DataStream(io.BufferedIOBase): + def writable(self): + return True + + def write(self, b): + data.append(b) + + ElementTree(element).write(DataStream(), encoding, method=method) return data ## -- cgit v0.12