diff options
Diffstat (limited to 'Lib/xml')
-rw-r--r-- | Lib/xml/etree/ElementTree.py | 138 |
1 files changed, 82 insertions, 56 deletions
diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py index 61fe155..10bf849 100644 --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -100,6 +100,8 @@ VERSION = "1.3.0" import sys import re import warnings +import io +import contextlib from . import ElementPath @@ -792,59 +794,38 @@ class ElementTree: # None for only if not US-ASCII or UTF-8 or Unicode. None is default. def write(self, file_or_filename, - # keyword arguments encoding=None, xml_declaration=None, default_namespace=None, method=None): - # assert self._root is not None if not method: method = "xml" elif method not in _serialize: - # FIXME: raise an ImportError for c14n if ElementC14N is missing? raise ValueError("unknown method %r" % method) if not encoding: if method == "c14n": encoding = "utf-8" else: encoding = "us-ascii" - elif encoding == str: # lxml.etree compatibility. - encoding = "unicode" else: encoding = encoding.lower() - if hasattr(file_or_filename, "write"): - file = file_or_filename - else: - if encoding != "unicode": - file = open(file_or_filename, "wb") + with _get_writer(file_or_filename, encoding) as write: + if method == "xml" and (xml_declaration or + (xml_declaration is None and + encoding not in ("utf-8", "us-ascii", "unicode"))): + declared_encoding = encoding + if encoding == "unicode": + # Retrieve the default encoding for the xml declaration + import locale + declared_encoding = locale.getpreferredencoding() + write("<?xml version='1.0' encoding='%s'?>\n" % ( + declared_encoding,)) + if method == "text": + _serialize_text(write, self._root) else: - file = open(file_or_filename, "w") - if encoding != "unicode": - def write(text): - try: - return file.write(text.encode(encoding, - "xmlcharrefreplace")) - except (TypeError, AttributeError): - _raise_serialization_error(text) - else: - write = file.write - if method == "xml" and (xml_declaration or - (xml_declaration is None and - encoding not in ("utf-8", "us-ascii", "unicode"))): - declared_encoding = encoding - if encoding == "unicode": - # Retrieve the default encoding for the xml declaration - import locale - declared_encoding = locale.getpreferredencoding() - write("<?xml version='1.0' encoding='%s'?>\n" % declared_encoding) - if method == "text": - _serialize_text(write, self._root) - else: - qnames, namespaces = _namespaces(self._root, default_namespace) - serialize = _serialize[method] - serialize(write, self._root, qnames, namespaces) - if file_or_filename is not file: - file.close() + qnames, namespaces = _namespaces(self._root, default_namespace) + serialize = _serialize[method] + serialize(write, self._root, qnames, namespaces) def write_c14n(self, file): # lxml.etree compatibility. use output method instead @@ -853,6 +834,58 @@ class ElementTree: # -------------------------------------------------------------------- # serialization support +@contextlib.contextmanager +def _get_writer(file_or_filename, encoding): + # returns text write method and release all resourses after using + try: + write = file_or_filename.write + except AttributeError: + # file_or_filename is a file name + if encoding == "unicode": + file = open(file_or_filename, "w") + else: + file = open(file_or_filename, "w", encoding=encoding, + errors="xmlcharrefreplace") + with file: + yield file.write + else: + # file_or_filename is a file-like object + # encoding determines if it is a text or binary writer + if encoding == "unicode": + # use a text writer as is + yield write + else: + # wrap a binary writer with TextIOWrapper + with contextlib.ExitStack() as stack: + if isinstance(file_or_filename, io.BufferedIOBase): + file = file_or_filename + elif isinstance(file_or_filename, io.RawIOBase): + file = io.BufferedWriter(file_or_filename) + # Keep the original file open when the BufferedWriter is + # destroyed + stack.callback(file.detach) + else: + # This is to handle passed objects that aren't in the + # IOBase hierarchy, but just have a write method + file = io.BufferedIOBase() + file.writable = lambda: True + file.write = write + try: + # TextIOWrapper uses this methods to determine + # if BOM (for UTF-16, etc) should be added + file.seekable = file_or_filename.seekable + file.tell = file_or_filename.tell + except AttributeError: + pass + file = io.TextIOWrapper(file, + encoding=encoding, + errors="xmlcharrefreplace", + newline="\n") + # Keep the original file open when the TextIOWrapper is + # destroyed + stack.callback(file.detach) + yield file.write + def _namespaces(elem, default_namespace=None): # identify namespaces used in this tree @@ -1134,22 +1167,13 @@ def _escape_attrib_html(text): # @defreturn string def tostring(element, encoding=None, method=None): - class dummy: - pass - data = [] - file = dummy() - file.write = data.append - ElementTree(element).write(file, encoding, method=method) - if encoding in (str, "unicode"): - return "".join(data) - else: - return b"".join(data) + stream = io.StringIO() if encoding == 'unicode' else io.BytesIO() + ElementTree(element).write(stream, encoding, method=method) + return stream.getvalue() ## # Generates a string representation of an XML element, including all -# subelements. If encoding is False, the string is returned as a -# sequence of string fragments; otherwise it is a sequence of -# bytestrings. +# subelements. # # @param element An Element instance. # @keyparam encoding Optional output encoding (default is US-ASCII). @@ -1161,13 +1185,15 @@ def tostring(element, encoding=None, method=None): # @since 1.3 def tostringlist(element, encoding=None, method=None): - class dummy: - pass data = [] - file = dummy() - file.write = data.append - ElementTree(element).write(file, encoding, method=method) - # FIXME: merge small fragments into larger parts + class DataStream(io.BufferedIOBase): + def writable(self): + return True + + def write(self, b): + data.append(b) + + ElementTree(element).write(DataStream(), encoding, method=method) return data ## |