1# mako/util.py
2# Copyright 2006-2023 the Mako authors and contributors <see AUTHORS file>
3#
4# This module is part of Mako and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6from ast import parse
7import codecs
8import collections
9import operator
10import os
11import re
12import timeit
13
14from .compat import importlib_metadata_get
15
16
17def update_wrapper(decorated, fn):
18    decorated.__wrapped__ = fn
19    decorated.__name__ = fn.__name__
20    return decorated
21
22
23class PluginLoader:
24    def __init__(self, group):
25        self.group = group
26        self.impls = {}
27
28    def load(self, name):
29        if name in self.impls:
30            return self.impls[name]()
31
32        for impl in importlib_metadata_get(self.group):
33            if impl.name == name:
34                self.impls[name] = impl.load
35                return impl.load()
36
37        from mako import exceptions
38
39        raise exceptions.RuntimeException(
40            "Can't load plugin %s %s" % (self.group, name)
41        )
42
43    def register(self, name, modulepath, objname):
44        def load():
45            mod = __import__(modulepath)
46            for token in modulepath.split(".")[1:]:
47                mod = getattr(mod, token)
48            return getattr(mod, objname)
49
50        self.impls[name] = load
51
52
53def verify_directory(dir_):
54    """create and/or verify a filesystem directory."""
55
56    tries = 0
57
58    while not os.path.exists(dir_):
59        try:
60            tries += 1
61            os.makedirs(dir_, 0o755)
62        except:
63            if tries > 5:
64                raise
65
66
67def to_list(x, default=None):
68    if x is None:
69        return default
70    if not isinstance(x, (list, tuple)):
71        return [x]
72    else:
73        return x
74
75
76class memoized_property:
77
78    """A read-only @property that is only evaluated once."""
79
80    def __init__(self, fget, doc=None):
81        self.fget = fget
82        self.__doc__ = doc or fget.__doc__
83        self.__name__ = fget.__name__
84
85    def __get__(self, obj, cls):
86        if obj is None:
87            return self
88        obj.__dict__[self.__name__] = result = self.fget(obj)
89        return result
90
91
92class memoized_instancemethod:
93
94    """Decorate a method memoize its return value.
95
96    Best applied to no-arg methods: memoization is not sensitive to
97    argument values, and will always return the same value even when
98    called with different arguments.
99
100    """
101
102    def __init__(self, fget, doc=None):
103        self.fget = fget
104        self.__doc__ = doc or fget.__doc__
105        self.__name__ = fget.__name__
106
107    def __get__(self, obj, cls):
108        if obj is None:
109            return self
110
111        def oneshot(*args, **kw):
112            result = self.fget(obj, *args, **kw)
113
114            def memo(*a, **kw):
115                return result
116
117            memo.__name__ = self.__name__
118            memo.__doc__ = self.__doc__
119            obj.__dict__[self.__name__] = memo
120            return result
121
122        oneshot.__name__ = self.__name__
123        oneshot.__doc__ = self.__doc__
124        return oneshot
125
126
127class SetLikeDict(dict):
128
129    """a dictionary that has some setlike methods on it"""
130
131    def union(self, other):
132        """produce a 'union' of this dict and another (at the key level).
133
134        values in the second dict take precedence over that of the first"""
135        x = SetLikeDict(**self)
136        x.update(other)
137        return x
138
139
140class FastEncodingBuffer:
141
142    """a very rudimentary buffer that is faster than StringIO,
143    and supports unicode data."""
144
145    def __init__(self, encoding=None, errors="strict"):
146        self.data = collections.deque()
147        self.encoding = encoding
148        self.delim = ""
149        self.errors = errors
150        self.write = self.data.append
151
152    def truncate(self):
153        self.data = collections.deque()
154        self.write = self.data.append
155
156    def getvalue(self):
157        if self.encoding:
158            return self.delim.join(self.data).encode(
159                self.encoding, self.errors
160            )
161        else:
162            return self.delim.join(self.data)
163
164
165class LRUCache(dict):
166
167    """A dictionary-like object that stores a limited number of items,
168    discarding lesser used items periodically.
169
170    this is a rewrite of LRUCache from Myghty to use a periodic timestamp-based
171    paradigm so that synchronization is not really needed.  the size management
172    is inexact.
173    """
174
175    class _Item:
176        def __init__(self, key, value):
177            self.key = key
178            self.value = value
179            self.timestamp = timeit.default_timer()
180
181        def __repr__(self):
182            return repr(self.value)
183
184    def __init__(self, capacity, threshold=0.5):
185        self.capacity = capacity
186        self.threshold = threshold
187
188    def __getitem__(self, key):
189        item = dict.__getitem__(self, key)
190        item.timestamp = timeit.default_timer()
191        return item.value
192
193    def values(self):
194        return [i.value for i in dict.values(self)]
195
196    def setdefault(self, key, value):
197        if key in self:
198            return self[key]
199        self[key] = value
200        return value
201
202    def __setitem__(self, key, value):
203        item = dict.get(self, key)
204        if item is None:
205            item = self._Item(key, value)
206            dict.__setitem__(self, key, item)
207        else:
208            item.value = value
209        self._manage_size()
210
211    def _manage_size(self):
212        while len(self) > self.capacity + self.capacity * self.threshold:
213            bytime = sorted(
214                dict.values(self),
215                key=operator.attrgetter("timestamp"),
216                reverse=True,
217            )
218            for item in bytime[self.capacity :]:
219                try:
220                    del self[item.key]
221                except KeyError:
222                    # if we couldn't find a key, most likely some other thread
223                    # broke in on us. loop around and try again
224                    break
225
226
227# Regexp to match python magic encoding line
228_PYTHON_MAGIC_COMMENT_re = re.compile(
229    r"[ \t\f]* \# .* coding[=:][ \t]*([-\w.]+)", re.VERBOSE
230)
231
232
233def parse_encoding(fp):
234    """Deduce the encoding of a Python source file (binary mode) from magic
235    comment.
236
237    It does this in the same way as the `Python interpreter`__
238
239    .. __: http://docs.python.org/ref/encodings.html
240
241    The ``fp`` argument should be a seekable file object in binary mode.
242    """
243    pos = fp.tell()
244    fp.seek(0)
245    try:
246        line1 = fp.readline()
247        has_bom = line1.startswith(codecs.BOM_UTF8)
248        if has_bom:
249            line1 = line1[len(codecs.BOM_UTF8) :]
250
251        m = _PYTHON_MAGIC_COMMENT_re.match(line1.decode("ascii", "ignore"))
252        if not m:
253            try:
254                parse(line1.decode("ascii", "ignore"))
255            except (ImportError, SyntaxError):
256                # Either it's a real syntax error, in which case the source
257                # is not valid python source, or line2 is a continuation of
258                # line1, in which case we don't want to scan line2 for a magic
259                # comment.
260                pass
261            else:
262                line2 = fp.readline()
263                m = _PYTHON_MAGIC_COMMENT_re.match(
264                    line2.decode("ascii", "ignore")
265                )
266
267        if has_bom:
268            if m:
269                raise SyntaxError(
270                    "python refuses to compile code with both a UTF8"
271                    " byte-order-mark and a magic encoding comment"
272                )
273            return "utf_8"
274        elif m:
275            return m.group(1)
276        else:
277            return None
278    finally:
279        fp.seek(pos)
280
281
282def sorted_dict_repr(d):
283    """repr() a dictionary with the keys in order.
284
285    Used by the lexer unit test to compare parse trees based on strings.
286
287    """
288    keys = list(d.keys())
289    keys.sort()
290    return "{" + ", ".join("%r: %r" % (k, d[k]) for k in keys) + "}"
291
292
293def restore__ast(_ast):
294    """Attempt to restore the required classes to the _ast module if it
295    appears to be missing them
296    """
297    if hasattr(_ast, "AST"):
298        return
299    _ast.PyCF_ONLY_AST = 2 << 9
300    m = compile(
301        """\
302def foo(): pass
303class Bar: pass
304if False: pass
305baz = 'mako'
3061 + 2 - 3 * 4 / 5
3076 // 7 % 8 << 9 >> 10
30811 & 12 ^ 13 | 14
30915 and 16 or 17
310-baz + (not +18) - ~17
311baz and 'foo' or 'bar'
312(mako is baz == baz) is not baz != mako
313mako > baz < mako >= baz <= mako
314mako in baz not in mako""",
315        "<unknown>",
316        "exec",
317        _ast.PyCF_ONLY_AST,
318    )
319    _ast.Module = type(m)
320
321    for cls in _ast.Module.__mro__:
322        if cls.__name__ == "mod":
323            _ast.mod = cls
324        elif cls.__name__ == "AST":
325            _ast.AST = cls
326
327    _ast.FunctionDef = type(m.body[0])
328    _ast.ClassDef = type(m.body[1])
329    _ast.If = type(m.body[2])
330
331    _ast.Name = type(m.body[3].targets[0])
332    _ast.Store = type(m.body[3].targets[0].ctx)
333    _ast.Str = type(m.body[3].value)
334
335    _ast.Sub = type(m.body[4].value.op)
336    _ast.Add = type(m.body[4].value.left.op)
337    _ast.Div = type(m.body[4].value.right.op)
338    _ast.Mult = type(m.body[4].value.right.left.op)
339
340    _ast.RShift = type(m.body[5].value.op)
341    _ast.LShift = type(m.body[5].value.left.op)
342    _ast.Mod = type(m.body[5].value.left.left.op)
343    _ast.FloorDiv = type(m.body[5].value.left.left.left.op)
344
345    _ast.BitOr = type(m.body[6].value.op)
346    _ast.BitXor = type(m.body[6].value.left.op)
347    _ast.BitAnd = type(m.body[6].value.left.left.op)
348
349    _ast.Or = type(m.body[7].value.op)
350    _ast.And = type(m.body[7].value.values[0].op)
351
352    _ast.Invert = type(m.body[8].value.right.op)
353    _ast.Not = type(m.body[8].value.left.right.op)
354    _ast.UAdd = type(m.body[8].value.left.right.operand.op)
355    _ast.USub = type(m.body[8].value.left.left.op)
356
357    _ast.Or = type(m.body[9].value.op)
358    _ast.And = type(m.body[9].value.values[0].op)
359
360    _ast.IsNot = type(m.body[10].value.ops[0])
361    _ast.NotEq = type(m.body[10].value.ops[1])
362    _ast.Is = type(m.body[10].value.left.ops[0])
363    _ast.Eq = type(m.body[10].value.left.ops[1])
364
365    _ast.Gt = type(m.body[11].value.ops[0])
366    _ast.Lt = type(m.body[11].value.ops[1])
367    _ast.GtE = type(m.body[11].value.ops[2])
368    _ast.LtE = type(m.body[11].value.ops[3])
369
370    _ast.In = type(m.body[12].value.ops[0])
371    _ast.NotIn = type(m.body[12].value.ops[1])
372
373
374def read_file(path, mode="rb"):
375    with open(path, mode) as fp:
376        return fp.read()
377
378
379def read_python_file(path):
380    fp = open(path, "rb")
381    try:
382        encoding = parse_encoding(fp)
383        data = fp.read()
384        if encoding:
385            data = data.decode(encoding)
386        return data
387    finally:
388        fp.close()
389