summaryrefslogtreecommitdiffstats
path: root/Lib/xml/sax/saxutils.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/xml/sax/saxutils.py')
-rw-r--r--Lib/xml/sax/saxutils.py67
1 files changed, 46 insertions, 21 deletions
diff --git a/Lib/xml/sax/saxutils.py b/Lib/xml/sax/saxutils.py
index 625bc12..a62183a 100644
--- a/Lib/xml/sax/saxutils.py
+++ b/Lib/xml/sax/saxutils.py
@@ -4,18 +4,10 @@ convenience of application and driver writers.
"""
import os, urllib.parse, urllib.request
+import io
from . import handler
from . import xmlreader
-# See whether the xmlcharrefreplace error handler is
-# supported
-try:
- from codecs import xmlcharrefreplace_errors
- _error_handling = "xmlcharrefreplace"
- del xmlcharrefreplace_errors
-except ImportError:
- _error_handling = "strict"
-
def __dict_replace(s, d):
"""Replace substrings of a string using a dictionary."""
for key, value in d.items():
@@ -76,14 +68,50 @@ def quoteattr(data, entities={}):
return data
+def _gettextwriter(out, encoding):
+ if out is None:
+ import sys
+ return sys.stdout
+
+ if isinstance(out, io.TextIOBase):
+ # use a text writer as is
+ return out
+
+ # wrap a binary writer with TextIOWrapper
+ if isinstance(out, io.RawIOBase):
+ # Keep the original file open when the TextIOWrapper is
+ # destroyed
+ class _wrapper:
+ __class__ = out.__class__
+ def __getattr__(self, name):
+ return getattr(out, name)
+ buffer = _wrapper()
+ buffer.close = lambda: None
+ else:
+ # This is to handle passed objects that aren't in the
+ # IOBase hierarchy, but just have a write method
+ buffer = io.BufferedIOBase()
+ buffer.writable = lambda: True
+ buffer.write = out.write
+ try:
+ # TextIOWrapper uses this methods to determine
+ # if BOM (for UTF-16, etc) should be added
+ buffer.seekable = out.seekable
+ buffer.tell = out.tell
+ except AttributeError:
+ pass
+ return io.TextIOWrapper(buffer, encoding=encoding,
+ errors='xmlcharrefreplace',
+ newline='\n',
+ write_through=True)
+
class XMLGenerator(handler.ContentHandler):
def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False):
- if out is None:
- import sys
- out = sys.stdout
handler.ContentHandler.__init__(self)
- self._out = out
+ out = _gettextwriter(out, encoding)
+ self._write = out.write
+ self._flush = out.flush
self._ns_contexts = [{}] # contains uri -> prefix dicts
self._current_context = self._ns_contexts[-1]
self._undeclared_ns_maps = []
@@ -91,12 +119,6 @@ class XMLGenerator(handler.ContentHandler):
self._short_empty_elements = short_empty_elements
self._pending_start_element = False
- def _write(self, text):
- if isinstance(text, str):
- self._out.write(text)
- else:
- self._out.write(text.encode(self._encoding, _error_handling))
-
def _qname(self, name):
"""Builds a qualified name from a (ns_url, localname) pair"""
if name[0]:
@@ -125,6 +147,9 @@ class XMLGenerator(handler.ContentHandler):
self._write('<?xml version="1.0" encoding="%s"?>\n' %
self._encoding)
+ def endDocument(self):
+ self._flush()
+
def startPrefixMapping(self, prefix, uri):
self._ns_contexts.append(self._current_context.copy())
self._current_context[uri] = prefix
@@ -157,9 +182,9 @@ class XMLGenerator(handler.ContentHandler):
for prefix, uri in self._undeclared_ns_maps:
if prefix:
- self._out.write(' xmlns:%s="%s"' % (prefix, uri))
+ self._write(' xmlns:%s="%s"' % (prefix, uri))
else:
- self._out.write(' xmlns="%s"' % uri)
+ self._write(' xmlns="%s"' % uri)
self._undeclared_ns_maps = []
for (name, value) in attrs.items():