diff options
Diffstat (limited to 'Lib/xml/sax/saxutils.py')
-rw-r--r-- | Lib/xml/sax/saxutils.py | 67 |
1 files changed, 46 insertions, 21 deletions
diff --git a/Lib/xml/sax/saxutils.py b/Lib/xml/sax/saxutils.py index 625bc12..a62183a 100644 --- a/Lib/xml/sax/saxutils.py +++ b/Lib/xml/sax/saxutils.py @@ -4,18 +4,10 @@ convenience of application and driver writers. """ import os, urllib.parse, urllib.request +import io from . import handler from . import xmlreader -# See whether the xmlcharrefreplace error handler is -# supported -try: - from codecs import xmlcharrefreplace_errors - _error_handling = "xmlcharrefreplace" - del xmlcharrefreplace_errors -except ImportError: - _error_handling = "strict" - def __dict_replace(s, d): """Replace substrings of a string using a dictionary.""" for key, value in d.items(): @@ -76,14 +68,50 @@ def quoteattr(data, entities={}): return data +def _gettextwriter(out, encoding): + if out is None: + import sys + return sys.stdout + + if isinstance(out, io.TextIOBase): + # use a text writer as is + return out + + # wrap a binary writer with TextIOWrapper + if isinstance(out, io.RawIOBase): + # Keep the original file open when the TextIOWrapper is + # destroyed + class _wrapper: + __class__ = out.__class__ + def __getattr__(self, name): + return getattr(out, name) + buffer = _wrapper() + buffer.close = lambda: None + else: + # This is to handle passed objects that aren't in the + # IOBase hierarchy, but just have a write method + buffer = io.BufferedIOBase() + buffer.writable = lambda: True + buffer.write = out.write + try: + # TextIOWrapper uses this methods to determine + # if BOM (for UTF-16, etc) should be added + buffer.seekable = out.seekable + buffer.tell = out.tell + except AttributeError: + pass + return io.TextIOWrapper(buffer, encoding=encoding, + errors='xmlcharrefreplace', + newline='\n', + write_through=True) + class XMLGenerator(handler.ContentHandler): def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False): - if out is None: - import sys - out = sys.stdout handler.ContentHandler.__init__(self) - self._out = out + out = _gettextwriter(out, encoding) + self._write = out.write + self._flush = out.flush self._ns_contexts = [{}] # contains uri -> prefix dicts self._current_context = self._ns_contexts[-1] self._undeclared_ns_maps = [] @@ -91,12 +119,6 @@ class XMLGenerator(handler.ContentHandler): self._short_empty_elements = short_empty_elements self._pending_start_element = False - def _write(self, text): - if isinstance(text, str): - self._out.write(text) - else: - self._out.write(text.encode(self._encoding, _error_handling)) - def _qname(self, name): """Builds a qualified name from a (ns_url, localname) pair""" if name[0]: @@ -125,6 +147,9 @@ class XMLGenerator(handler.ContentHandler): self._write('<?xml version="1.0" encoding="%s"?>\n' % self._encoding) + def endDocument(self): + self._flush() + def startPrefixMapping(self, prefix, uri): self._ns_contexts.append(self._current_context.copy()) self._current_context[uri] = prefix @@ -157,9 +182,9 @@ class XMLGenerator(handler.ContentHandler): for prefix, uri in self._undeclared_ns_maps: if prefix: - self._out.write(' xmlns:%s="%s"' % (prefix, uri)) + self._write(' xmlns:%s="%s"' % (prefix, uri)) else: - self._out.write(' xmlns="%s"' % uri) + self._write(' xmlns="%s"' % uri) self._undeclared_ns_maps = [] for (name, value) in attrs.items(): |