summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorEli Bendersky <eliben@gmail.com>2012-07-15 03:02:22 (GMT)
committerEli Bendersky <eliben@gmail.com>2012-07-15 03:02:22 (GMT)
commit00f402bfcbe3245f9c62f86376fc77bb9e7de639 (patch)
treec5035e1c4af4be283479aca143ba687d74d19c0f /Lib
parent1191709b1379661a15287a2c6ac8263f23655f73 (diff)
downloadcpython-00f402bfcbe3245f9c62f86376fc77bb9e7de639.zip
cpython-00f402bfcbe3245f9c62f86376fc77bb9e7de639.tar.gz
cpython-00f402bfcbe3245f9c62f86376fc77bb9e7de639.tar.bz2
Close #1767933: Badly formed XML using etree and utf-16. Patch by Serhiy Storchaka, with some minor fixes by me
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_xml_etree.py240
-rw-r--r--Lib/xml/etree/ElementTree.py138
2 files changed, 257 insertions, 121 deletions
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index c1fc955..d90f978 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -21,7 +21,7 @@ import unittest
import weakref
from test import support
-from test.support import findfile, import_fresh_module, gc_collect
+from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect
pyET = None
ET = None
@@ -888,65 +888,6 @@ def check_encoding(encoding):
"""
ET.XML("<?xml version='1.0' encoding='%s'?><xml />" % encoding)
-def encoding():
- r"""
- Test encoding issues.
-
- >>> elem = ET.Element("tag")
- >>> elem.text = "abc"
- >>> serialize(elem)
- '<tag>abc</tag>'
- >>> serialize(elem, encoding="utf-8")
- b'<tag>abc</tag>'
- >>> serialize(elem, encoding="us-ascii")
- b'<tag>abc</tag>'
- >>> serialize(elem, encoding="iso-8859-1")
- b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>abc</tag>"
-
- >>> elem.text = "<&\"\'>"
- >>> serialize(elem)
- '<tag>&lt;&amp;"\'&gt;</tag>'
- >>> serialize(elem, encoding="utf-8")
- b'<tag>&lt;&amp;"\'&gt;</tag>'
- >>> serialize(elem, encoding="us-ascii") # cdata characters
- b'<tag>&lt;&amp;"\'&gt;</tag>'
- >>> serialize(elem, encoding="iso-8859-1")
- b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag>&lt;&amp;"\'&gt;</tag>'
-
- >>> elem.attrib["key"] = "<&\"\'>"
- >>> elem.text = None
- >>> serialize(elem)
- '<tag key="&lt;&amp;&quot;\'&gt;" />'
- >>> serialize(elem, encoding="utf-8")
- b'<tag key="&lt;&amp;&quot;\'&gt;" />'
- >>> serialize(elem, encoding="us-ascii")
- b'<tag key="&lt;&amp;&quot;\'&gt;" />'
- >>> serialize(elem, encoding="iso-8859-1")
- b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="&lt;&amp;&quot;\'&gt;" />'
-
- >>> elem.text = '\xe5\xf6\xf6<>'
- >>> elem.attrib.clear()
- >>> serialize(elem)
- '<tag>\xe5\xf6\xf6&lt;&gt;</tag>'
- >>> serialize(elem, encoding="utf-8")
- b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;</tag>'
- >>> serialize(elem, encoding="us-ascii")
- b'<tag>&#229;&#246;&#246;&lt;&gt;</tag>'
- >>> serialize(elem, encoding="iso-8859-1")
- b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>\xe5\xf6\xf6&lt;&gt;</tag>"
-
- >>> elem.attrib["key"] = '\xe5\xf6\xf6<>'
- >>> elem.text = None
- >>> serialize(elem)
- '<tag key="\xe5\xf6\xf6&lt;&gt;" />'
- >>> serialize(elem, encoding="utf-8")
- b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;" />'
- >>> serialize(elem, encoding="us-ascii")
- b'<tag key="&#229;&#246;&#246;&lt;&gt;" />'
- >>> serialize(elem, encoding="iso-8859-1")
- b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="\xe5\xf6\xf6&lt;&gt;" />'
- """
-
def methods():
r"""
Test serialization methods.
@@ -2166,16 +2107,185 @@ class ElementSlicingTest(unittest.TestCase):
self.assertEqual(self._subelem_tags(e), ['a1'])
-class StringIOTest(unittest.TestCase):
+class IOTest(unittest.TestCase):
+ def tearDown(self):
+ unlink(TESTFN)
+
+ def test_encoding(self):
+ # Test encoding issues.
+ elem = ET.Element("tag")
+ elem.text = "abc"
+ self.assertEqual(serialize(elem), '<tag>abc</tag>')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'<tag>abc</tag>')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'<tag>abc</tag>')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("<?xml version='1.0' encoding='%s'?>\n"
+ "<tag>abc</tag>" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.text = "<&\"\'>"
+ self.assertEqual(serialize(elem), '<tag>&lt;&amp;"\'&gt;</tag>')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'<tag>&lt;&amp;"\'&gt;</tag>')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'<tag>&lt;&amp;"\'&gt;</tag>')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("<?xml version='1.0' encoding='%s'?>\n"
+ "<tag>&lt;&amp;\"'&gt;</tag>" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.attrib["key"] = "<&\"\'>"
+ self.assertEqual(serialize(elem), '<tag key="&lt;&amp;&quot;\'&gt;" />')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'<tag key="&lt;&amp;&quot;\'&gt;" />')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'<tag key="&lt;&amp;&quot;\'&gt;" />')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("<?xml version='1.0' encoding='%s'?>\n"
+ "<tag key=\"&lt;&amp;&quot;'&gt;\" />" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.text = '\xe5\xf6\xf6<>'
+ self.assertEqual(serialize(elem), '<tag>\xe5\xf6\xf6&lt;&gt;</tag>')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;</tag>')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'<tag>&#229;&#246;&#246;&lt;&gt;</tag>')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("<?xml version='1.0' encoding='%s'?>\n"
+ "<tag>åöö&lt;&gt;</tag>" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.attrib["key"] = '\xe5\xf6\xf6<>'
+ self.assertEqual(serialize(elem), '<tag key="\xe5\xf6\xf6&lt;&gt;" />')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;" />')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'<tag key="&#229;&#246;&#246;&lt;&gt;" />')
+ for enc in ("iso-8859-1", "utf-16", "utf-16le", "utf-16be", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("<?xml version='1.0' encoding='%s'?>\n"
+ "<tag key=\"åöö&lt;&gt;\" />" % enc).encode(enc))
+
+ def test_write_to_filename(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ tree.write(TESTFN)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''<site />''')
+
+ def test_write_to_text_file(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ with open(TESTFN, 'w', encoding='utf-8') as f:
+ tree.write(f, encoding='unicode')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''<site />''')
+
+ def test_write_to_binary_file(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ with open(TESTFN, 'wb') as f:
+ tree.write(f)
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''<site />''')
+
+ def test_write_to_binary_file_with_bom(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ # test BOM writing to buffered file
+ with open(TESTFN, 'wb') as f:
+ tree.write(f, encoding='utf-16')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(),
+ '''<?xml version='1.0' encoding='utf-16'?>\n'''
+ '''<site />'''.encode("utf-16"))
+ # test BOM writing to non-buffered file
+ with open(TESTFN, 'wb', buffering=0) as f:
+ tree.write(f, encoding='utf-16')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(),
+ '''<?xml version='1.0' encoding='utf-16'?>\n'''
+ '''<site />'''.encode("utf-16"))
+
def test_read_from_stringio(self):
tree = ET.ElementTree()
- stream = io.StringIO()
- stream.write('''<?xml version="1.0"?><site></site>''')
- stream.seek(0)
+ stream = io.StringIO('''<?xml version="1.0"?><site></site>''')
tree.parse(stream)
+ self.assertEqual(tree.getroot().tag, 'site')
+ def test_write_to_stringio(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ stream = io.StringIO()
+ tree.write(stream, encoding='unicode')
+ self.assertEqual(stream.getvalue(), '''<site />''')
+
+ def test_read_from_bytesio(self):
+ tree = ET.ElementTree()
+ raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''')
+ tree.parse(raw)
+ self.assertEqual(tree.getroot().tag, 'site')
+
+ def test_write_to_bytesio(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ raw = io.BytesIO()
+ tree.write(raw)
+ self.assertEqual(raw.getvalue(), b'''<site />''')
+
+ class dummy:
+ pass
+
+ def test_read_from_user_text_reader(self):
+ stream = io.StringIO('''<?xml version="1.0"?><site></site>''')
+ reader = self.dummy()
+ reader.read = stream.read
+ tree = ET.ElementTree()
+ tree.parse(reader)
self.assertEqual(tree.getroot().tag, 'site')
+ def test_write_to_user_text_writer(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ stream = io.StringIO()
+ writer = self.dummy()
+ writer.write = stream.write
+ tree.write(writer, encoding='unicode')
+ self.assertEqual(stream.getvalue(), '''<site />''')
+
+ def test_read_from_user_binary_reader(self):
+ raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''')
+ reader = self.dummy()
+ reader.read = raw.read
+ tree = ET.ElementTree()
+ tree.parse(reader)
+ self.assertEqual(tree.getroot().tag, 'site')
+ tree = ET.ElementTree()
+
+ def test_write_to_user_binary_writer(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ raw = io.BytesIO()
+ writer = self.dummy()
+ writer.write = raw.write
+ tree.write(writer)
+ self.assertEqual(raw.getvalue(), b'''<site />''')
+
+ def test_write_to_user_binary_writer_with_bom(self):
+ tree = ET.ElementTree(ET.XML('''<site />'''))
+ raw = io.BytesIO()
+ writer = self.dummy()
+ writer.write = raw.write
+ writer.seekable = lambda: True
+ writer.tell = raw.tell
+ tree.write(writer, encoding="utf-16")
+ self.assertEqual(raw.getvalue(),
+ '''<?xml version='1.0' encoding='utf-16'?>\n'''
+ '''<site />'''.encode("utf-16"))
+
class ParseErrorTest(unittest.TestCase):
def test_subclass(self):
@@ -2299,7 +2409,7 @@ def test_main(module=None):
test_classes = [
ElementSlicingTest,
BasicElementTest,
- StringIOTest,
+ IOTest,
ParseErrorTest,
XincludeTest,
ElementTreeTest,
diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py
index 61fe155..10bf849 100644
--- a/Lib/xml/etree/ElementTree.py
+++ b/Lib/xml/etree/ElementTree.py
@@ -100,6 +100,8 @@ VERSION = "1.3.0"
import sys
import re
import warnings
+import io
+import contextlib
from . import ElementPath
@@ -792,59 +794,38 @@ class ElementTree:
# None for only if not US-ASCII or UTF-8 or Unicode. None is default.
def write(self, file_or_filename,
- # keyword arguments
encoding=None,
xml_declaration=None,
default_namespace=None,
method=None):
- # assert self._root is not None
if not method:
method = "xml"
elif method not in _serialize:
- # FIXME: raise an ImportError for c14n if ElementC14N is missing?
raise ValueError("unknown method %r" % method)
if not encoding:
if method == "c14n":
encoding = "utf-8"
else:
encoding = "us-ascii"
- elif encoding == str: # lxml.etree compatibility.
- encoding = "unicode"
else:
encoding = encoding.lower()
- if hasattr(file_or_filename, "write"):
- file = file_or_filename
- else:
- if encoding != "unicode":
- file = open(file_or_filename, "wb")
+ with _get_writer(file_or_filename, encoding) as write:
+ if method == "xml" and (xml_declaration or
+ (xml_declaration is None and
+ encoding not in ("utf-8", "us-ascii", "unicode"))):
+ declared_encoding = encoding
+ if encoding == "unicode":
+ # Retrieve the default encoding for the xml declaration
+ import locale
+ declared_encoding = locale.getpreferredencoding()
+ write("<?xml version='1.0' encoding='%s'?>\n" % (
+ declared_encoding,))
+ if method == "text":
+ _serialize_text(write, self._root)
else:
- file = open(file_or_filename, "w")
- if encoding != "unicode":
- def write(text):
- try:
- return file.write(text.encode(encoding,
- "xmlcharrefreplace"))
- except (TypeError, AttributeError):
- _raise_serialization_error(text)
- else:
- write = file.write
- if method == "xml" and (xml_declaration or
- (xml_declaration is None and
- encoding not in ("utf-8", "us-ascii", "unicode"))):
- declared_encoding = encoding
- if encoding == "unicode":
- # Retrieve the default encoding for the xml declaration
- import locale
- declared_encoding = locale.getpreferredencoding()
- write("<?xml version='1.0' encoding='%s'?>\n" % declared_encoding)
- if method == "text":
- _serialize_text(write, self._root)
- else:
- qnames, namespaces = _namespaces(self._root, default_namespace)
- serialize = _serialize[method]
- serialize(write, self._root, qnames, namespaces)
- if file_or_filename is not file:
- file.close()
+ qnames, namespaces = _namespaces(self._root, default_namespace)
+ serialize = _serialize[method]
+ serialize(write, self._root, qnames, namespaces)
def write_c14n(self, file):
# lxml.etree compatibility. use output method instead
@@ -853,6 +834,58 @@ class ElementTree:
# --------------------------------------------------------------------
# serialization support
+@contextlib.contextmanager
+def _get_writer(file_or_filename, encoding):
+ # returns text write method and release all resourses after using
+ try:
+ write = file_or_filename.write
+ except AttributeError:
+ # file_or_filename is a file name
+ if encoding == "unicode":
+ file = open(file_or_filename, "w")
+ else:
+ file = open(file_or_filename, "w", encoding=encoding,
+ errors="xmlcharrefreplace")
+ with file:
+ yield file.write
+ else:
+ # file_or_filename is a file-like object
+ # encoding determines if it is a text or binary writer
+ if encoding == "unicode":
+ # use a text writer as is
+ yield write
+ else:
+ # wrap a binary writer with TextIOWrapper
+ with contextlib.ExitStack() as stack:
+ if isinstance(file_or_filename, io.BufferedIOBase):
+ file = file_or_filename
+ elif isinstance(file_or_filename, io.RawIOBase):
+ file = io.BufferedWriter(file_or_filename)
+ # Keep the original file open when the BufferedWriter is
+ # destroyed
+ stack.callback(file.detach)
+ else:
+ # This is to handle passed objects that aren't in the
+ # IOBase hierarchy, but just have a write method
+ file = io.BufferedIOBase()
+ file.writable = lambda: True
+ file.write = write
+ try:
+ # TextIOWrapper uses this methods to determine
+ # if BOM (for UTF-16, etc) should be added
+ file.seekable = file_or_filename.seekable
+ file.tell = file_or_filename.tell
+ except AttributeError:
+ pass
+ file = io.TextIOWrapper(file,
+ encoding=encoding,
+ errors="xmlcharrefreplace",
+ newline="\n")
+ # Keep the original file open when the TextIOWrapper is
+ # destroyed
+ stack.callback(file.detach)
+ yield file.write
+
def _namespaces(elem, default_namespace=None):
# identify namespaces used in this tree
@@ -1134,22 +1167,13 @@ def _escape_attrib_html(text):
# @defreturn string
def tostring(element, encoding=None, method=None):
- class dummy:
- pass
- data = []
- file = dummy()
- file.write = data.append
- ElementTree(element).write(file, encoding, method=method)
- if encoding in (str, "unicode"):
- return "".join(data)
- else:
- return b"".join(data)
+ stream = io.StringIO() if encoding == 'unicode' else io.BytesIO()
+ ElementTree(element).write(stream, encoding, method=method)
+ return stream.getvalue()
##
# Generates a string representation of an XML element, including all
-# subelements. If encoding is False, the string is returned as a
-# sequence of string fragments; otherwise it is a sequence of
-# bytestrings.
+# subelements.
#
# @param element An Element instance.
# @keyparam encoding Optional output encoding (default is US-ASCII).
@@ -1161,13 +1185,15 @@ def tostring(element, encoding=None, method=None):
# @since 1.3
def tostringlist(element, encoding=None, method=None):
- class dummy:
- pass
data = []
- file = dummy()
- file.write = data.append
- ElementTree(element).write(file, encoding, method=method)
- # FIXME: merge small fragments into larger parts
+ class DataStream(io.BufferedIOBase):
+ def writable(self):
+ return True
+
+ def write(self, b):
+ data.append(b)
+
+ ElementTree(element).write(DataStream(), encoding, method=method)
return data
##