1import os
2
3from mako.cache import CacheImpl
4from mako.cache import register_plugin
5from mako.template import Template
6from .assertions import eq_
7from .config import config
8
9
10class TemplateTest:
11    def _file_template(self, filename, **kw):
12        filepath = self._file_path(filename)
13        return Template(
14            uri=filename,
15            filename=filepath,
16            module_directory=config.module_base,
17            **kw,
18        )
19
20    def _file_path(self, filename):
21        name, ext = os.path.splitext(filename)
22        py3k_path = os.path.join(config.template_base, name + "_py3k" + ext)
23        if os.path.exists(py3k_path):
24            return py3k_path
25
26        return os.path.join(config.template_base, filename)
27
28    def _do_file_test(
29        self,
30        filename,
31        expected,
32        filters=None,
33        unicode_=True,
34        template_args=None,
35        **kw,
36    ):
37        t1 = self._file_template(filename, **kw)
38        self._do_test(
39            t1,
40            expected,
41            filters=filters,
42            unicode_=unicode_,
43            template_args=template_args,
44        )
45
46    def _do_memory_test(
47        self,
48        source,
49        expected,
50        filters=None,
51        unicode_=True,
52        template_args=None,
53        **kw,
54    ):
55        t1 = Template(text=source, **kw)
56        self._do_test(
57            t1,
58            expected,
59            filters=filters,
60            unicode_=unicode_,
61            template_args=template_args,
62        )
63
64    def _do_test(
65        self,
66        template,
67        expected,
68        filters=None,
69        template_args=None,
70        unicode_=True,
71    ):
72        if template_args is None:
73            template_args = {}
74        if unicode_:
75            output = template.render_unicode(**template_args)
76        else:
77            output = template.render(**template_args)
78
79        if filters:
80            output = filters(output)
81        eq_(output, expected)
82
83    def indicates_unbound_local_error(self, rendered_output, unbound_var):
84        var = f"'{unbound_var}'"
85        error_msgs = (
86            # < 3.11
87            f"local variable {var} referenced before assignment",
88            # >= 3.11
89            f"cannot access local variable {var} where it is not associated",
90        )
91        return any((msg in rendered_output) for msg in error_msgs)
92
93
94class PlainCacheImpl(CacheImpl):
95    """Simple memory cache impl so that tests which
96    use caching can run without beaker."""
97
98    def __init__(self, cache):
99        self.cache = cache
100        self.data = {}
101
102    def get_or_create(self, key, creation_function, **kw):
103        if key in self.data:
104            return self.data[key]
105        else:
106            self.data[key] = data = creation_function(**kw)
107            return data
108
109    def put(self, key, value, **kw):
110        self.data[key] = value
111
112    def get(self, key, **kw):
113        return self.data[key]
114
115    def invalidate(self, key, **kw):
116        del self.data[key]
117
118
119register_plugin("plain", __name__, "PlainCacheImpl")
120