diff options
Diffstat (limited to 'Lib/xml/etree/ElementTree.py')
-rw-r--r-- | Lib/xml/etree/ElementTree.py | 380 |
1 files changed, 240 insertions, 140 deletions
diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py index ff8ff7d..e8e309c 100644 --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -68,8 +68,9 @@ __all__ = [ "tostring", "tostringlist", "TreeBuilder", "VERSION", - "XML", + "XML", "XMLID", "XMLParser", "XMLTreeBuilder", + "register_namespace", ] VERSION = "1.3.0" @@ -99,34 +100,11 @@ VERSION = "1.3.0" import sys import re import warnings +import io +import contextlib +from . import ElementPath -class _SimpleElementPath: - # emulate pre-1.2 find/findtext/findall behaviour - def find(self, element, tag, namespaces=None): - for elem in element: - if elem.tag == tag: - return elem - return None - def findtext(self, element, tag, default=None, namespaces=None): - elem = self.find(element, tag) - if elem is None: - return default - return elem.text or "" - def iterfind(self, element, tag, namespaces=None): - if tag[:3] == ".//": - for elem in element.iter(tag[3:]): - yield elem - for elem in element: - if elem.tag == tag: - yield elem - def findall(self, element, tag, namespaces=None): - return list(self.iterfind(element, tag, namespaces)) - -try: - from . import ElementPath -except ImportError: - ElementPath = _SimpleElementPath() ## # Parser error. This is a subclass of <b>SyntaxError</b>. @@ -148,9 +126,9 @@ class ParseError(SyntaxError): # @defreturn flag def iselement(element): - # FIXME: not sure about this; might be a better idea to look - # for tag/attrib/text attributes - return isinstance(element, Element) or hasattr(element, "tag") + # FIXME: not sure about this; + # isinstance(element, Element) or look for tag/attrib/text attributes + return hasattr(element, 'tag') ## # Element class. This class defines the Element interface, and @@ -205,6 +183,9 @@ class Element: # constructor def __init__(self, tag, attrib={}, **extra): + if not isinstance(attrib, dict): + raise TypeError("attrib must be dict, not %s" % ( + attrib.__class__.__name__,)) attrib = attrib.copy() attrib.update(extra) self.tag = tag @@ -298,7 +279,7 @@ class Element: # @param element The element to add. def append(self, element): - # assert iselement(element) + self._assert_is_element(element) self._children.append(element) ## @@ -308,8 +289,8 @@ class Element: # @since 1.3 def extend(self, elements): - # for element in elements: - # assert iselement(element) + for element in elements: + self._assert_is_element(element) self._children.extend(elements) ## @@ -318,9 +299,15 @@ class Element: # @param index Where to insert the new subelement. def insert(self, index, element): - # assert iselement(element) + self._assert_is_element(element) self._children.insert(index, element) + def _assert_is_element(self, e): + # Need to refer to the actual Python implementation, not the + # shadowing C implementation. + if not isinstance(e, _Element): + raise TypeError('expected an Element, not %s' % type(e).__name__) + ## # Removes a matching subelement. Unlike the <b>find</b> methods, # this method compares elements based on identity, not on tag @@ -810,59 +797,38 @@ class ElementTree: # "c14n"; default is "xml"). def write(self, file_or_filename, - # keyword arguments encoding=None, xml_declaration=None, default_namespace=None, method=None): - # assert self._root is not None if not method: method = "xml" elif method not in _serialize: - # FIXME: raise an ImportError for c14n if ElementC14N is missing? raise ValueError("unknown method %r" % method) if not encoding: if method == "c14n": encoding = "utf-8" else: encoding = "us-ascii" - elif encoding == str: # lxml.etree compatibility. - encoding = "unicode" else: encoding = encoding.lower() - if hasattr(file_or_filename, "write"): - file = file_or_filename - else: - if encoding != "unicode": - file = open(file_or_filename, "wb") + with _get_writer(file_or_filename, encoding) as write: + if method == "xml" and (xml_declaration or + (xml_declaration is None and + encoding not in ("utf-8", "us-ascii", "unicode"))): + declared_encoding = encoding + if encoding == "unicode": + # Retrieve the default encoding for the xml declaration + import locale + declared_encoding = locale.getpreferredencoding() + write("<?xml version='1.0' encoding='%s'?>\n" % ( + declared_encoding,)) + if method == "text": + _serialize_text(write, self._root) else: - file = open(file_or_filename, "w") - if encoding != "unicode": - def write(text): - try: - return file.write(text.encode(encoding, - "xmlcharrefreplace")) - except (TypeError, AttributeError): - _raise_serialization_error(text) - else: - write = file.write - if method == "xml" and (xml_declaration or - (xml_declaration is None and - encoding not in ("utf-8", "us-ascii", "unicode"))): - declared_encoding = encoding - if encoding == "unicode": - # Retrieve the default encoding for the xml declaration - import locale - declared_encoding = locale.getpreferredencoding() - write("<?xml version='1.0' encoding='%s'?>\n" % declared_encoding) - if method == "text": - _serialize_text(write, self._root) - else: - qnames, namespaces = _namespaces(self._root, default_namespace) - serialize = _serialize[method] - serialize(write, self._root, qnames, namespaces) - if file_or_filename is not file: - file.close() + qnames, namespaces = _namespaces(self._root, default_namespace) + serialize = _serialize[method] + serialize(write, self._root, qnames, namespaces) def write_c14n(self, file): # lxml.etree compatibility. use output method instead @@ -871,6 +837,58 @@ class ElementTree: # -------------------------------------------------------------------- # serialization support +@contextlib.contextmanager +def _get_writer(file_or_filename, encoding): + # returns text write method and release all resourses after using + try: + write = file_or_filename.write + except AttributeError: + # file_or_filename is a file name + if encoding == "unicode": + file = open(file_or_filename, "w") + else: + file = open(file_or_filename, "w", encoding=encoding, + errors="xmlcharrefreplace") + with file: + yield file.write + else: + # file_or_filename is a file-like object + # encoding determines if it is a text or binary writer + if encoding == "unicode": + # use a text writer as is + yield write + else: + # wrap a binary writer with TextIOWrapper + with contextlib.ExitStack() as stack: + if isinstance(file_or_filename, io.BufferedIOBase): + file = file_or_filename + elif isinstance(file_or_filename, io.RawIOBase): + file = io.BufferedWriter(file_or_filename) + # Keep the original file open when the BufferedWriter is + # destroyed + stack.callback(file.detach) + else: + # This is to handle passed objects that aren't in the + # IOBase hierarchy, but just have a write method + file = io.BufferedIOBase() + file.writable = lambda: True + file.write = write + try: + # TextIOWrapper uses this methods to determine + # if BOM (for UTF-16, etc) should be added + file.seekable = file_or_filename.seekable + file.tell = file_or_filename.tell + except AttributeError: + pass + file = io.TextIOWrapper(file, + encoding=encoding, + errors="xmlcharrefreplace", + newline="\n") + # Keep the original file open when the TextIOWrapper is + # destroyed + stack.callback(file.detach) + yield file.write + def _namespaces(elem, default_namespace=None): # identify namespaces used in this tree @@ -910,11 +928,7 @@ def _namespaces(elem, default_namespace=None): _raise_serialization_error(qname) # populate qname and namespaces table - try: - iterate = elem.iter - except AttributeError: - iterate = elem.getiterator # cET compatibility - for elem in iterate(): + for elem in elem.iter(): tag = elem.tag if isinstance(tag, QName): if tag.text not in qnames: @@ -1086,6 +1100,8 @@ _namespace_map = { # dublin core "http://purl.org/dc/elements/1.1/": "dc", } +# For tests and troubleshooting +register_namespace._namespace_map = _namespace_map def _raise_serialization_error(text): raise TypeError( @@ -1154,22 +1170,13 @@ def _escape_attrib_html(text): # @defreturn string def tostring(element, encoding=None, method=None): - class dummy: - pass - data = [] - file = dummy() - file.write = data.append - ElementTree(element).write(file, encoding, method=method) - if encoding in (str, "unicode"): - return "".join(data) - else: - return b"".join(data) + stream = io.StringIO() if encoding == 'unicode' else io.BytesIO() + ElementTree(element).write(stream, encoding, method=method) + return stream.getvalue() ## # Generates a string representation of an XML element, including all -# subelements. If encoding is False, the string is returned as a -# sequence of string fragments; otherwise it is a sequence of -# bytestrings. +# subelements. # # @param element An Element instance. # @keyparam encoding Optional output encoding (default is US-ASCII). @@ -1180,15 +1187,29 @@ def tostring(element, encoding=None, method=None): # @defreturn sequence # @since 1.3 +class _ListDataStream(io.BufferedIOBase): + """ An auxiliary stream accumulating into a list reference + """ + def __init__(self, lst): + self.lst = lst + + def writable(self): + return True + + def seekable(self): + return True + + def write(self, b): + self.lst.append(b) + + def tell(self): + return len(self.lst) + def tostringlist(element, encoding=None, method=None): - class dummy: - pass - data = [] - file = dummy() - file.write = data.append - ElementTree(element).write(file, encoding, method=method) - # FIXME: merge small fragments into larger parts - return data + lst = [] + stream = _ListDataStream(lst) + ElementTree(element).write(stream, encoding, method=method) + return lst ## # Writes an element tree or element structure to sys.stdout. This @@ -1510,24 +1531,30 @@ class XMLParser: self.target = self._target = target self._error = expat.error self._names = {} # name memo cache - # callbacks + # main callbacks parser.DefaultHandlerExpand = self._default - parser.StartElementHandler = self._start - parser.EndElementHandler = self._end - parser.CharacterDataHandler = self._data - # optional callbacks - parser.CommentHandler = self._comment - parser.ProcessingInstructionHandler = self._pi + if hasattr(target, 'start'): + parser.StartElementHandler = self._start + if hasattr(target, 'end'): + parser.EndElementHandler = self._end + if hasattr(target, 'data'): + parser.CharacterDataHandler = target.data + # miscellaneous callbacks + if hasattr(target, 'comment'): + parser.CommentHandler = target.comment + if hasattr(target, 'pi'): + parser.ProcessingInstructionHandler = target.pi # let expat do the buffering, if supported try: - self._parser.buffer_text = 1 + parser.buffer_text = 1 except AttributeError: pass # use new-style attribute handling, if supported try: - self._parser.ordered_attributes = 1 - self._parser.specified_attributes = 1 - parser.StartElementHandler = self._start_list + parser.ordered_attributes = 1 + parser.specified_attributes = 1 + if hasattr(target, 'start'): + parser.StartElementHandler = self._start_list except AttributeError: pass self._doctype = None @@ -1571,44 +1598,29 @@ class XMLParser: attrib[fixname(attrib_in[i])] = attrib_in[i+1] return self.target.start(tag, attrib) - def _data(self, text): - return self.target.data(text) - def _end(self, tag): return self.target.end(self._fixname(tag)) - def _comment(self, data): - try: - comment = self.target.comment - except AttributeError: - pass - else: - return comment(data) - - def _pi(self, target, data): - try: - pi = self.target.pi - except AttributeError: - pass - else: - return pi(target, data) - def _default(self, text): prefix = text[:1] if prefix == "&": # deal with undefined entities try: - self.target.data(self.entity[text[1:-1]]) + data_handler = self.target.data + except AttributeError: + return + try: + data_handler(self.entity[text[1:-1]]) except KeyError: from xml.parsers import expat err = expat.error( "undefined entity %s: line %d, column %d" % - (text, self._parser.ErrorLineNumber, - self._parser.ErrorColumnNumber) + (text, self.parser.ErrorLineNumber, + self.parser.ErrorColumnNumber) ) err.code = 11 # XML_ERROR_UNDEFINED_ENTITY - err.lineno = self._parser.ErrorLineNumber - err.offset = self._parser.ErrorColumnNumber + err.lineno = self.parser.ErrorLineNumber + err.offset = self.parser.ErrorColumnNumber raise err elif prefix == "<" and text[:9] == "<!DOCTYPE": self._doctype = [] # inside a doctype declaration @@ -1626,16 +1638,16 @@ class XMLParser: type = self._doctype[1] if type == "PUBLIC" and n == 4: name, type, pubid, system = self._doctype + if pubid: + pubid = pubid[1:-1] elif type == "SYSTEM" and n == 3: name, type, system = self._doctype pubid = None else: return - if pubid: - pubid = pubid[1:-1] if hasattr(self.target, "doctype"): self.target.doctype(name, pubid, system[1:-1]) - elif self.doctype is not self._XMLParser__doctype: + elif self.doctype != self._XMLParser__doctype: # warn about deprecated call self._XMLParser__doctype(name, pubid, system[1:-1]) self.doctype(name, pubid, system[1:-1]) @@ -1666,7 +1678,7 @@ class XMLParser: def feed(self, data): try: - self._parser.Parse(data, 0) + self.parser.Parse(data, 0) except self._error as v: self._raiseerror(v) @@ -1678,12 +1690,100 @@ class XMLParser: def close(self): try: - self._parser.Parse("", 1) # end of data + self.parser.Parse("", 1) # end of data except self._error as v: self._raiseerror(v) - tree = self.target.close() - del self.target, self._parser # get rid of circular references - return tree + try: + close_handler = self.target.close + except AttributeError: + pass + else: + return close_handler() + finally: + # get rid of circular references + del self.parser, self._parser + del self.target, self._target + + +# Import the C accelerators +try: + # Element, SubElement, ParseError, TreeBuilder, XMLParser + from _elementtree import * +except ImportError: + pass +else: + # Overwrite 'ElementTree.parse' and 'iterparse' to use the C XMLParser + + class ElementTree(ElementTree): + def parse(self, source, parser=None): + close_source = False + if not hasattr(source, 'read'): + source = open(source, 'rb') + close_source = True + try: + if parser is not None: + while True: + data = source.read(65536) + if not data: + break + parser.feed(data) + self._root = parser.close() + else: + parser = XMLParser() + self._root = parser._parse(source) + return self._root + finally: + if close_source: + source.close() + + class iterparse: + root = None + def __init__(self, file, events=None): + self._close_file = False + if not hasattr(file, 'read'): + file = open(file, 'rb') + self._close_file = True + self._file = file + self._events = [] + self._index = 0 + self._error = None + self.root = self._root = None + b = TreeBuilder() + self._parser = XMLParser(b) + self._parser._setevents(self._events, events) + + def __next__(self): + while True: + try: + item = self._events[self._index] + self._index += 1 + return item + except IndexError: + pass + if self._error: + e = self._error + self._error = None + raise e + if self._parser is None: + self.root = self._root + if self._close_file: + self._file.close() + raise StopIteration + # load event buffer + del self._events[:] + self._index = 0 + data = self._file.read(16384) + if data: + try: + self._parser.feed(data) + except SyntaxError as exc: + self._error = exc + else: + self._root = self._parser.close() + self._parser = None + + def __iter__(self): + return self # compatibility XMLTreeBuilder = XMLParser |