xref: /aosp_15_r20/external/fonttools/Lib/fontTools/misc/etree.py (revision e1fe3e4ad2793916b15cccdc4a7da52a7e1dd0e9)
1"""Shim module exporting the same ElementTree API for lxml and
2xml.etree backends.
3
4When lxml is installed, it is automatically preferred over the built-in
5xml.etree module.
6On Python 2.7, the cElementTree module is preferred over the pure-python
7ElementTree module.
8
9Besides exporting a unified interface, this also defines extra functions
10or subclasses built-in ElementTree classes to add features that are
11only availble in lxml, like OrderedDict for attributes, pretty_print and
12iterwalk.
13"""
14
15from fontTools.misc.textTools import tostr
16
17
18XML_DECLARATION = """<?xml version='1.0' encoding='%s'?>"""
19
20__all__ = [
21    # public symbols
22    "Comment",
23    "dump",
24    "Element",
25    "ElementTree",
26    "fromstring",
27    "fromstringlist",
28    "iselement",
29    "iterparse",
30    "parse",
31    "ParseError",
32    "PI",
33    "ProcessingInstruction",
34    "QName",
35    "SubElement",
36    "tostring",
37    "tostringlist",
38    "TreeBuilder",
39    "XML",
40    "XMLParser",
41    "register_namespace",
42]
43
44try:
45    from lxml.etree import *
46
47    _have_lxml = True
48except ImportError:
49    try:
50        from xml.etree.cElementTree import *
51
52        # the cElementTree version of XML function doesn't support
53        # the optional 'parser' keyword argument
54        from xml.etree.ElementTree import XML
55    except ImportError:  # pragma: no cover
56        from xml.etree.ElementTree import *
57    _have_lxml = False
58
59    import sys
60
61    # dict is always ordered in python >= 3.6 and on pypy
62    PY36 = sys.version_info >= (3, 6)
63    try:
64        import __pypy__
65    except ImportError:
66        __pypy__ = None
67    _dict_is_ordered = bool(PY36 or __pypy__)
68    del PY36, __pypy__
69
70    if _dict_is_ordered:
71        _Attrib = dict
72    else:
73        from collections import OrderedDict as _Attrib
74
75    if isinstance(Element, type):
76        _Element = Element
77    else:
78        # in py27, cElementTree.Element cannot be subclassed, so
79        # we need to import the pure-python class
80        from xml.etree.ElementTree import Element as _Element
81
82    class Element(_Element):
83        """Element subclass that keeps the order of attributes."""
84
85        def __init__(self, tag, attrib=_Attrib(), **extra):
86            super(Element, self).__init__(tag)
87            self.attrib = _Attrib()
88            if attrib:
89                self.attrib.update(attrib)
90            if extra:
91                self.attrib.update(extra)
92
93    def SubElement(parent, tag, attrib=_Attrib(), **extra):
94        """Must override SubElement as well otherwise _elementtree.SubElement
95        fails if 'parent' is a subclass of Element object.
96        """
97        element = parent.__class__(tag, attrib, **extra)
98        parent.append(element)
99        return element
100
101    def _iterwalk(element, events, tag):
102        include = tag is None or element.tag == tag
103        if include and "start" in events:
104            yield ("start", element)
105        for e in element:
106            for item in _iterwalk(e, events, tag):
107                yield item
108        if include:
109            yield ("end", element)
110
111    def iterwalk(element_or_tree, events=("end",), tag=None):
112        """A tree walker that generates events from an existing tree as
113        if it was parsing XML data with iterparse().
114        Drop-in replacement for lxml.etree.iterwalk.
115        """
116        if iselement(element_or_tree):
117            element = element_or_tree
118        else:
119            element = element_or_tree.getroot()
120        if tag == "*":
121            tag = None
122        for item in _iterwalk(element, events, tag):
123            yield item
124
125    _ElementTree = ElementTree
126
127    class ElementTree(_ElementTree):
128        """ElementTree subclass that adds 'pretty_print' and 'doctype'
129        arguments to the 'write' method.
130        Currently these are only supported for the default XML serialization
131        'method', and not also for "html" or "text", for these are delegated
132        to the base class.
133        """
134
135        def write(
136            self,
137            file_or_filename,
138            encoding=None,
139            xml_declaration=False,
140            method=None,
141            doctype=None,
142            pretty_print=False,
143        ):
144            if method and method != "xml":
145                # delegate to super-class
146                super(ElementTree, self).write(
147                    file_or_filename,
148                    encoding=encoding,
149                    xml_declaration=xml_declaration,
150                    method=method,
151                )
152                return
153
154            if encoding is not None and encoding.lower() == "unicode":
155                if xml_declaration:
156                    raise ValueError(
157                        "Serialisation to unicode must not request an XML declaration"
158                    )
159                write_declaration = False
160                encoding = "unicode"
161            elif xml_declaration is None:
162                # by default, write an XML declaration only for non-standard encodings
163                write_declaration = encoding is not None and encoding.upper() not in (
164                    "ASCII",
165                    "UTF-8",
166                    "UTF8",
167                    "US-ASCII",
168                )
169            else:
170                write_declaration = xml_declaration
171
172            if encoding is None:
173                encoding = "ASCII"
174
175            if pretty_print:
176                # NOTE this will modify the tree in-place
177                _indent(self._root)
178
179            with _get_writer(file_or_filename, encoding) as write:
180                if write_declaration:
181                    write(XML_DECLARATION % encoding.upper())
182                    if pretty_print:
183                        write("\n")
184                if doctype:
185                    write(_tounicode(doctype))
186                    if pretty_print:
187                        write("\n")
188
189                qnames, namespaces = _namespaces(self._root)
190                _serialize_xml(write, self._root, qnames, namespaces)
191
192    import io
193
194    def tostring(
195        element,
196        encoding=None,
197        xml_declaration=None,
198        method=None,
199        doctype=None,
200        pretty_print=False,
201    ):
202        """Custom 'tostring' function that uses our ElementTree subclass, with
203        pretty_print support.
204        """
205        stream = io.StringIO() if encoding == "unicode" else io.BytesIO()
206        ElementTree(element).write(
207            stream,
208            encoding=encoding,
209            xml_declaration=xml_declaration,
210            method=method,
211            doctype=doctype,
212            pretty_print=pretty_print,
213        )
214        return stream.getvalue()
215
216    # serialization support
217
218    import re
219
220    # Valid XML strings can include any Unicode character, excluding control
221    # characters, the surrogate blocks, FFFE, and FFFF:
222    #   Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF]
223    # Here we reversed the pattern to match only the invalid characters.
224    # For the 'narrow' python builds supporting only UCS-2, which represent
225    # characters beyond BMP as UTF-16 surrogate pairs, we need to pass through
226    # the surrogate block. I haven't found a more elegant solution...
227    UCS2 = sys.maxunicode < 0x10FFFF
228    if UCS2:
229        _invalid_xml_string = re.compile(
230            "[\u0000-\u0008\u000B-\u000C\u000E-\u001F\uFFFE-\uFFFF]"
231        )
232    else:
233        _invalid_xml_string = re.compile(
234            "[\u0000-\u0008\u000B-\u000C\u000E-\u001F\uD800-\uDFFF\uFFFE-\uFFFF]"
235        )
236
237    def _tounicode(s):
238        """Test if a string is valid user input and decode it to unicode string
239        using ASCII encoding if it's a bytes string.
240        Reject all bytes/unicode input that contains non-XML characters.
241        Reject all bytes input that contains non-ASCII characters.
242        """
243        try:
244            s = tostr(s, encoding="ascii", errors="strict")
245        except UnicodeDecodeError:
246            raise ValueError(
247                "Bytes strings can only contain ASCII characters. "
248                "Use unicode strings for non-ASCII characters."
249            )
250        except AttributeError:
251            _raise_serialization_error(s)
252        if s and _invalid_xml_string.search(s):
253            raise ValueError(
254                "All strings must be XML compatible: Unicode or ASCII, "
255                "no NULL bytes or control characters"
256            )
257        return s
258
259    import contextlib
260
261    @contextlib.contextmanager
262    def _get_writer(file_or_filename, encoding):
263        # returns text write method and release all resources after using
264        try:
265            write = file_or_filename.write
266        except AttributeError:
267            # file_or_filename is a file name
268            f = open(
269                file_or_filename,
270                "w",
271                encoding="utf-8" if encoding == "unicode" else encoding,
272                errors="xmlcharrefreplace",
273            )
274            with f:
275                yield f.write
276        else:
277            # file_or_filename is a file-like object
278            # encoding determines if it is a text or binary writer
279            if encoding == "unicode":
280                # use a text writer as is
281                yield write
282            else:
283                # wrap a binary writer with TextIOWrapper
284                detach_buffer = False
285                if isinstance(file_or_filename, io.BufferedIOBase):
286                    buf = file_or_filename
287                elif isinstance(file_or_filename, io.RawIOBase):
288                    buf = io.BufferedWriter(file_or_filename)
289                    detach_buffer = True
290                else:
291                    # This is to handle passed objects that aren't in the
292                    # IOBase hierarchy, but just have a write method
293                    buf = io.BufferedIOBase()
294                    buf.writable = lambda: True
295                    buf.write = write
296                    try:
297                        # TextIOWrapper uses this methods to determine
298                        # if BOM (for UTF-16, etc) should be added
299                        buf.seekable = file_or_filename.seekable
300                        buf.tell = file_or_filename.tell
301                    except AttributeError:
302                        pass
303                wrapper = io.TextIOWrapper(
304                    buf,
305                    encoding=encoding,
306                    errors="xmlcharrefreplace",
307                    newline="\n",
308                )
309                try:
310                    yield wrapper.write
311                finally:
312                    # Keep the original file open when the TextIOWrapper and
313                    # the BufferedWriter are destroyed
314                    wrapper.detach()
315                    if detach_buffer:
316                        buf.detach()
317
318    from xml.etree.ElementTree import _namespace_map
319
320    def _namespaces(elem):
321        # identify namespaces used in this tree
322
323        # maps qnames to *encoded* prefix:local names
324        qnames = {None: None}
325
326        # maps uri:s to prefixes
327        namespaces = {}
328
329        def add_qname(qname):
330            # calculate serialized qname representation
331            try:
332                qname = _tounicode(qname)
333                if qname[:1] == "{":
334                    uri, tag = qname[1:].rsplit("}", 1)
335                    prefix = namespaces.get(uri)
336                    if prefix is None:
337                        prefix = _namespace_map.get(uri)
338                        if prefix is None:
339                            prefix = "ns%d" % len(namespaces)
340                        else:
341                            prefix = _tounicode(prefix)
342                        if prefix != "xml":
343                            namespaces[uri] = prefix
344                    if prefix:
345                        qnames[qname] = "%s:%s" % (prefix, tag)
346                    else:
347                        qnames[qname] = tag  # default element
348                else:
349                    qnames[qname] = qname
350            except TypeError:
351                _raise_serialization_error(qname)
352
353        # populate qname and namespaces table
354        for elem in elem.iter():
355            tag = elem.tag
356            if isinstance(tag, QName):
357                if tag.text not in qnames:
358                    add_qname(tag.text)
359            elif isinstance(tag, str):
360                if tag not in qnames:
361                    add_qname(tag)
362            elif tag is not None and tag is not Comment and tag is not PI:
363                _raise_serialization_error(tag)
364            for key, value in elem.items():
365                if isinstance(key, QName):
366                    key = key.text
367                if key not in qnames:
368                    add_qname(key)
369                if isinstance(value, QName) and value.text not in qnames:
370                    add_qname(value.text)
371            text = elem.text
372            if isinstance(text, QName) and text.text not in qnames:
373                add_qname(text.text)
374        return qnames, namespaces
375
376    def _serialize_xml(write, elem, qnames, namespaces, **kwargs):
377        tag = elem.tag
378        text = elem.text
379        if tag is Comment:
380            write("<!--%s-->" % _tounicode(text))
381        elif tag is ProcessingInstruction:
382            write("<?%s?>" % _tounicode(text))
383        else:
384            tag = qnames[_tounicode(tag) if tag is not None else None]
385            if tag is None:
386                if text:
387                    write(_escape_cdata(text))
388                for e in elem:
389                    _serialize_xml(write, e, qnames, None)
390            else:
391                write("<" + tag)
392                if namespaces:
393                    for uri, prefix in sorted(
394                        namespaces.items(), key=lambda x: x[1]
395                    ):  # sort on prefix
396                        if prefix:
397                            prefix = ":" + prefix
398                        write(' xmlns%s="%s"' % (prefix, _escape_attrib(uri)))
399                attrs = elem.attrib
400                if attrs:
401                    # try to keep existing attrib order
402                    if len(attrs) <= 1 or type(attrs) is _Attrib:
403                        items = attrs.items()
404                    else:
405                        # if plain dict, use lexical order
406                        items = sorted(attrs.items())
407                    for k, v in items:
408                        if isinstance(k, QName):
409                            k = _tounicode(k.text)
410                        else:
411                            k = _tounicode(k)
412                        if isinstance(v, QName):
413                            v = qnames[_tounicode(v.text)]
414                        else:
415                            v = _escape_attrib(v)
416                        write(' %s="%s"' % (qnames[k], v))
417                if text is not None or len(elem):
418                    write(">")
419                    if text:
420                        write(_escape_cdata(text))
421                    for e in elem:
422                        _serialize_xml(write, e, qnames, None)
423                    write("</" + tag + ">")
424                else:
425                    write("/>")
426        if elem.tail:
427            write(_escape_cdata(elem.tail))
428
429    def _raise_serialization_error(text):
430        raise TypeError("cannot serialize %r (type %s)" % (text, type(text).__name__))
431
432    def _escape_cdata(text):
433        # escape character data
434        try:
435            text = _tounicode(text)
436            # it's worth avoiding do-nothing calls for short strings
437            if "&" in text:
438                text = text.replace("&", "&amp;")
439            if "<" in text:
440                text = text.replace("<", "&lt;")
441            if ">" in text:
442                text = text.replace(">", "&gt;")
443            return text
444        except (TypeError, AttributeError):
445            _raise_serialization_error(text)
446
447    def _escape_attrib(text):
448        # escape attribute value
449        try:
450            text = _tounicode(text)
451            if "&" in text:
452                text = text.replace("&", "&amp;")
453            if "<" in text:
454                text = text.replace("<", "&lt;")
455            if ">" in text:
456                text = text.replace(">", "&gt;")
457            if '"' in text:
458                text = text.replace('"', "&quot;")
459            if "\n" in text:
460                text = text.replace("\n", "&#10;")
461            return text
462        except (TypeError, AttributeError):
463            _raise_serialization_error(text)
464
465    def _indent(elem, level=0):
466        # From http://effbot.org/zone/element-lib.htm#prettyprint
467        i = "\n" + level * "  "
468        if len(elem):
469            if not elem.text or not elem.text.strip():
470                elem.text = i + "  "
471            if not elem.tail or not elem.tail.strip():
472                elem.tail = i
473            for elem in elem:
474                _indent(elem, level + 1)
475            if not elem.tail or not elem.tail.strip():
476                elem.tail = i
477        else:
478            if level and (not elem.tail or not elem.tail.strip()):
479                elem.tail = i
480