xref: /aosp_15_r20/external/executorch/exir/print_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport copy
10*523fa7a6SAndroid Build Coastguard Workerimport re
11*523fa7a6SAndroid Build Coastguard Workerimport reprlib
12*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import fields
13*523fa7a6SAndroid Build Coastguard Workerfrom enum import IntEnum
14*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, List, Optional, TextIO
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Workerimport torch
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType, InternalError
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import (
20*523fa7a6SAndroid Build Coastguard Worker    Bool,
21*523fa7a6SAndroid Build Coastguard Worker    BoolList,
22*523fa7a6SAndroid Build Coastguard Worker    DelegateCall,
23*523fa7a6SAndroid Build Coastguard Worker    Double,
24*523fa7a6SAndroid Build Coastguard Worker    DoubleList,
25*523fa7a6SAndroid Build Coastguard Worker    EValue,
26*523fa7a6SAndroid Build Coastguard Worker    Frame,
27*523fa7a6SAndroid Build Coastguard Worker    FrameList,
28*523fa7a6SAndroid Build Coastguard Worker    FreeCall,
29*523fa7a6SAndroid Build Coastguard Worker    Int,
30*523fa7a6SAndroid Build Coastguard Worker    IntList,
31*523fa7a6SAndroid Build Coastguard Worker    JumpFalseCall,
32*523fa7a6SAndroid Build Coastguard Worker    KernelCall,
33*523fa7a6SAndroid Build Coastguard Worker    MoveCall,
34*523fa7a6SAndroid Build Coastguard Worker    Null,
35*523fa7a6SAndroid Build Coastguard Worker    OptionalTensorList,
36*523fa7a6SAndroid Build Coastguard Worker    Program,
37*523fa7a6SAndroid Build Coastguard Worker    ScalarType,
38*523fa7a6SAndroid Build Coastguard Worker    String,
39*523fa7a6SAndroid Build Coastguard Worker    Tensor,
40*523fa7a6SAndroid Build Coastguard Worker    TensorList,
41*523fa7a6SAndroid Build Coastguard Worker    TensorShapeDynamism,
42*523fa7a6SAndroid Build Coastguard Worker)
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Workerdef _scalar_type_str(scalar_type: ScalarType) -> str:
46*523fa7a6SAndroid Build Coastguard Worker    type2str = {
47*523fa7a6SAndroid Build Coastguard Worker        ScalarType.BYTE: "bt",
48*523fa7a6SAndroid Build Coastguard Worker        ScalarType.CHAR: "c",
49*523fa7a6SAndroid Build Coastguard Worker        ScalarType.SHORT: "s",
50*523fa7a6SAndroid Build Coastguard Worker        ScalarType.INT: "i",
51*523fa7a6SAndroid Build Coastguard Worker        ScalarType.LONG: "l",
52*523fa7a6SAndroid Build Coastguard Worker        ScalarType.HALF: "h",
53*523fa7a6SAndroid Build Coastguard Worker        ScalarType.FLOAT: "f",
54*523fa7a6SAndroid Build Coastguard Worker        ScalarType.DOUBLE: "d",
55*523fa7a6SAndroid Build Coastguard Worker        ScalarType.COMPLEX32: "c32",
56*523fa7a6SAndroid Build Coastguard Worker        ScalarType.COMPLEX64: "c64",
57*523fa7a6SAndroid Build Coastguard Worker        ScalarType.COMPLEX128: "c128",
58*523fa7a6SAndroid Build Coastguard Worker        ScalarType.BOOL: "b",
59*523fa7a6SAndroid Build Coastguard Worker        ScalarType.QINT8: "qi8",
60*523fa7a6SAndroid Build Coastguard Worker        ScalarType.QUINT8: "qui8",
61*523fa7a6SAndroid Build Coastguard Worker        ScalarType.QINT32: "qi32",
62*523fa7a6SAndroid Build Coastguard Worker        ScalarType.BFLOAT16: "bf16",
63*523fa7a6SAndroid Build Coastguard Worker        ScalarType.QUINT4x2: "qui4x2",
64*523fa7a6SAndroid Build Coastguard Worker        ScalarType.QUINT2x4: "qui2x4",
65*523fa7a6SAndroid Build Coastguard Worker    }
66*523fa7a6SAndroid Build Coastguard Worker    if not (ret := type2str.get(scalar_type, None)):
67*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"Unrecognized scalar_type: {scalar_type}")
68*523fa7a6SAndroid Build Coastguard Worker    else:
69*523fa7a6SAndroid Build Coastguard Worker        return ret
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Workerdef _is_dynamic_shape_tensor(tensor: Tensor) -> bool:
73*523fa7a6SAndroid Build Coastguard Worker    return tensor.shape_dynamism != TensorShapeDynamism.STATIC
74*523fa7a6SAndroid Build Coastguard Worker
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Workerdef _format_evalue(  # noqa: C901
77*523fa7a6SAndroid Build Coastguard Worker    evalue: EValue, show_meminfo: bool, mark_dynamic_shape_tensor: bool
78*523fa7a6SAndroid Build Coastguard Worker) -> str:
79*523fa7a6SAndroid Build Coastguard Worker    evstr = "\033[34m"
80*523fa7a6SAndroid Build Coastguard Worker    if isinstance(evalue.val, Tensor):
81*523fa7a6SAndroid Build Coastguard Worker        tensor = evalue.val
82*523fa7a6SAndroid Build Coastguard Worker        if tensor.data_buffer_idx > 0:
83*523fa7a6SAndroid Build Coastguard Worker            assert not _is_dynamic_shape_tensor(
84*523fa7a6SAndroid Build Coastguard Worker                tensor
85*523fa7a6SAndroid Build Coastguard Worker            ), "A constant tensor can not be dynamic shape"
86*523fa7a6SAndroid Build Coastguard Worker            evstr += "CT"  # constant tensor
87*523fa7a6SAndroid Build Coastguard Worker            assert tensor.allocation_info is None
88*523fa7a6SAndroid Build Coastguard Worker        else:
89*523fa7a6SAndroid Build Coastguard Worker            if mark_dynamic_shape_tensor:
90*523fa7a6SAndroid Build Coastguard Worker                if tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND:
91*523fa7a6SAndroid Build Coastguard Worker                    evstr += "UB"  # upper bound tensor will be shown as 'UBT'
92*523fa7a6SAndroid Build Coastguard Worker                elif tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND:
93*523fa7a6SAndroid Build Coastguard Worker                    evstr += "DU"  # dynamic unbound tensor will be shown as 'DUT'
94*523fa7a6SAndroid Build Coastguard Worker            evstr += "T"
95*523fa7a6SAndroid Build Coastguard Worker            if show_meminfo:
96*523fa7a6SAndroid Build Coastguard Worker                if tensor.allocation_info:
97*523fa7a6SAndroid Build Coastguard Worker                    evstr += f"m{tensor.allocation_info.memory_id}.{tensor.allocation_info.memory_offset}"
98*523fa7a6SAndroid Build Coastguard Worker                else:
99*523fa7a6SAndroid Build Coastguard Worker                    evstr += "m."
100*523fa7a6SAndroid Build Coastguard Worker        evstr += f"{tensor.sizes}{_scalar_type_str(tensor.scalar_type)}"
101*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(evalue.val, TensorList):
102*523fa7a6SAndroid Build Coastguard Worker        evstr += "TL"
103*523fa7a6SAndroid Build Coastguard Worker        tensorlist = evalue.val
104*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore
105*523fa7a6SAndroid Build Coastguard Worker        evstr += str(tensorlist.items)
106*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(evalue.val, OptionalTensorList):
107*523fa7a6SAndroid Build Coastguard Worker        evstr += "OTL"
108*523fa7a6SAndroid Build Coastguard Worker        optionaltensorlist = evalue.val
109*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore
110*523fa7a6SAndroid Build Coastguard Worker        evstr += str(optionaltensorlist.items)
111*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(evalue.val, IntList):
112*523fa7a6SAndroid Build Coastguard Worker        evstr += "IL"
113*523fa7a6SAndroid Build Coastguard Worker        intlist = evalue.val
114*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore
115*523fa7a6SAndroid Build Coastguard Worker        evstr += str(intlist.items)
116*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(evalue.val, DoubleList):
117*523fa7a6SAndroid Build Coastguard Worker        evstr += "DL"
118*523fa7a6SAndroid Build Coastguard Worker        doublelist = evalue.val
119*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore
120*523fa7a6SAndroid Build Coastguard Worker        evstr += str(doublelist.items)
121*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(evalue.val, BoolList):
122*523fa7a6SAndroid Build Coastguard Worker        evstr += "BL"
123*523fa7a6SAndroid Build Coastguard Worker        boollist = evalue.val
124*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore
125*523fa7a6SAndroid Build Coastguard Worker        evstr += str(boollist.items)
126*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(evalue.val, Int):
127*523fa7a6SAndroid Build Coastguard Worker        intval = evalue.val
128*523fa7a6SAndroid Build Coastguard Worker        evstr += f"I{intval.int_val}"
129*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(evalue.val, Double):
130*523fa7a6SAndroid Build Coastguard Worker        doubleval = evalue.val
131*523fa7a6SAndroid Build Coastguard Worker        evstr += f"D{doubleval.double_val}"
132*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(evalue.val, Bool):
133*523fa7a6SAndroid Build Coastguard Worker        boolval = evalue.val
134*523fa7a6SAndroid Build Coastguard Worker        evstr += f"B{int(boolval.bool_val)}"  # print 0, 1 since it's shorter than false, true
135*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(evalue.val, String):
136*523fa7a6SAndroid Build Coastguard Worker        stringval = evalue.val
137*523fa7a6SAndroid Build Coastguard Worker        evstr += f"S{stringval.string_val}"
138*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(evalue.val, Null):
139*523fa7a6SAndroid Build Coastguard Worker        evstr += "N"  # for null
140*523fa7a6SAndroid Build Coastguard Worker    else:
141*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"Unrecognized type of evalue: {evalue}")
142*523fa7a6SAndroid Build Coastguard Worker    evstr += "\033[0m"
143*523fa7a6SAndroid Build Coastguard Worker    return evstr
144*523fa7a6SAndroid Build Coastguard Worker
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Workerdef print_program(  # noqa: C901
147*523fa7a6SAndroid Build Coastguard Worker    program: Program,
148*523fa7a6SAndroid Build Coastguard Worker    show_meminfo: bool = True,
149*523fa7a6SAndroid Build Coastguard Worker    mark_dynamic_shape_tensor: bool = False,
150*523fa7a6SAndroid Build Coastguard Worker    out: Optional[TextIO] = None,
151*523fa7a6SAndroid Build Coastguard Worker) -> None:
152*523fa7a6SAndroid Build Coastguard Worker    """
153*523fa7a6SAndroid Build Coastguard Worker    Dump the instruction list of a program in a more human readable fashion.
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Worker    The dump follows the following BNF syntax (I combime some regex syntax
156*523fa7a6SAndroid Build Coastguard Worker    so the grammar becomes shorter. The grammar is not strict but the main
157*523fa7a6SAndroid Build Coastguard Worker    purpose is to let people understand the dump):
158*523fa7a6SAndroid Build Coastguard Worker    ```
159*523fa7a6SAndroid Build Coastguard Worker      PROGRAM: (INSTRUCTION)+
160*523fa7a6SAndroid Build Coastguard Worker      INSTRUCTION: SEQUENCE_NO ':' (CALL_KERNEL | JUMP_FALSE)
161*523fa7a6SAndroid Build Coastguard Worker      JUMP_FALSE: 'JF' '(' EVALUE ')' '->' TARGET_SEQUENCE_NO
162*523fa7a6SAndroid Build Coastguard Worker      CALL_KERNEL: OVERLOADDED_OP_NAME ARGS
163*523fa7a6SAndroid Build Coastguard Worker      ARGS: EVALUE | ARGS ',' EVALUE
164*523fa7a6SAndroid Build Coastguard Worker      EVALUE: EVALUE_IDX ( TENSOR | INT | BOOL | ...)
165*523fa7a6SAndroid Build Coastguard Worker      INT: 'I' ACTUAL_INT_VALUE
166*523fa7a6SAndroid Build Coastguard Worker      BOOL: 'B' ZERO_OR_ONE
167*523fa7a6SAndroid Build Coastguard Worker      CONST_TENSOR_PREFIX: 'CT'
168*523fa7a6SAndroid Build Coastguard Worker      TENSOR: ('T' | CONST_TENSOR_PREFIX) (MEM_ALLOCATION_INFO)? TENSOR_SHAPE TENSOR_DTYPE
169*523fa7a6SAndroid Build Coastguard Worker      TENSOR_SHAPE: '[' dim0_size, dim1_size, ..., last_dim_size ']'
170*523fa7a6SAndroid Build Coastguard Worker      MEM_ALLOCATION_INFO: PLANNED_MEM_INFO | UNPLANNED_MEM_INFO
171*523fa7a6SAndroid Build Coastguard Worker      PLANNED_MEM_INFO: 'm' MEM_LAYER_ID '.' MEM_LAYER_OFFSET
172*523fa7a6SAndroid Build Coastguard Worker      UNPLANNED_MEM_INFO: 'm.'
173*523fa7a6SAndroid Build Coastguard Worker    ```
174*523fa7a6SAndroid Build Coastguard Worker
175*523fa7a6SAndroid Build Coastguard Worker    To make the dump easier to read, it's colored as follows:
176*523fa7a6SAndroid Build Coastguard Worker    1. input/output EValues are marked as red
177*523fa7a6SAndroid Build Coastguard Worker    2. EValue types (or more specifically tensor types with size and dtype) are marked as blue
178*523fa7a6SAndroid Build Coastguard Worker    """
179*523fa7a6SAndroid Build Coastguard Worker    execution_plan = program.execution_plan[0]
180*523fa7a6SAndroid Build Coastguard Worker    operators = execution_plan.operators
181*523fa7a6SAndroid Build Coastguard Worker    delegates = execution_plan.delegates
182*523fa7a6SAndroid Build Coastguard Worker    chain = execution_plan.chains[0]
183*523fa7a6SAndroid Build Coastguard Worker    instructions = chain.instructions
184*523fa7a6SAndroid Build Coastguard Worker    inputs: List[int] = execution_plan.inputs
185*523fa7a6SAndroid Build Coastguard Worker    outputs: List[int] = execution_plan.outputs
186*523fa7a6SAndroid Build Coastguard Worker    values: List[EValue] = execution_plan.values
187*523fa7a6SAndroid Build Coastguard Worker
188*523fa7a6SAndroid Build Coastguard Worker    def _format_arg(evalue_idx: int) -> str:
189*523fa7a6SAndroid Build Coastguard Worker        def _get_io_index(iolist: List[int], target_evalue_idx: int) -> int:
190*523fa7a6SAndroid Build Coastguard Worker            """
191*523fa7a6SAndroid Build Coastguard Worker            The list is short enough so linear scan is proper.
192*523fa7a6SAndroid Build Coastguard Worker            """
193*523fa7a6SAndroid Build Coastguard Worker            for io_idx, evalue_idx in enumerate(iolist):
194*523fa7a6SAndroid Build Coastguard Worker                if evalue_idx == target_evalue_idx:
195*523fa7a6SAndroid Build Coastguard Worker                    return io_idx
196*523fa7a6SAndroid Build Coastguard Worker            return -1
197*523fa7a6SAndroid Build Coastguard Worker
198*523fa7a6SAndroid Build Coastguard Worker        argstr = str(evalue_idx)
199*523fa7a6SAndroid Build Coastguard Worker        if (input_idx := _get_io_index(inputs, evalue_idx)) >= 0:
200*523fa7a6SAndroid Build Coastguard Worker            argstr += f"\033[31mI{input_idx}\033[0m"
201*523fa7a6SAndroid Build Coastguard Worker        if (output_idx := _get_io_index(outputs, evalue_idx)) >= 0:
202*523fa7a6SAndroid Build Coastguard Worker            argstr += f"\033[31mO{output_idx}\033[0m"
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker        # EValue type
205*523fa7a6SAndroid Build Coastguard Worker        evalue = values[evalue_idx]
206*523fa7a6SAndroid Build Coastguard Worker        return argstr + _format_evalue(evalue, show_meminfo, mark_dynamic_shape_tensor)
207*523fa7a6SAndroid Build Coastguard Worker
208*523fa7a6SAndroid Build Coastguard Worker    print(
209*523fa7a6SAndroid Build Coastguard Worker        f"The program contains the following {len(instructions)} instructions", file=out
210*523fa7a6SAndroid Build Coastguard Worker    )
211*523fa7a6SAndroid Build Coastguard Worker    for idx, instr in enumerate(instructions):
212*523fa7a6SAndroid Build Coastguard Worker        print(f"{idx:3}: ", end="", file=out)
213*523fa7a6SAndroid Build Coastguard Worker        if isinstance(instr.instr_args, KernelCall):
214*523fa7a6SAndroid Build Coastguard Worker            kernel = instr.instr_args
215*523fa7a6SAndroid Build Coastguard Worker            op = operators[kernel.op_index]
216*523fa7a6SAndroid Build Coastguard Worker            args = kernel.args
217*523fa7a6SAndroid Build Coastguard Worker
218*523fa7a6SAndroid Build Coastguard Worker            opname = f"{op.name}.{op.overload}" if op.overload else op.name
219*523fa7a6SAndroid Build Coastguard Worker            argstr = ",".join(map(_format_arg, args))
220*523fa7a6SAndroid Build Coastguard Worker            print(f"{opname} {argstr}", file=out)
221*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(instr.instr_args, DelegateCall):
222*523fa7a6SAndroid Build Coastguard Worker            delegate = instr.instr_args
223*523fa7a6SAndroid Build Coastguard Worker            backend = delegates[delegate.delegate_index]
224*523fa7a6SAndroid Build Coastguard Worker            args = delegate.args
225*523fa7a6SAndroid Build Coastguard Worker            backend_id = f"{backend.id}"
226*523fa7a6SAndroid Build Coastguard Worker            argstr = ",".join(map(_format_arg, args))
227*523fa7a6SAndroid Build Coastguard Worker            print(f"{backend_id} {argstr}", file=out)
228*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(instr.instr_args, JumpFalseCall):
229*523fa7a6SAndroid Build Coastguard Worker            jfcall = instr.instr_args
230*523fa7a6SAndroid Build Coastguard Worker            print(
231*523fa7a6SAndroid Build Coastguard Worker                f"JF ({_format_arg(jfcall.cond_value_index)}) -> {jfcall.destination_instruction}",
232*523fa7a6SAndroid Build Coastguard Worker                file=out,
233*523fa7a6SAndroid Build Coastguard Worker            )
234*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(instr.instr_args, MoveCall):
235*523fa7a6SAndroid Build Coastguard Worker            move_call = instr.instr_args
236*523fa7a6SAndroid Build Coastguard Worker            print(
237*523fa7a6SAndroid Build Coastguard Worker                f"MOVE {_format_arg(move_call.move_from)} -> {_format_arg(move_call.move_to)}",
238*523fa7a6SAndroid Build Coastguard Worker                file=out,
239*523fa7a6SAndroid Build Coastguard Worker            )
240*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(instr.instr_args, FreeCall):
241*523fa7a6SAndroid Build Coastguard Worker            print(f"FREE {_format_arg(instr.instr_args.value_index)}", file=out)
242*523fa7a6SAndroid Build Coastguard Worker        else:
243*523fa7a6SAndroid Build Coastguard Worker            raise InternalError(f"Unsupport instruction type {instr}")
244*523fa7a6SAndroid Build Coastguard Worker
245*523fa7a6SAndroid Build Coastguard Worker
246*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore
247*523fa7a6SAndroid Build Coastguard Workerdef pretty_print(obj: Any, indent: int = 0, out: Optional[TextIO] = None) -> None:
248*523fa7a6SAndroid Build Coastguard Worker    """
249*523fa7a6SAndroid Build Coastguard Worker    Pretty prints the given object which is of the Program type and any of its
250*523fa7a6SAndroid Build Coastguard Worker    attribute’s types.
251*523fa7a6SAndroid Build Coastguard Worker    """
252*523fa7a6SAndroid Build Coastguard Worker    if isinstance(obj, torch.fx.GraphModule):
253*523fa7a6SAndroid Build Coastguard Worker        raise ExportError(
254*523fa7a6SAndroid Build Coastguard Worker            ExportErrorType.INVALID_INPUT_TYPE,
255*523fa7a6SAndroid Build Coastguard Worker            "pretty_print() does not accept GraphModule as input.",
256*523fa7a6SAndroid Build Coastguard Worker        )
257*523fa7a6SAndroid Build Coastguard Worker
258*523fa7a6SAndroid Build Coastguard Worker    # Instruction types are IntEnum object
259*523fa7a6SAndroid Build Coastguard Worker    if isinstance(obj, IntEnum):
260*523fa7a6SAndroid Build Coastguard Worker        print(int(obj), end="", file=out)
261*523fa7a6SAndroid Build Coastguard Worker        return
262*523fa7a6SAndroid Build Coastguard Worker
263*523fa7a6SAndroid Build Coastguard Worker    primitives = (int, str, bool, float, type(None))
264*523fa7a6SAndroid Build Coastguard Worker    if isinstance(obj, primitives):
265*523fa7a6SAndroid Build Coastguard Worker        print(obj, end="", file=out)
266*523fa7a6SAndroid Build Coastguard Worker        return
267*523fa7a6SAndroid Build Coastguard Worker
268*523fa7a6SAndroid Build Coastguard Worker    if isinstance(obj, bytes):
269*523fa7a6SAndroid Build Coastguard Worker        r = reprlib.Repr()
270*523fa7a6SAndroid Build Coastguard Worker        r.maxother = 1024
271*523fa7a6SAndroid Build Coastguard Worker        print(r.repr(obj), end="", file=out)
272*523fa7a6SAndroid Build Coastguard Worker        return
273*523fa7a6SAndroid Build Coastguard Worker
274*523fa7a6SAndroid Build Coastguard Worker    if isinstance(obj, list):
275*523fa7a6SAndroid Build Coastguard Worker        if len(obj) < 10 and all(isinstance(elem, int) for elem in obj):
276*523fa7a6SAndroid Build Coastguard Worker            print(obj, end="", file=out)
277*523fa7a6SAndroid Build Coastguard Worker            return
278*523fa7a6SAndroid Build Coastguard Worker        print("[", file=out)
279*523fa7a6SAndroid Build Coastguard Worker        for index, elem in enumerate(obj):
280*523fa7a6SAndroid Build Coastguard Worker            print("  " * (indent + 1), end="", file=out)
281*523fa7a6SAndroid Build Coastguard Worker            pretty_print(elem, indent + 1, out=out)
282*523fa7a6SAndroid Build Coastguard Worker            print(f"(index={index}),", file=out)
283*523fa7a6SAndroid Build Coastguard Worker        print("  " * indent + "]", end="", file=out)
284*523fa7a6SAndroid Build Coastguard Worker        return
285*523fa7a6SAndroid Build Coastguard Worker
286*523fa7a6SAndroid Build Coastguard Worker    inline = all(
287*523fa7a6SAndroid Build Coastguard Worker        isinstance(getattr(obj, field.name), primitives) for field in fields(obj)
288*523fa7a6SAndroid Build Coastguard Worker    )
289*523fa7a6SAndroid Build Coastguard Worker    end = "" if inline else "\n"
290*523fa7a6SAndroid Build Coastguard Worker    print(f"{type(obj).__name__}(", end=end, file=out)
291*523fa7a6SAndroid Build Coastguard Worker    for i, _field in enumerate(fields(obj)):
292*523fa7a6SAndroid Build Coastguard Worker        if not inline:
293*523fa7a6SAndroid Build Coastguard Worker            print("  " * (indent + 1), end="", file=out)
294*523fa7a6SAndroid Build Coastguard Worker        print(_field.name + "=", end="", file=out)
295*523fa7a6SAndroid Build Coastguard Worker        pretty_print(getattr(obj, _field.name), indent + 1, out=out)
296*523fa7a6SAndroid Build Coastguard Worker        if i < len(fields(obj)) - 1:
297*523fa7a6SAndroid Build Coastguard Worker            print(", ", end="", file=out)
298*523fa7a6SAndroid Build Coastguard Worker        print("", end=end, file=out)
299*523fa7a6SAndroid Build Coastguard Worker    if not inline:
300*523fa7a6SAndroid Build Coastguard Worker        print("  " * indent, end="", file=out)
301*523fa7a6SAndroid Build Coastguard Worker    print(")", end="" if indent else "\n", file=out)
302*523fa7a6SAndroid Build Coastguard Worker
303*523fa7a6SAndroid Build Coastguard Worker
304*523fa7a6SAndroid Build Coastguard Workerdef pretty_print_stacktraces(obj: FrameList) -> str:
305*523fa7a6SAndroid Build Coastguard Worker    """
306*523fa7a6SAndroid Build Coastguard Worker    Pretty prints the traceback for one instruction
307*523fa7a6SAndroid Build Coastguard Worker    """
308*523fa7a6SAndroid Build Coastguard Worker    pretty = "Traceback (most recent call last): \n"
309*523fa7a6SAndroid Build Coastguard Worker    for frame in obj.items:
310*523fa7a6SAndroid Build Coastguard Worker        pretty += f'    File "{frame.filename}", '
311*523fa7a6SAndroid Build Coastguard Worker        pretty += f"line {str(frame.lineno)}, in {frame.name}\n"
312*523fa7a6SAndroid Build Coastguard Worker        pretty += f"{frame.context} \n"
313*523fa7a6SAndroid Build Coastguard Worker    pretty += "\n"
314*523fa7a6SAndroid Build Coastguard Worker    return pretty
315*523fa7a6SAndroid Build Coastguard Worker
316*523fa7a6SAndroid Build Coastguard Worker
317*523fa7a6SAndroid Build Coastguard Workerdef add_cursor_to_graph(graph: torch.fx.Graph, finding_node: torch.fx.Node) -> str:
318*523fa7a6SAndroid Build Coastguard Worker    """
319*523fa7a6SAndroid Build Coastguard Worker    Insert a cursor at the node location in the fx.Graph.
320*523fa7a6SAndroid Build Coastguard Worker    e.g:
321*523fa7a6SAndroid Build Coastguard Worker    # graph():
322*523fa7a6SAndroid Build Coastguard Worker    #   %x : [#users=1] = placeholder[target=x]
323*523fa7a6SAndroid Build Coastguard Worker    #   %param : [#users=1] = get_attr[target=param]
324*523fa7a6SAndroid Build Coastguard Worker    #   %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
325*523fa7a6SAndroid Build Coastguard Worker    # --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
326*523fa7a6SAndroid Build Coastguard Worker    #   %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
327*523fa7a6SAndroid Build Coastguard Worker    #   return clamp
328*523fa7a6SAndroid Build Coastguard Worker
329*523fa7a6SAndroid Build Coastguard Worker    This is mostly used for error reporting
330*523fa7a6SAndroid Build Coastguard Worker    """
331*523fa7a6SAndroid Build Coastguard Worker
332*523fa7a6SAndroid Build Coastguard Worker    new_graph = copy.deepcopy(graph)
333*523fa7a6SAndroid Build Coastguard Worker
334*523fa7a6SAndroid Build Coastguard Worker    found_at = -1
335*523fa7a6SAndroid Build Coastguard Worker    for ix, node in enumerate(graph.nodes):
336*523fa7a6SAndroid Build Coastguard Worker        if node == finding_node:
337*523fa7a6SAndroid Build Coastguard Worker            found_at = ix
338*523fa7a6SAndroid Build Coastguard Worker
339*523fa7a6SAndroid Build Coastguard Worker    # This is heavily based on __str__ method of fx.Graph
340*523fa7a6SAndroid Build Coastguard Worker    def _format_graph(graph: torch.fx.Graph, offending_node_idx: int) -> str:
341*523fa7a6SAndroid Build Coastguard Worker        s = "graph():"
342*523fa7a6SAndroid Build Coastguard Worker        for ix, node in enumerate(graph.nodes):
343*523fa7a6SAndroid Build Coastguard Worker            node_str = node.format_node()
344*523fa7a6SAndroid Build Coastguard Worker            if node_str:
345*523fa7a6SAndroid Build Coastguard Worker                if ix != offending_node_idx:
346*523fa7a6SAndroid Build Coastguard Worker                    s += "\n    " + node_str
347*523fa7a6SAndroid Build Coastguard Worker                else:
348*523fa7a6SAndroid Build Coastguard Worker                    s += "\n--> " + node_str
349*523fa7a6SAndroid Build Coastguard Worker        return s
350*523fa7a6SAndroid Build Coastguard Worker
351*523fa7a6SAndroid Build Coastguard Worker    return _format_graph(new_graph, found_at)
352*523fa7a6SAndroid Build Coastguard Worker
353*523fa7a6SAndroid Build Coastguard Worker
354*523fa7a6SAndroid Build Coastguard Workerdef _stacktrace_to_framelist(stacktrace: str) -> FrameList:
355*523fa7a6SAndroid Build Coastguard Worker    """Creates a frame list from a stacktrace string."""
356*523fa7a6SAndroid Build Coastguard Worker    pattern = r'File "(.*?)", line (\d+), in (.*?)\n'
357*523fa7a6SAndroid Build Coastguard Worker    matches = re.findall(pattern, stacktrace)
358*523fa7a6SAndroid Build Coastguard Worker    mapped_frame_list = [
359*523fa7a6SAndroid Build Coastguard Worker        Frame(
360*523fa7a6SAndroid Build Coastguard Worker            filename=match[0],
361*523fa7a6SAndroid Build Coastguard Worker            lineno=int(match[1]),
362*523fa7a6SAndroid Build Coastguard Worker            name=match[2],
363*523fa7a6SAndroid Build Coastguard Worker            context=stacktrace.split("\n")[i * 2 + 1].strip(),
364*523fa7a6SAndroid Build Coastguard Worker        )
365*523fa7a6SAndroid Build Coastguard Worker        for i, match in enumerate(matches)
366*523fa7a6SAndroid Build Coastguard Worker    ]
367*523fa7a6SAndroid Build Coastguard Worker    return FrameList(mapped_frame_list)
368*523fa7a6SAndroid Build Coastguard Worker
369*523fa7a6SAndroid Build Coastguard Worker
370*523fa7a6SAndroid Build Coastguard Workerdef inspect_node(graph: torch.fx.Graph, node: torch.fx.Node) -> str:
371*523fa7a6SAndroid Build Coastguard Worker    """
372*523fa7a6SAndroid Build Coastguard Worker    Inspect a node by highlighting the node in the graph as well as the stacktrace.
373*523fa7a6SAndroid Build Coastguard Worker
374*523fa7a6SAndroid Build Coastguard Worker    Args:
375*523fa7a6SAndroid Build Coastguard Worker        graph: The graph containing the node
376*523fa7a6SAndroid Build Coastguard Worker        node: The node to be inspected
377*523fa7a6SAndroid Build Coastguard Worker
378*523fa7a6SAndroid Build Coastguard Worker    Return: A string. An example output is:
379*523fa7a6SAndroid Build Coastguard Worker
380*523fa7a6SAndroid Build Coastguard Worker    _param_constant0 error_msg:  Here is the failing node in the graph module:
381*523fa7a6SAndroid Build Coastguard Worker    graph():
382*523fa7a6SAndroid Build Coastguard Worker        %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
383*523fa7a6SAndroid Build Coastguard Worker    --> %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
384*523fa7a6SAndroid Build Coastguard Worker        %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
385*523fa7a6SAndroid Build Coastguard Worker        %aten_convolution_default : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%arg0_1, %_param_constant0, %_param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
386*523fa7a6SAndroid Build Coastguard Worker        %_param_constant2 : [num_users=1] = get_attr[target=_param_constant2]
387*523fa7a6SAndroid Build Coastguard Worker        %_param_constant3 : [num_users=1] = get_attr[target=_param_constant3]
388*523fa7a6SAndroid Build Coastguard Worker        %aten_convolution_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_convolution_default, %_param_constant2, %_param_constant3, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
389*523fa7a6SAndroid Build Coastguard Worker        %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_convolution_default, %aten_convolution_default_1), kwargs = {})
390*523fa7a6SAndroid Build Coastguard Worker        %_param_constant4 : [num_users=1] = get_attr[target=_param_constant4]
391*523fa7a6SAndroid Build Coastguard Worker        %_param_constant5 : [num_users=1] = get_attr[target=_param_constant5]
392*523fa7a6SAndroid Build Coastguard Worker        %aten_convolution_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor, %_param_constant4, %_param_constant5, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
393*523fa7a6SAndroid Build Coastguard Worker        %aten_gelu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.gelu.default](args = (%aten_convolution_default_2,), kwargs = {})
394*523fa7a6SAndroid Build Coastguard Worker        return [aten_gelu_default]
395*523fa7a6SAndroid Build Coastguard Worker    This node _param_constant0 has metadata of:
396*523fa7a6SAndroid Build Coastguard Worker    The node stacktrace:
397*523fa7a6SAndroid Build Coastguard Worker    Traceback (most recent call last):
398*523fa7a6SAndroid Build Coastguard Worker        File "/tmp/ipykernel_1204253/3382880687.py", line 7, in forward
399*523fa7a6SAndroid Build Coastguard Worker    return self.test_model(x)
400*523fa7a6SAndroid Build Coastguard Worker        File "/mnt/xarfuse/uid-25337/7b86ad0c-seed-nspid4026532987_cgpid2707357-ns-4026532984/torch/nn/modules/module.py", line 1528, in _call_impl
401*523fa7a6SAndroid Build Coastguard Worker    return forward_call(*args, **kwargs)
402*523fa7a6SAndroid Build Coastguard Worker        File "/tmp/ipykernel_1204253/712280972.py", line 10, in forward
403*523fa7a6SAndroid Build Coastguard Worker    a = self.conv1(x)
404*523fa7a6SAndroid Build Coastguard Worker
405*523fa7a6SAndroid Build Coastguard Worker    """
406*523fa7a6SAndroid Build Coastguard Worker    graph_str_with_cursor = add_cursor_to_graph(graph, node)
407*523fa7a6SAndroid Build Coastguard Worker    error_msg = (
408*523fa7a6SAndroid Build Coastguard Worker        f"Here is the node in the graph module:\n"
409*523fa7a6SAndroid Build Coastguard Worker        f"{graph_str_with_cursor}\n"
410*523fa7a6SAndroid Build Coastguard Worker        f"This node {node} has metadata of:\n"
411*523fa7a6SAndroid Build Coastguard Worker    )
412*523fa7a6SAndroid Build Coastguard Worker    # Node spec error message
413*523fa7a6SAndroid Build Coastguard Worker    if hasattr(node.meta, "spec"):
414*523fa7a6SAndroid Build Coastguard Worker        error_msg += f"The node spec:\n{node.meta['spec']}\n"
415*523fa7a6SAndroid Build Coastguard Worker
416*523fa7a6SAndroid Build Coastguard Worker    # Stacktrace error message
417*523fa7a6SAndroid Build Coastguard Worker    if "stack_trace" in node.meta:
418*523fa7a6SAndroid Build Coastguard Worker        framelist = _stacktrace_to_framelist(node.meta["stack_trace"])
419*523fa7a6SAndroid Build Coastguard Worker        error_msg += f"The node stacktrace:\n{pretty_print_stacktraces(framelist)}\n"
420*523fa7a6SAndroid Build Coastguard Worker    return error_msg
421