1# mako/pyparser.py
2# Copyright 2006-2023 the Mako authors and contributors <see AUTHORS file>
3#
4# This module is part of Mako and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7"""Handles parsing of Python code.
8
9Parsing to AST is done via _ast on Python > 2.5, otherwise the compiler
10module is used.
11"""
12
13import operator
14
15import _ast
16
17from mako import _ast_util
18from mako import compat
19from mako import exceptions
20from mako import util
21
22# words that cannot be assigned to (notably
23# smaller than the total keys in __builtins__)
24reserved = {"True", "False", "None", "print"}
25
26# the "id" attribute on a function node
27arg_id = operator.attrgetter("arg")
28
29util.restore__ast(_ast)
30
31
32def parse(code, mode="exec", **exception_kwargs):
33    """Parse an expression into AST"""
34
35    try:
36        return _ast_util.parse(code, "<unknown>", mode)
37    except Exception as e:
38        raise exceptions.SyntaxException(
39            "(%s) %s (%r)"
40            % (
41                compat.exception_as().__class__.__name__,
42                compat.exception_as(),
43                code[0:50],
44            ),
45            **exception_kwargs,
46        ) from e
47
48
49class FindIdentifiers(_ast_util.NodeVisitor):
50    def __init__(self, listener, **exception_kwargs):
51        self.in_function = False
52        self.in_assign_targets = False
53        self.local_ident_stack = set()
54        self.listener = listener
55        self.exception_kwargs = exception_kwargs
56
57    def _add_declared(self, name):
58        if not self.in_function:
59            self.listener.declared_identifiers.add(name)
60        else:
61            self.local_ident_stack.add(name)
62
63    def visit_ClassDef(self, node):
64        self._add_declared(node.name)
65
66    def visit_Assign(self, node):
67        # flip around the visiting of Assign so the expression gets
68        # evaluated first, in the case of a clause like "x=x+5" (x
69        # is undeclared)
70
71        self.visit(node.value)
72        in_a = self.in_assign_targets
73        self.in_assign_targets = True
74        for n in node.targets:
75            self.visit(n)
76        self.in_assign_targets = in_a
77
78    def visit_ExceptHandler(self, node):
79        if node.name is not None:
80            self._add_declared(node.name)
81        if node.type is not None:
82            self.visit(node.type)
83        for statement in node.body:
84            self.visit(statement)
85
86    def visit_Lambda(self, node, *args):
87        self._visit_function(node, True)
88
89    def visit_FunctionDef(self, node):
90        self._add_declared(node.name)
91        self._visit_function(node, False)
92
93    def _expand_tuples(self, args):
94        for arg in args:
95            if isinstance(arg, _ast.Tuple):
96                yield from arg.elts
97            else:
98                yield arg
99
100    def _visit_function(self, node, islambda):
101        # push function state onto stack.  dont log any more
102        # identifiers as "declared" until outside of the function,
103        # but keep logging identifiers as "undeclared". track
104        # argument names in each function header so they arent
105        # counted as "undeclared"
106
107        inf = self.in_function
108        self.in_function = True
109
110        local_ident_stack = self.local_ident_stack
111        self.local_ident_stack = local_ident_stack.union(
112            [arg_id(arg) for arg in self._expand_tuples(node.args.args)]
113        )
114        if islambda:
115            self.visit(node.body)
116        else:
117            for n in node.body:
118                self.visit(n)
119        self.in_function = inf
120        self.local_ident_stack = local_ident_stack
121
122    def visit_For(self, node):
123        # flip around visit
124
125        self.visit(node.iter)
126        self.visit(node.target)
127        for statement in node.body:
128            self.visit(statement)
129        for statement in node.orelse:
130            self.visit(statement)
131
132    def visit_Name(self, node):
133        if isinstance(node.ctx, _ast.Store):
134            # this is eqiuvalent to visit_AssName in
135            # compiler
136            self._add_declared(node.id)
137        elif (
138            node.id not in reserved
139            and node.id not in self.listener.declared_identifiers
140            and node.id not in self.local_ident_stack
141        ):
142            self.listener.undeclared_identifiers.add(node.id)
143
144    def visit_Import(self, node):
145        for name in node.names:
146            if name.asname is not None:
147                self._add_declared(name.asname)
148            else:
149                self._add_declared(name.name.split(".")[0])
150
151    def visit_ImportFrom(self, node):
152        for name in node.names:
153            if name.asname is not None:
154                self._add_declared(name.asname)
155            elif name.name == "*":
156                raise exceptions.CompileException(
157                    "'import *' is not supported, since all identifier "
158                    "names must be explicitly declared.  Please use the "
159                    "form 'from <modulename> import <name1>, <name2>, "
160                    "...' instead.",
161                    **self.exception_kwargs,
162                )
163            else:
164                self._add_declared(name.name)
165
166
167class FindTuple(_ast_util.NodeVisitor):
168    def __init__(self, listener, code_factory, **exception_kwargs):
169        self.listener = listener
170        self.exception_kwargs = exception_kwargs
171        self.code_factory = code_factory
172
173    def visit_Tuple(self, node):
174        for n in node.elts:
175            p = self.code_factory(n, **self.exception_kwargs)
176            self.listener.codeargs.append(p)
177            self.listener.args.append(ExpressionGenerator(n).value())
178            ldi = self.listener.declared_identifiers
179            self.listener.declared_identifiers = ldi.union(
180                p.declared_identifiers
181            )
182            lui = self.listener.undeclared_identifiers
183            self.listener.undeclared_identifiers = lui.union(
184                p.undeclared_identifiers
185            )
186
187
188class ParseFunc(_ast_util.NodeVisitor):
189    def __init__(self, listener, **exception_kwargs):
190        self.listener = listener
191        self.exception_kwargs = exception_kwargs
192
193    def visit_FunctionDef(self, node):
194        self.listener.funcname = node.name
195
196        argnames = [arg_id(arg) for arg in node.args.args]
197        if node.args.vararg:
198            argnames.append(node.args.vararg.arg)
199
200        kwargnames = [arg_id(arg) for arg in node.args.kwonlyargs]
201        if node.args.kwarg:
202            kwargnames.append(node.args.kwarg.arg)
203        self.listener.argnames = argnames
204        self.listener.defaults = node.args.defaults  # ast
205        self.listener.kwargnames = kwargnames
206        self.listener.kwdefaults = node.args.kw_defaults
207        self.listener.varargs = node.args.vararg
208        self.listener.kwargs = node.args.kwarg
209
210
211class ExpressionGenerator:
212    def __init__(self, astnode):
213        self.generator = _ast_util.SourceGenerator(" " * 4)
214        self.generator.visit(astnode)
215
216    def value(self):
217        return "".join(self.generator.result)
218