xref: /aosp_15_r20/external/pytorch/torch/jit/frontend.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import ast
3import dataclasses
4import inspect
5import re
6import string
7import sys
8from collections import namedtuple
9from textwrap import dedent
10from typing import List, Tuple  # noqa: F401
11
12import torch
13import torch.jit.annotations
14from torch import _jit_internal
15from torch._C._jit_tree_views import (
16    Apply,
17    Assert,
18    Assign,
19    Attribute,
20    AugAssign,
21    BinOp,
22    Break,
23    ClassDef,
24    Const,
25    Continue,
26    Decl,
27    Def,
28    Delete,
29    DictComp,
30    DictLiteral,
31    Dots,
32    EmptyTypeAnnotation,
33    ExprStmt,
34    FalseLiteral,
35    For,
36    Ident,
37    If,
38    ListComp,
39    ListLiteral,
40    NoneLiteral,
41    Param,
42    Pass,
43    Property,
44    Raise,
45    Return,
46    Select,
47    SliceExpr,
48    Starred,
49    Stmt,
50    StringLiteral,
51    Subscript,
52    TernaryIf,
53    TrueLiteral,
54    TupleLiteral,
55    UnaryOp,
56    Var,
57    While,
58    With,
59    WithItem,
60)
61from torch._jit_internal import (  # noqa: F401
62    _is_drop_fn,
63    FunctionModifiers,
64    is_static_fn,
65    should_drop,
66)
67from torch._sources import (
68    get_source_lines_and_file,
69    make_source_context,
70    parse_def,
71    ParsedDef as _ParsedDef,
72)
73from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS
74from torch.jit._monkeytype_config import get_qualified_name, monkeytype_trace
75
76
77_IS_ASTUNPARSE_INSTALLED = False
78try:
79    import astunparse  # type: ignore[import]
80
81    _IS_ASTUNPARSE_INSTALLED = True
82except ImportError:
83    pass
84
85# Borrowed from cPython implementation
86# https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411#
87
88_reserved_prefix = "__jit"
89_reserved_names = {"print"}
90_identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits)
91
92
93def is_reserved_name(name):
94    return name.startswith(_reserved_prefix) or name in _reserved_names
95
96
97pretty_node_names = {
98    ast.FunctionDef: "function definitions",
99    ast.For: "for loops",
100    ast.Delete: "del statements",
101    ast.ClassDef: "class definitions",
102    ast.With: "with statements",
103    ast.Raise: "raise statements",
104    ast.Assert: "assertions",
105    ast.Import: "import statements",
106    ast.ImportFrom: "import statements",
107    ast.Global: "global variables",
108    ast.Break: "break statements",
109    ast.Continue: "continue statements",
110}
111
112node_start_tokens = {
113    ast.FunctionDef: "def",
114    ast.For: "for",
115    ast.Delete: "del",
116    ast.ClassDef: "class",
117    ast.With: "with",
118    ast.Raise: "raise",
119    ast.Assert: "assert",
120    ast.Import: "import",
121    ast.ImportFrom: "from",
122    ast.Global: "global",
123    ast.Break: "break",
124    ast.Continue: "continue",
125}
126
127pretty_node_names.update(
128    {
129        ast.AsyncFunctionDef: "async function definitions",
130        ast.AsyncFor: "async for loops",
131        ast.AsyncWith: "async with statements",
132        ast.Try: "try blocks",
133        ast.Nonlocal: "nonlocal variables",
134    }
135)
136
137node_start_tokens.update(
138    {
139        ast.AsyncFunctionDef: "async def",
140        ast.AsyncFor: "async for",
141        ast.AsyncWith: "async with",
142        ast.Try: "try",
143        ast.Nonlocal: "nonlocal",
144    }
145)
146
147pretty_node_names.update(
148    {
149        ast.AnnAssign: "annotated assignments",
150    }
151)
152# NB: no specific token for AnnAssign
153
154
155class FrontendError(Exception):
156    def __init__(self, source_range, msg):
157        self.source_range = source_range
158        self.msg = msg
159
160        # This has to be instantiated here so the ErrorReport is accurate to the
161        # call stack when the FrontendError was raised
162        self.error_report = torch._C.ErrorReport(self.source_range)
163
164    def __str__(self):
165        return self.msg + self.error_report.what().lstrip()
166
167
168class NotSupportedError(FrontendError):
169    pass
170
171
172class UnsupportedNodeError(NotSupportedError):
173    def __init__(self, ctx, offending_node, reason=""):
174        # If we don't have a specific token, we default to length of 1
175        node_type = type(offending_node)
176        range_len = len(node_start_tokens.get(node_type, " "))
177        source_range = ctx.make_range(
178            offending_node.lineno,
179            offending_node.col_offset,
180            offending_node.col_offset + range_len,
181        )
182        feature_name = pretty_node_names.get(node_type, node_type.__name__)
183        msg = f"{feature_name} {reason + ' ' if reason else ''}aren't supported"
184        super().__init__(source_range, msg)
185
186
187class FrontendTypeError(FrontendError):
188    pass
189
190
191def build_withitems(ctx, items):
192    items = [build_withitem(ctx, i) for i in items]
193    return list(items)
194
195
196def build_stmts(ctx, stmts):
197    stmts = [build_stmt(ctx, s) for s in stmts]
198    return list(filter(None, stmts))
199
200
201def get_class_properties(cls, self_name):
202    """
203    Get a list of Property objects representing the properties of a class.
204
205    Args:
206        cls:  The class to get properties of.
207        self_name: The name of the class that the properties should belong to.
208    Returns:
209        A list of Property objects corresponding to the properties of cls. Property
210        here refers to the subclass of TreeView.
211    """
212    props = inspect.getmembers(cls, predicate=lambda m: isinstance(m, property))
213    # Any property that should not compiled must be in this list on the Module.
214    unused_properties = getattr(cls, "__jit_unused_properties__", [])
215
216    # Create Property TreeView objects from inspected property objects.
217    properties = []
218    for prop in props:
219        if prop[0] not in unused_properties and not should_drop(prop[1].fget):
220            getter = get_jit_def(
221                prop[1].fget, f"__{prop[0]}_getter", self_name=self_name
222            )
223            setter = (
224                get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name)
225                if prop[1].fset
226                else None
227            )
228            properties.append(
229                Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter)
230            )
231
232    return properties
233
234
235def get_class_assigns(ctx, cls_ast):
236    assigns = []
237
238    def maybe_build_assign(builder, entry):
239        nonlocal assigns
240        try:
241            assigns.append(builder(ctx, entry))
242        except NotSupportedError:
243            pass
244
245    for entry in cls_ast.body:
246        if isinstance(entry, ast.Assign):
247            maybe_build_assign(StmtBuilder.build_Assign, entry)
248        elif isinstance(entry, ast.AnnAssign):
249            maybe_build_assign(StmtBuilder.build_AnnAssign, entry)
250    return assigns
251
252
253def get_jit_class_def(cls, self_name):
254    """Get definitions for each method within the current class independently.
255
256    Args:
257        cls: The class to get definition of.
258        self_name: The name of the class that the properties should belong to.
259
260    Returns:
261        torch._C._jit_tree_views.ClassDef: A representation of the class,
262            the methods in the class and their definition as a tree.
263    """
264    # TODO: proper overriding analysis when implementing class inheritance
265    methods = inspect.getmembers(
266        cls,
267        predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m))
268        and not is_static_fn(cls, m.__name__)
269        and m.__name__ in cls.__dict__
270        and not _is_drop_fn(m),
271    )
272
273    def is_classmethod(fn):
274        return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls
275
276    # Get and parse the source code for this class
277    sourcelines, file_lineno, filename = get_source_lines_and_file(
278        cls, torch._C.ErrorReport.call_stack()
279    )
280    source = "".join(sourcelines)
281
282    dedent_src = dedent(source)
283    py_ast = ast.parse(dedent_src)
284
285    class_ast = py_ast.body[0]
286    assert isinstance(class_ast, ast.ClassDef)
287
288    # Special case for dataclasses. In general we need access to the source code for
289    # an object in order to JIT compile it. But the dataclasses module dynamically synthesizes
290    # magic methods for classes, and we can't get the source code for these methods. As a
291    # workaround, we synthesize TorchScript-friendly implementations ourselves.
292    if dataclasses.is_dataclass(cls):
293        # Detect whether the user manually implemented any of the magic methods. If they did,
294        # we don't want to synthesize/override them.
295        overrides = {
296            method.name
297            for method in class_ast.body
298            if isinstance(method, ast.FunctionDef)
299            and method.name in DATACLASS_MAGIC_METHODS
300        }
301        for i, (name, _) in enumerate(methods):
302            # Is this a magic method we can synthesize?
303            synthesizer_fn = DATACLASS_MAGIC_METHODS.get(name)
304            if synthesizer_fn and name not in overrides:
305                parsed_def = synthesizer_fn(cls)
306                methods[i] = name, parsed_def
307                func = getattr(cls, name)
308                _jit_internal.loader.cache(func, parsed_def.source)
309
310    method_defs = [
311        get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj))
312        for (name, obj) in methods
313    ]
314    properties = get_class_properties(cls, self_name)
315
316    leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
317        dedent_src.split("\n", 1)[0]
318    )
319    ctx = make_source_context(
320        source, filename, file_lineno, leading_whitespace_len, False
321    )
322    assigns = get_class_assigns(ctx, class_ast)
323
324    return build_class_def(ctx, class_ast, method_defs, properties, self_name, assigns)
325
326
327def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
328    """
329    Build a JIT AST (TreeView) from the given function.
330
331    Args:
332        fn: A function object to compile or a pre-parsed ParsedDef object
333        def_name: The name to give to the resulting AST object. This is not
334            always the same as `fn.__name__`, for example:
335                def _forward(self):
336                    ...
337                forward = _forward
338            In this case, the `__name__` attribute of the function object is "_forward",
339            but we want the result AST to have the name "forward".
340        self_name: If this function is a method, what the type name of `self` is.
341    """
342    parsed_def = parse_def(fn) if not isinstance(fn, _ParsedDef) else fn
343    type_line = torch.jit.annotations.get_type_line(parsed_def.source)
344    fn_def = parsed_def.ast.body[0]
345
346    if is_classmethod:
347        arg_name = fn_def.args.args[0].arg
348        # Insert a statement that assigns the first argument to the class
349        assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0]
350        fn_def.body.insert(0, assign_stmt)
351
352    # Swap out the function signature and body if it is unused
353    if should_drop(fn):
354        unused_fn_def = ast.parse(
355            'def unused_fn(self: Any):\n\traise RuntimeError("Cannot call @unused methods")'
356        )
357        if len(unused_fn_def.body) != 1 or not isinstance(
358            unused_fn_def.body[0], ast.FunctionDef
359        ):
360            raise RuntimeError(
361                f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}"
362            )
363        unused_def = unused_fn_def.body[0]
364        fn_def.body = unused_def.body
365        # kwarg/vararg not supported by `build_def`
366        fn_def.args.kwarg = fn_def.args.vararg = None
367        for arg in fn_def.args.args + fn_def.args.kwonlyargs:
368            # Replace potentially unsupported type annotations by "Any"
369            arg.annotation = unused_def.args.args[0].annotation
370        if _is_drop_fn(fn):
371            # Dropping potentially unsupported return type annotation for jit._drop
372            fn_def.returns = None
373            fn_def.type_comment = None
374
375    # If MonkeyType is installed, get all the consolidated type traces
376    # for the arguments from type_trace_db
377    type_trace_db = torch.jit._script._get_type_trace_db()
378    pdt_arg_types = None
379    if monkeytype_trace and not isinstance(fn, _ParsedDef):  # type: ignore[truthy-function]
380        qualname = get_qualified_name(fn)
381        pdt_arg_types = type_trace_db.get_args_types(qualname)
382
383    return build_def(
384        parsed_def.ctx,
385        fn_def,
386        type_line,
387        def_name,
388        self_name=self_name,
389        pdt_arg_types=pdt_arg_types,
390    )
391
392
393# TODO: more robust handling of recognizing ignore context manager
394def is_torch_jit_ignore_context_manager(stmt):
395    # checks if the statement is torch.jit.ignore context manager
396    if isinstance(stmt.items[0].context_expr, ast.Call):
397        # extract torch part
398        function = stmt.items[0].context_expr.func
399        if isinstance(function, ast.Attribute):
400            attr_name = function.attr
401            attr_value = function.value
402            if attr_name == "_IgnoreContextManager" and isinstance(
403                attr_value, ast.Attribute
404            ):
405                # there should be at most two nested attributes (e.g torch.jit._IgnoreContextManager)
406                if attr_value.attr == "jit" and isinstance(attr_value.value, ast.Name):
407                    if attr_value.value.id == "torch":
408                        return True
409    return False
410
411
412class Builder:
413    def __call__(self, ctx, node):
414        method = getattr(self, "build_" + node.__class__.__name__, None)
415        if method is None:
416            raise UnsupportedNodeError(ctx, node)
417        return method(ctx, node)
418
419
420def build_class_def(ctx, py_def, methods, properties, self_name, assigns):
421    r = ctx.make_range(
422        py_def.lineno, py_def.col_offset, py_def.col_offset + len("class")
423    )
424    return ClassDef(
425        Ident(r, self_name), [Stmt(method) for method in methods], properties, assigns
426    )
427
428
429def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=None):
430    body = py_def.body
431    r = ctx.make_range(py_def.lineno, py_def.col_offset, py_def.col_offset + len("def"))
432
433    param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
434    return_type = None
435    if getattr(py_def, "returns", None) is not None:
436        return_type = build_expr(ctx, py_def.returns)
437
438    decl = Decl(r, param_list, return_type)
439    is_method = self_name is not None
440    if type_line is not None:
441        type_comment_decl = torch._C.parse_type_comment(type_line)
442        decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
443
444    return Def(Ident(r, def_name), decl, build_stmts(ctx, body))
445
446
447_vararg_kwarg_err = (
448    "Compiled functions can't take variable number of arguments "
449    "or use keyword-only arguments with defaults"
450)
451
452
453def build_param_list(ctx, py_args, self_name, pdt_arg_types=None):
454    if py_args.kwarg is not None:
455        expr = py_args.kwarg
456        ctx_range = ctx.make_range(
457            expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg)
458        )
459        raise NotSupportedError(ctx_range, _vararg_kwarg_err)
460    if py_args.vararg is not None:
461        expr = py_args.vararg
462        ctx_range = ctx.make_range(
463            expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg)
464        )
465        raise NotSupportedError(ctx_range, _vararg_kwarg_err)
466    if len(py_args.kw_defaults) > 0:
467        # kw_defaults is a list of the values for the kwargs (which default to None),
468        # so they don't actually have line numbers.
469        for arg in py_args.kw_defaults:
470            if arg is not None:
471                ctx_range = build_expr(ctx, arg).range()
472                raise NotSupportedError(ctx_range, _vararg_kwarg_err)
473
474    # List of Tuple of args and type as inferred by profile directed typing
475    arg_and_types = [
476        (
477            arg,
478            pdt_arg_types[arg.arg]
479            if pdt_arg_types and bool(pdt_arg_types[arg.arg])
480            else None,
481        )
482        for arg in py_args.args
483    ]
484    arg_and_types_kwonlyargs = [
485        (
486            arg,
487            pdt_arg_types[arg.arg]
488            if pdt_arg_types and bool(pdt_arg_types[arg.arg])
489            else None,
490        )
491        for arg in py_args.kwonlyargs
492    ]
493
494    result = [
495        build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type)
496        for arg, arg_type in arg_and_types
497    ]
498    result += [
499        build_param(ctx, arg, self_name, kwarg_only=True, pdt_arg_type=arg_type)
500        for arg, arg_type in arg_and_types_kwonlyargs
501    ]
502    return result
503
504
505def build_param(ctx, py_arg, self_name, kwarg_only, pdt_arg_type=None):
506    # NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
507    name = py_arg.arg
508    r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name))
509    if getattr(py_arg, "annotation", None) is not None:
510        annotation_expr = build_expr(ctx, py_arg.annotation)
511    elif pdt_arg_type:
512        annotation_expr = Var(Ident(r, pdt_arg_type))
513    elif self_name is not None and name == "self":
514        annotation_expr = Var(Ident(r, self_name))
515    else:
516        annotation_expr = EmptyTypeAnnotation(r)
517    return Param(annotation_expr, Ident(r, name), kwarg_only)
518
519
520def build_ignore_context_manager(ctx, stmt):
521    InputType = namedtuple("InputType", ["name", "ann"])
522    OutputType = namedtuple("OutputType", ["name", "ann"])
523
524    def process_ins_outs(args):
525        # parse the context manager to figure out inputs and outputs
526        # with their annotated types
527        # TODO: add input, output validator
528        inputs = []
529        outputs = []
530        for arg in args:
531            var_name = arg.arg
532            var_ann = arg.value.value
533            var_decl_type, var_ann = var_ann.split(":")
534            if var_decl_type == "inp":
535                inputs.append(InputType(var_name, var_ann))
536            if var_decl_type == "out":
537                outputs.append(OutputType(var_name, var_ann))
538        return inputs, outputs
539
540    def create_unique_name_ext(ctx, stmt):
541        # extension will be based on the full path filename plus
542        # the line number of original context manager
543        fn = re.sub(r"[^a-zA-Z0-9_]", "_", ctx.filename)
544        return f"{fn}_{stmt.lineno}"
545
546    def build_return_ann_stmt(outputs):
547        return_type_ann = ""
548        return_statement_str = "return "
549        if len(outputs) == 0:
550            return_type_ann += " -> None"
551        if len(outputs) == 1:
552            return_type_ann = " -> " + outputs[0].ann
553            return_statement_str += outputs[0].name
554        if len(outputs) > 1:
555            return_type_ann = " -> Tuple"
556            return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]"
557            return_statement_str += ", ".join([var.name for var in outputs])
558        return return_type_ann, return_statement_str
559
560    def build_args(args):
561        return ", ".join([arg.name for arg in args])
562
563    inputs, outputs = process_ins_outs(stmt.items[0].context_expr.keywords)
564
565    # build the replacement function str with given inputs and outputs
566    ignore_function_name = "func_ignore_" + create_unique_name_ext(ctx, stmt)
567    ignore_function_str = "\ndef " + ignore_function_name
568    ignore_function_str += (
569        "(" + ", ".join([var.name + " :" + var.ann for var in inputs]) + ")"
570    )
571
572    return_ann, return_stmt = build_return_ann_stmt(outputs)
573    ignore_function_str += return_ann + ": pass"
574
575    # first create the functionDef object from just declaration
576    ignore_function = ast.parse(ignore_function_str).body[0]
577
578    # dump the body of context manager to dummy function
579    ignore_function.body = stmt.body  # type: ignore[attr-defined]
580
581    # insert return statement to the function
582    return_stmt = ast.parse(return_stmt).body[0]
583    ignore_function.body.append(return_stmt)  # type: ignore[attr-defined]
584
585    # registers the custom function in the global context
586    ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function)
587    ignore_func_str += f'\nglobals()["{ignore_function_name}"] = {ignore_function_name}'
588    exec(ignore_func_str)  # noqa: P204
589
590    # build the statements as:
591    # <out_1>, <out_2>, ... = torch.jit.frontend.<func>(<in_1>, <in_2>)
592    assign_str_lhs = build_args(outputs)
593    # this function will be registered in torch.jit.frontend module by default
594    assign_str_rhs = (
595        f"torch.jit.frontend.{ignore_function_name}(" + build_args(inputs) + ")"
596    )
597
598    if len(outputs) > 0:
599        assign_str = assign_str_lhs + " = " + assign_str_rhs
600    else:
601        assign_str = assign_str_rhs
602    assign_ast = ast.parse(assign_str).body[0]
603    return assign_ast
604
605
606def get_default_args(fn):
607    """
608    Get a dictionary of default arguments for a function.
609
610    Args:
611        fn: Callable - The function to inspect for default arguments.
612    Returns:
613        (Dict[str, Any]): mapping argument names to their default values if
614        :attr:`fn` is not None, else empty dictionary.
615    """
616    if fn is None:
617        return {}
618
619    signature = inspect.signature(fn)
620
621    return {
622        k: v.default
623        for k, v in signature.parameters.items()
624        if v.default is not inspect.Parameter.empty
625    }
626
627
628def get_default_args_for_class(cls):
629    """
630    Get default arguments for all methods in a class (except for static methods).
631
632    Args:
633        cls: type - The class type to inspect for default arguments.
634    Returns:
635        A Dict[str, Dict[str, Any]] which maps each method name to a Dict[str, Any]
636        that maps each argument name to its default value.
637    """
638    # Get methods (except static methods because those are compiled separately as
639    # if they were independent script functions).
640    methods = inspect.getmembers(
641        cls,
642        predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m))
643        and not is_static_fn(cls, m.__name__)
644        and m.__name__ in cls.__dict__,
645    )
646
647    # Get method defaults. Property defaults do not need to be considered
648    # because setters cannot be invoked without a value.
649    defaults = {
650        method_name: get_default_args(method_impl)
651        for method_name, method_impl in methods
652    }
653
654    return defaults
655
656
657class WithItemBuilder(Builder):
658    @staticmethod
659    def build_withitem(ctx, item):
660        lineno = item.context_expr.lineno
661        start = item.context_expr.col_offset
662        end = start + len(pretty_node_names[ast.With])
663        op_vars = item.optional_vars
664        r = ctx.make_range(lineno, start, end)
665
666        return WithItem(
667            r,
668            build_expr(ctx, item.context_expr),
669            build_expr(ctx, op_vars) if op_vars else None,
670        )
671
672
673class StmtBuilder(Builder):
674    augassign_map = {
675        ast.Add: "+",
676        ast.Sub: "-",
677        ast.Mult: "*",
678        ast.Div: "/",
679        ast.Mod: "%",
680        ast.BitOr: "|",
681        ast.BitAnd: "&",
682        ast.BitXor: "^",
683        ast.LShift: "<<",
684        ast.RShift: ">>",
685        ast.Pow: "**",
686    }
687
688    @staticmethod
689    def build_Expr(ctx, stmt):
690        value = stmt.value
691        if value.__class__.__name__ == "Str":
692            # If a statement is a string literal expression,
693            # then it is a docstring. Just ignore it.
694            return None
695        else:
696            return ExprStmt(build_expr(ctx, value))
697
698    @staticmethod
699    def build_Assign(ctx, stmt):
700        rhs = build_expr(ctx, stmt.value)
701        lhs = [build_expr(ctx, x) for x in stmt.targets]
702        return Assign(lhs, rhs)
703
704    @staticmethod
705    def build_AnnAssign(ctx, stmt):
706        if stmt.value is None:
707            raise UnsupportedNodeError(ctx, stmt, reason="without assigned value")
708
709        # Disallow type annotations on instance attributes outside of __init__
710        if (
711            type(stmt.target) == ast.Attribute
712            and stmt.target.value.id == "self"  # type: ignore[attr-defined]
713            and ctx.funcname != "__init__"
714        ):
715            start = stmt.col_offset
716            end = start + len(f"self.{stmt.target.attr}")
717            if hasattr(stmt.annotation, "id"):
718                end += len(f": {stmt.annotation.id}")
719            sr = ctx.make_range(stmt.lineno, start, end)
720            raise ValueError(
721                "Type annotations on instance attributes must be declared in "
722                f"__init__, not '{ctx.funcname}': {sr}"
723            )
724
725        rhs = build_expr(ctx, stmt.value)
726        lhs = build_expr(ctx, stmt.target)
727        the_type = build_expr(ctx, stmt.annotation)
728        return Assign([lhs], rhs, the_type)
729
730    @staticmethod
731    def build_Delete(ctx, stmt):
732        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("del"))
733
734        return Delete(r, [build_expr(ctx, target) for target in stmt.targets])
735
736    @staticmethod
737    def build_Return(ctx, stmt):
738        r = ctx.make_range(
739            stmt.lineno, stmt.col_offset, stmt.col_offset + len("return")
740        )
741        return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value))
742
743    @staticmethod
744    def build_Raise(ctx, stmt):
745        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("raise"))
746        expr = build_expr(ctx, stmt.exc)
747        return Raise(r, expr)
748
749    @staticmethod
750    def build_Assert(ctx, stmt):
751        r = ctx.make_range(
752            stmt.lineno, stmt.col_offset, stmt.col_offset + len("assert")
753        )
754        test = build_expr(ctx, stmt.test)
755        msg = build_expr(ctx, stmt.msg) if stmt.msg is not None else None
756        return Assert(r, test, msg)
757
758    @staticmethod
759    def build_AugAssign(ctx, stmt):
760        lhs = build_expr(ctx, stmt.target)
761        rhs = build_expr(ctx, stmt.value)
762        op = type(stmt.op)
763        if op in StmtBuilder.augassign_map:
764            op_token = StmtBuilder.augassign_map[op]
765        else:
766            raise NotSupportedError(
767                find_before(ctx, rhs.range().start, "=", offsets=(-1, 0)),
768                "unsupported kind of augmented assignment: " + op.__name__,
769            )
770        return AugAssign(lhs, op_token, rhs)
771
772    @staticmethod
773    def build_While(ctx, stmt):
774        if stmt.orelse:
775            # TODO: try to recover the location of else:? Python doesn't give us useful
776            # annotations in this case
777            raise NotSupportedError(
778                None, "else branches of while loops aren't supported"
779            )
780        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("while"))
781        return While(r, build_expr(ctx, stmt.test), build_stmts(ctx, stmt.body))
782
783    @staticmethod
784    def build_For(ctx, stmt):
785        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("for"))
786        if stmt.orelse:
787            raise NotSupportedError(r, "else branches of for loops aren't supported")
788
789        return For(
790            r,
791            [build_expr(ctx, stmt.target)],
792            [build_expr(ctx, stmt.iter)],
793            build_stmts(ctx, stmt.body),
794        )
795
796    @staticmethod
797    def build_If(ctx, stmt):
798        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("if"))
799        return If(
800            r,
801            build_expr(ctx, stmt.test),
802            build_stmts(ctx, stmt.body),
803            build_stmts(ctx, stmt.orelse),
804        )
805
806    @staticmethod
807    def build_Print(ctx, stmt):
808        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("print"))
809        if stmt.dest:
810            raise NotSupportedError(
811                r, "print statements with non-default destinations aren't supported"
812            )
813        args = [build_expr(ctx, val) for val in stmt.values]
814        return ExprStmt(Apply(Var(Ident(r, "print")), args, []))
815
816    @staticmethod
817    def build_Pass(ctx, stmt):
818        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("pass"))
819        return Pass(r)
820
821    @staticmethod
822    def build_Break(ctx, stmt):
823        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("break"))
824        return Break(r)
825
826    @staticmethod
827    def build_Continue(ctx, stmt):
828        r = ctx.make_range(
829            stmt.lineno, stmt.col_offset, stmt.col_offset + len("continue")
830        )
831        return Continue(r)
832
833    @staticmethod
834    def build_With(ctx, stmt):
835        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("with"))
836        # Handle ignore context manager
837        if is_torch_jit_ignore_context_manager(stmt):
838            if not _IS_ASTUNPARSE_INSTALLED:
839                raise RuntimeError(
840                    "torch.jit._IgnoreContextManager requires installing Python library `astunparse`, \
841                                   please install it in your Python environment"
842                )
843            assign_ast = build_ignore_context_manager(ctx, stmt)
844            return build_stmt(ctx, assign_ast)
845        return With(r, build_withitems(ctx, stmt.items), build_stmts(ctx, stmt.body))
846
847
848class ExprBuilder(Builder):
849    binop_map = {
850        ast.Add: "+",
851        ast.Sub: "-",
852        ast.Mult: "*",
853        ast.Div: "/",
854        ast.Pow: "**",
855        ast.Mod: "%",
856        ast.FloorDiv: "//",
857        ast.BitAnd: "&",
858        ast.BitXor: "^",
859        ast.BitOr: "|",
860        ast.LShift: "<<",
861        ast.RShift: ">>",
862    }
863
864    binop_map[ast.MatMult] = "@"
865
866    unop_map = {
867        ast.Not: "not",
868        ast.USub: "-",
869        ast.Invert: "~",
870    }
871
872    boolop_map = {
873        ast.And: "and",
874        ast.Or: "or",
875    }
876
877    cmpop_map = {
878        ast.Eq: "==",
879        ast.NotEq: "!=",
880        ast.LtE: "<=",
881        ast.Lt: "<",
882        ast.GtE: ">=",
883        ast.Gt: ">",
884        ast.Is: "is",
885        ast.IsNot: "is not",
886        ast.In: "in",
887        ast.NotIn: "not in",
888    }
889
890    @staticmethod
891    def build_Attribute(ctx, expr):
892        base = build_expr(ctx, expr.value)
893        # expr.attr is just a string, so it's not annotated in any way, so we have
894        # to build the range manually
895        source = ctx.source.encode("utf-8")
896
897        def get_char(index):
898            return chr(source[index])
899
900        start_pos = base.range().end + 1
901        while get_char(start_pos) in string.whitespace:  # Skip whitespace
902            start_pos += 1
903        end_pos = start_pos + len(expr.attr)
904        name_range = ctx.make_raw_range(start_pos, end_pos)
905        return Select(base, Ident(name_range, expr.attr))
906
907    @staticmethod
908    def build_Call(ctx, expr):
909        func = build_expr(ctx, expr.func)
910        args = [build_expr(ctx, py_arg) for py_arg in expr.args]
911        if hasattr(expr, "starargs") and expr.starargs:
912            stararg_expr = build_expr(ctx, expr.starargs)
913            args += [Starred(stararg_expr.range(), stararg_expr)]
914        kwargs = []
915        for kw in expr.keywords:
916            kw_expr = build_expr(ctx, kw.value)
917            # XXX: we could do a better job at figuring out the range for the name here
918            if not kw.arg:
919                raise NotSupportedError(
920                    kw_expr.range(), "keyword-arg expansion is not supported"
921                )
922            kwargs.append(Attribute(Ident(kw_expr.range(), kw.arg), kw_expr))
923        return Apply(func, args, kwargs)
924
925    @staticmethod
926    def build_Ellipsis(ctx, expr):
927        r = ctx.make_range(
928            expr.lineno, expr.col_offset, expr.col_offset + 3
929        )  # len("...") == 3
930        return Dots(r)
931
932    @staticmethod
933    def build_Name(ctx, expr):
934        r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
935        if expr.id.startswith(_reserved_prefix):
936            raise NotSupportedError(
937                r,
938                "names of variables used in JIT-ed functions "
939                "can't start with " + _reserved_prefix,
940            )
941        if expr.id == "True":
942            return TrueLiteral(r)
943        elif expr.id == "False":
944            return FalseLiteral(r)
945        elif expr.id == "None":
946            return NoneLiteral(r)
947        elif expr.id == "Ellipsis":
948            return Dots(r)
949        return Var(Ident(r, expr.id))
950
951    @staticmethod
952    def build_NameConstant(ctx, expr):
953        r = ctx.make_range(
954            expr.lineno, expr.col_offset, expr.col_offset + len(str(expr.value))
955        )
956        if expr.value is True:
957            return TrueLiteral(r)
958        elif expr.value is False:
959            return FalseLiteral(r)
960        elif expr.value is None:
961            return NoneLiteral(r)
962        elif expr.value == Ellipsis:
963            return Dots(r)
964        else:
965            raise ValueError("Name constant value unsupported: " + str(expr.value))
966
967    @staticmethod
968    def build_BinOp(ctx, expr):
969        lhs = build_expr(ctx, expr.left)
970        rhs = build_expr(ctx, expr.right)
971        op = type(expr.op)
972
973        if op == ast.Div and not ctx.uses_true_division:
974            err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
975            raise FrontendError(
976                err_range,
977                "Division of ints in TorchScript uses Python 3 true "
978                "division semantics. Please put `from __future__ "
979                "import division` at the top of your file",
980            )
981        op_token = ExprBuilder.binop_map.get(op)
982        if op_token is None:
983            err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
984            raise NotSupportedError(
985                err_range, "unsupported binary operator: " + op.__name__
986            )
987        return BinOp(op_token, lhs, rhs)
988
989    @staticmethod
990    def build_UnaryOp(ctx, expr):
991        sub_expr = build_expr(ctx, expr.operand)
992        op = type(expr.op)
993        op_token = ExprBuilder.unop_map.get(op)
994        if op_token is None:
995            raise NotSupportedError(
996                expr.range(), "unsupported unary operator: " + op.__name__
997            )
998        r = ctx.make_range(
999            expr.lineno, expr.col_offset, expr.col_offset + len(op_token)
1000        )
1001        return UnaryOp(r, op_token, sub_expr)
1002
1003    @staticmethod
1004    def build_BoolOp(ctx, expr):
1005        if len(expr.values) < 2:
1006            raise AssertionError(
1007                "expected at least 2 values in BoolOp, but got " + str(len(expr.values))
1008            )
1009        sub_exprs = [build_expr(ctx, sub_expr) for sub_expr in expr.values]
1010        op = type(expr.op)
1011        op_token = ExprBuilder.boolop_map.get(op)
1012        if op_token is None:
1013            err_range = ctx.make_raw_range(
1014                sub_exprs[0].range().end, sub_exprs[1].range().start
1015            )
1016            raise NotSupportedError(
1017                err_range, "unsupported boolean operator: " + op.__name__
1018            )
1019        lhs = sub_exprs[0]
1020        for rhs in sub_exprs[1:]:
1021            lhs = BinOp(op_token, lhs, rhs)
1022        return lhs
1023
1024    @staticmethod
1025    def build_IfExp(ctx, expr):
1026        return TernaryIf(
1027            build_expr(ctx, expr.test),
1028            build_expr(ctx, expr.body),
1029            build_expr(ctx, expr.orelse),
1030        )
1031
1032    @staticmethod
1033    def build_Compare(ctx, expr):
1034        operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)]
1035        result = None
1036        for lhs, op_, rhs in zip(operands, expr.ops, operands[1:]):
1037            op = type(op_)
1038            op_token = ExprBuilder.cmpop_map.get(op)
1039            r = ctx.make_raw_range(lhs.range().end, rhs.range().start)
1040            if op_token is None:
1041                raise NotSupportedError(
1042                    r, "unsupported comparison operator: " + op.__name__
1043                )
1044
1045            if op == ast.NotIn:
1046                # NB: `not in` is just `not( in )`, so we don't introduce new tree view
1047                # but just make it a nested call in our tree view structure
1048                in_expr = BinOp("in", lhs, rhs)
1049                cmp_expr = UnaryOp(r, "not", in_expr)
1050            else:
1051                cmp_expr = BinOp(op_token, lhs, rhs)
1052
1053            if result is None:
1054                result = cmp_expr
1055            else:
1056                result = BinOp("and", result, cmp_expr)
1057        return result
1058
1059    @staticmethod
1060    def build_Subscript(ctx, expr):
1061        def build_SliceExpr(ctx, base, slice_expr):
1062            lower = (
1063                build_expr(ctx, slice_expr.lower)
1064                if slice_expr.lower is not None
1065                else None
1066            )
1067            upper = (
1068                build_expr(ctx, slice_expr.upper)
1069                if slice_expr.upper is not None
1070                else None
1071            )
1072            step = (
1073                build_expr(ctx, slice_expr.step)
1074                if slice_expr.step is not None
1075                else None
1076            )
1077            return SliceExpr(base.range(), lower, upper, step)
1078
1079        def build_Index(ctx, base, index_expr):
1080            if isinstance(index_expr.value, ast.Tuple):
1081                raise NotSupportedError(
1082                    base.range(),
1083                    "slicing multiple dimensions with tuples not supported yet",
1084                )
1085            return build_expr(ctx, index_expr.value)
1086
1087        def build_ExtSlice(ctx, base, extslice):
1088            sub_exprs = []
1089            for expr in extslice.dims:
1090                sub_type = type(expr)
1091                if sub_type is ast.Index:
1092                    sub_exprs.append(build_Index(ctx, base, expr))
1093                elif sub_type is ast.Slice:
1094                    sub_exprs.append(build_SliceExpr(ctx, base, expr))
1095                elif sub_type is ast.Constant and expr.value is Ellipsis:
1096                    sub_exprs.append(Dots(base.range()))
1097                else:
1098                    raise NotSupportedError(
1099                        base.range(),
1100                        f"slicing multiple dimensions with {sub_type} not supported",
1101                    )
1102            return sub_exprs
1103
1104        base = build_expr(ctx, expr.value)
1105        sub_type = type(expr.slice)
1106        if sub_type is ast.Index:
1107            if isinstance(expr.slice.value, ast.Tuple):
1108                # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k]
1109                # XXX: Indexing using a list is **different**! It triggers advanced indexing.
1110                indices = [
1111                    build_expr(ctx, index_expr) for index_expr in expr.slice.value.elts
1112                ]
1113                if not indices:
1114                    # `col_offset` is an int, but `end_col_offset` is
1115                    # `Optional[int]`. The magic number is here to make
1116                    # sure we can parse `()` on any machine
1117                    r = ctx.make_range(
1118                        expr.lineno,
1119                        expr.slice.value.col_offset,
1120                        expr.slice.value.col_offset + 2,
1121                    )
1122                    tup = TupleLiteral(r, [])
1123                    indices.append(tup)
1124                return Subscript(base, indices)
1125            else:
1126                return Subscript(base, [build_expr(ctx, expr.slice.value)])
1127        elif sub_type is ast.Slice:
1128            return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)])
1129        elif sub_type is ast.ExtSlice:
1130            return Subscript(base, build_ExtSlice(ctx, base, expr.slice))
1131        elif sys.version_info >= (
1132            3,
1133            9,
1134        ):  # In Python3.9 array indicies are not wrapped in ast.Index
1135            if sub_type is ast.Tuple:
1136                # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k]
1137                indices = []
1138                for index_expr in expr.slice.elts:
1139                    if isinstance(index_expr, ast.Slice):
1140                        indices.append(build_SliceExpr(ctx, base, index_expr))
1141                    else:
1142                        indices.append(build_expr(ctx, index_expr))
1143                # Special-case logic for `typing.Tuple[()]`
1144                if not indices:
1145                    # See note above r.e. magic number
1146                    r = ctx.make_range(
1147                        expr.lineno, expr.slice.col_offset, expr.slice.col_offset + 2
1148                    )
1149                    tup = TupleLiteral(r, [])
1150                    indices.append(tup)
1151                return Subscript(base, indices)
1152            return Subscript(base, [build_expr(ctx, expr.slice)])
1153        else:  # Ellipsis (can only happen in Python 2)
1154            raise NotSupportedError(base.range(), "ellipsis is not supported")
1155
1156    @staticmethod
1157    def build_List(ctx, expr):
1158        return ListLiteral(
1159            ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
1160            [build_expr(ctx, e) for e in expr.elts],
1161        )
1162
1163    @staticmethod
1164    def build_Tuple(ctx, expr):
1165        return TupleLiteral(
1166            ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
1167            [build_expr(ctx, e) for e in expr.elts],
1168        )
1169
1170    @staticmethod
1171    def build_Dict(ctx, expr):
1172        range = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
1173        if expr.keys and not expr.keys[0]:
1174            raise NotSupportedError(
1175                range, "Dict expansion (e.g. `{**dict}`) is not supported"
1176            )
1177        return DictLiteral(
1178            range,
1179            [build_expr(ctx, e) for e in expr.keys],
1180            [build_expr(ctx, e) for e in expr.values],
1181        )
1182
1183    @staticmethod
1184    def build_Num(ctx, expr):
1185        value = str(expr.value)
1186        r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value))
1187        return Const(r, value)
1188
1189    @staticmethod
1190    def build_Constant(ctx, expr):
1191        value = expr.value
1192        if value is None or isinstance(value, bool):
1193            # NB: this check has to happen before the int check because bool is
1194            # a subclass of int
1195            return ExprBuilder.build_NameConstant(ctx, expr)
1196        if isinstance(value, (int, float, complex)):
1197            return ExprBuilder.build_Num(ctx, expr)
1198        elif isinstance(value, str):
1199            return ExprBuilder.build_Str(ctx, expr)
1200        elif isinstance(value, type(Ellipsis)):
1201            return ExprBuilder.build_Ellipsis(ctx, expr)
1202        else:
1203            error_range = ctx.make_range(
1204                expr.lineno, expr.col_offset, expr.col_offset + len(str(value))
1205            )
1206            raise FrontendError(error_range, "Unknown Constant expression type")
1207
1208    @staticmethod
1209    def build_Str(ctx, expr):
1210        value = str(expr.value)
1211        r = ctx.make_range(
1212            expr.lineno, expr.col_offset, expr.col_offset + len(value) + 1
1213        )
1214        return StringLiteral(r, value)
1215
1216    @staticmethod
1217    def build_JoinedStr(ctx, expr):
1218        s = ""
1219        args = []
1220        for value in expr.values:
1221            r = ctx.make_range(value.lineno, value.col_offset, value.col_offset + 1)
1222            if isinstance(value, ast.FormattedValue):
1223                if value.conversion != -1:
1224                    raise NotSupportedError(r, "Don't support conversion in JoinedStr")
1225                if value.format_spec is not None:
1226                    raise NotSupportedError(r, "Don't support formatting in JoinedStr")
1227                s += "{}"
1228                args.append(build_expr(ctx, value.value))
1229            elif isinstance(value, ast.Constant):
1230                s += value.value
1231            else:
1232                raise NotSupportedError(r, "Unsupported value in JoinedStr")
1233
1234        r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
1235        return Apply(Select(StringLiteral(r, s), Ident(r, "format")), args, [])
1236
1237    @staticmethod
1238    def build_ListComp(ctx, stmt):
1239        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
1240        if len(stmt.generators) != 1:
1241            raise NotSupportedError(r, "Only a single generator is currently supported")
1242
1243        if len(stmt.generators[0].ifs) != 0:
1244            raise NotSupportedError(r, "Comprehension ifs are not supported yet")
1245
1246        elt_expr = build_expr(ctx, stmt.elt)
1247        target_expr = build_expr(ctx, stmt.generators[0].target)
1248        iter_expr = build_expr(ctx, stmt.generators[0].iter)
1249
1250        return ListComp(r, elt_expr, target_expr, iter_expr)
1251
1252    @staticmethod
1253    def build_GeneratorExp(ctx, stmt):
1254        # Convert Generator expression to ListComp
1255        return ExprBuilder.build_ListComp(ctx, stmt)
1256
1257    @staticmethod
1258    def build_DictComp(ctx, stmt):
1259        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
1260        if len(stmt.generators) != 1:
1261            raise NotSupportedError(r, "Only a single generator is currently supported")
1262
1263        if len(stmt.generators[0].ifs) != 0:
1264            raise NotSupportedError(r, "Comprehension ifs are not supported yet")
1265
1266        key_expr = build_expr(ctx, stmt.key)
1267        value_expr = build_expr(ctx, stmt.value)
1268        target_expr = build_expr(ctx, stmt.generators[0].target)
1269        iter_expr = build_expr(ctx, stmt.generators[0].iter)
1270
1271        return DictComp(r, key_expr, value_expr, target_expr, iter_expr)
1272
1273    @staticmethod
1274    def build_Starred(ctx, expr):
1275        r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
1276        return Starred(r, build_expr(ctx, expr.value))
1277
1278
1279build_expr = ExprBuilder()
1280build_stmt = StmtBuilder()
1281build_withitem = WithItemBuilder()
1282
1283
1284def find_before(ctx, pos, substr, offsets=(0, 0)):
1285    new_pos = ctx.source[:pos].rindex(substr)
1286    return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
1287