summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEli Bendersky <eliben@gmail.com>2012-03-16 03:53:30 (GMT)
committerEli Bendersky <eliben@gmail.com>2012-03-16 03:53:30 (GMT)
commitf996e775eaf22e6a6465e640a6de46ea74011bc0 (patch)
tree933b3ea33b18bc2101b883555ac587a221c02ee9
parente53d977e8077759e8123da3da563e6b73392ed8b (diff)
downloadcpython-f996e775eaf22e6a6465e640a6de46ea74011bc0.zip
cpython-f996e775eaf22e6a6465e640a6de46ea74011bc0.tar.gz
cpython-f996e775eaf22e6a6465e640a6de46ea74011bc0.tar.bz2
Closes Issue #14246: _elementtree parser will now handle io.StringIO
-rw-r--r--Lib/test/test_xml_etree.py14
-rw-r--r--Modules/_elementtree.c23
2 files changed, 36 insertions, 1 deletions
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index fedf550..97fc690 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -16,6 +16,7 @@
import sys
import html
+import io
import unittest
from test import support
@@ -2026,6 +2027,18 @@ class ElementSlicingTest(unittest.TestCase):
del e[::2]
self.assertEqual(self._subelem_tags(e), ['a1'])
+
+class StringIOTest(unittest.TestCase):
+ def test_read_from_stringio(self):
+ tree = ET.ElementTree()
+ stream = io.StringIO()
+ stream.write('''<?xml version="1.0"?><site></site>''')
+ stream.seek(0)
+ tree.parse(stream)
+
+ self.assertEqual(tree.getroot().tag, 'site')
+
+
# --------------------------------------------------------------------
@@ -2077,6 +2090,7 @@ def test_main(module=pyET):
test_classes = [
ElementSlicingTest,
+ StringIOTest,
ElementTreeTest,
TreeBuilderTest]
if module is pyET:
diff --git a/Modules/_elementtree.c b/Modules/_elementtree.c
index ba37cd7..99935b9 100644
--- a/Modules/_elementtree.c
+++ b/Modules/_elementtree.c
@@ -2682,6 +2682,7 @@ xmlparser_parse(XMLParserObject* self, PyObject* args)
PyObject* reader;
PyObject* buffer;
+ PyObject* temp;
PyObject* res;
PyObject* fileobj;
@@ -2703,7 +2704,27 @@ xmlparser_parse(XMLParserObject* self, PyObject* args)
return NULL;
}
- if (!PyBytes_CheckExact(buffer) || PyBytes_GET_SIZE(buffer) == 0) {
+ if (PyUnicode_CheckExact(buffer)) {
+ /* A unicode object is encoded into bytes using UTF-8 */
+ if (PyUnicode_GET_SIZE(buffer) == 0) {
+ Py_DECREF(buffer);
+ break;
+ }
+ temp = PyUnicode_AsEncodedString(buffer, "utf-8", "surrogatepass");
+ if (!temp) {
+ /* Propagate exception from PyUnicode_AsEncodedString */
+ Py_DECREF(buffer);
+ Py_DECREF(reader);
+ return NULL;
+ }
+
+ /* Here we no longer need the original buffer since it contains
+ * unicode. Make it point to the encoded bytes object.
+ */
+ Py_DECREF(buffer);
+ buffer = temp;
+ }
+ else if (!PyBytes_CheckExact(buffer) || PyBytes_GET_SIZE(buffer) == 0) {
Py_DECREF(buffer);
break;
}