xref: /aosp_15_r20/external/executorch/exir/print_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import copy
10import re
11import reprlib
12from dataclasses import fields
13from enum import IntEnum
14from typing import Any, List, Optional, TextIO
15
16import torch
17from executorch.exir.error import ExportError, ExportErrorType, InternalError
18
19from executorch.exir.schema import (
20    Bool,
21    BoolList,
22    DelegateCall,
23    Double,
24    DoubleList,
25    EValue,
26    Frame,
27    FrameList,
28    FreeCall,
29    Int,
30    IntList,
31    JumpFalseCall,
32    KernelCall,
33    MoveCall,
34    Null,
35    OptionalTensorList,
36    Program,
37    ScalarType,
38    String,
39    Tensor,
40    TensorList,
41    TensorShapeDynamism,
42)
43
44
45def _scalar_type_str(scalar_type: ScalarType) -> str:
46    type2str = {
47        ScalarType.BYTE: "bt",
48        ScalarType.CHAR: "c",
49        ScalarType.SHORT: "s",
50        ScalarType.INT: "i",
51        ScalarType.LONG: "l",
52        ScalarType.HALF: "h",
53        ScalarType.FLOAT: "f",
54        ScalarType.DOUBLE: "d",
55        ScalarType.COMPLEX32: "c32",
56        ScalarType.COMPLEX64: "c64",
57        ScalarType.COMPLEX128: "c128",
58        ScalarType.BOOL: "b",
59        ScalarType.QINT8: "qi8",
60        ScalarType.QUINT8: "qui8",
61        ScalarType.QINT32: "qi32",
62        ScalarType.BFLOAT16: "bf16",
63        ScalarType.QUINT4x2: "qui4x2",
64        ScalarType.QUINT2x4: "qui2x4",
65    }
66    if not (ret := type2str.get(scalar_type, None)):
67        raise RuntimeError(f"Unrecognized scalar_type: {scalar_type}")
68    else:
69        return ret
70
71
72def _is_dynamic_shape_tensor(tensor: Tensor) -> bool:
73    return tensor.shape_dynamism != TensorShapeDynamism.STATIC
74
75
76def _format_evalue(  # noqa: C901
77    evalue: EValue, show_meminfo: bool, mark_dynamic_shape_tensor: bool
78) -> str:
79    evstr = "\033[34m"
80    if isinstance(evalue.val, Tensor):
81        tensor = evalue.val
82        if tensor.data_buffer_idx > 0:
83            assert not _is_dynamic_shape_tensor(
84                tensor
85            ), "A constant tensor can not be dynamic shape"
86            evstr += "CT"  # constant tensor
87            assert tensor.allocation_info is None
88        else:
89            if mark_dynamic_shape_tensor:
90                if tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND:
91                    evstr += "UB"  # upper bound tensor will be shown as 'UBT'
92                elif tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND:
93                    evstr += "DU"  # dynamic unbound tensor will be shown as 'DUT'
94            evstr += "T"
95            if show_meminfo:
96                if tensor.allocation_info:
97                    evstr += f"m{tensor.allocation_info.memory_id}.{tensor.allocation_info.memory_offset}"
98                else:
99                    evstr += "m."
100        evstr += f"{tensor.sizes}{_scalar_type_str(tensor.scalar_type)}"
101    elif isinstance(evalue.val, TensorList):
102        evstr += "TL"
103        tensorlist = evalue.val
104        # pyre-ignore
105        evstr += str(tensorlist.items)
106    elif isinstance(evalue.val, OptionalTensorList):
107        evstr += "OTL"
108        optionaltensorlist = evalue.val
109        # pyre-ignore
110        evstr += str(optionaltensorlist.items)
111    elif isinstance(evalue.val, IntList):
112        evstr += "IL"
113        intlist = evalue.val
114        # pyre-ignore
115        evstr += str(intlist.items)
116    elif isinstance(evalue.val, DoubleList):
117        evstr += "DL"
118        doublelist = evalue.val
119        # pyre-ignore
120        evstr += str(doublelist.items)
121    elif isinstance(evalue.val, BoolList):
122        evstr += "BL"
123        boollist = evalue.val
124        # pyre-ignore
125        evstr += str(boollist.items)
126    elif isinstance(evalue.val, Int):
127        intval = evalue.val
128        evstr += f"I{intval.int_val}"
129    elif isinstance(evalue.val, Double):
130        doubleval = evalue.val
131        evstr += f"D{doubleval.double_val}"
132    elif isinstance(evalue.val, Bool):
133        boolval = evalue.val
134        evstr += f"B{int(boolval.bool_val)}"  # print 0, 1 since it's shorter than false, true
135    elif isinstance(evalue.val, String):
136        stringval = evalue.val
137        evstr += f"S{stringval.string_val}"
138    elif isinstance(evalue.val, Null):
139        evstr += "N"  # for null
140    else:
141        raise RuntimeError(f"Unrecognized type of evalue: {evalue}")
142    evstr += "\033[0m"
143    return evstr
144
145
146def print_program(  # noqa: C901
147    program: Program,
148    show_meminfo: bool = True,
149    mark_dynamic_shape_tensor: bool = False,
150    out: Optional[TextIO] = None,
151) -> None:
152    """
153    Dump the instruction list of a program in a more human readable fashion.
154
155    The dump follows the following BNF syntax (I combime some regex syntax
156    so the grammar becomes shorter. The grammar is not strict but the main
157    purpose is to let people understand the dump):
158    ```
159      PROGRAM: (INSTRUCTION)+
160      INSTRUCTION: SEQUENCE_NO ':' (CALL_KERNEL | JUMP_FALSE)
161      JUMP_FALSE: 'JF' '(' EVALUE ')' '->' TARGET_SEQUENCE_NO
162      CALL_KERNEL: OVERLOADDED_OP_NAME ARGS
163      ARGS: EVALUE | ARGS ',' EVALUE
164      EVALUE: EVALUE_IDX ( TENSOR | INT | BOOL | ...)
165      INT: 'I' ACTUAL_INT_VALUE
166      BOOL: 'B' ZERO_OR_ONE
167      CONST_TENSOR_PREFIX: 'CT'
168      TENSOR: ('T' | CONST_TENSOR_PREFIX) (MEM_ALLOCATION_INFO)? TENSOR_SHAPE TENSOR_DTYPE
169      TENSOR_SHAPE: '[' dim0_size, dim1_size, ..., last_dim_size ']'
170      MEM_ALLOCATION_INFO: PLANNED_MEM_INFO | UNPLANNED_MEM_INFO
171      PLANNED_MEM_INFO: 'm' MEM_LAYER_ID '.' MEM_LAYER_OFFSET
172      UNPLANNED_MEM_INFO: 'm.'
173    ```
174
175    To make the dump easier to read, it's colored as follows:
176    1. input/output EValues are marked as red
177    2. EValue types (or more specifically tensor types with size and dtype) are marked as blue
178    """
179    execution_plan = program.execution_plan[0]
180    operators = execution_plan.operators
181    delegates = execution_plan.delegates
182    chain = execution_plan.chains[0]
183    instructions = chain.instructions
184    inputs: List[int] = execution_plan.inputs
185    outputs: List[int] = execution_plan.outputs
186    values: List[EValue] = execution_plan.values
187
188    def _format_arg(evalue_idx: int) -> str:
189        def _get_io_index(iolist: List[int], target_evalue_idx: int) -> int:
190            """
191            The list is short enough so linear scan is proper.
192            """
193            for io_idx, evalue_idx in enumerate(iolist):
194                if evalue_idx == target_evalue_idx:
195                    return io_idx
196            return -1
197
198        argstr = str(evalue_idx)
199        if (input_idx := _get_io_index(inputs, evalue_idx)) >= 0:
200            argstr += f"\033[31mI{input_idx}\033[0m"
201        if (output_idx := _get_io_index(outputs, evalue_idx)) >= 0:
202            argstr += f"\033[31mO{output_idx}\033[0m"
203
204        # EValue type
205        evalue = values[evalue_idx]
206        return argstr + _format_evalue(evalue, show_meminfo, mark_dynamic_shape_tensor)
207
208    print(
209        f"The program contains the following {len(instructions)} instructions", file=out
210    )
211    for idx, instr in enumerate(instructions):
212        print(f"{idx:3}: ", end="", file=out)
213        if isinstance(instr.instr_args, KernelCall):
214            kernel = instr.instr_args
215            op = operators[kernel.op_index]
216            args = kernel.args
217
218            opname = f"{op.name}.{op.overload}" if op.overload else op.name
219            argstr = ",".join(map(_format_arg, args))
220            print(f"{opname} {argstr}", file=out)
221        elif isinstance(instr.instr_args, DelegateCall):
222            delegate = instr.instr_args
223            backend = delegates[delegate.delegate_index]
224            args = delegate.args
225            backend_id = f"{backend.id}"
226            argstr = ",".join(map(_format_arg, args))
227            print(f"{backend_id} {argstr}", file=out)
228        elif isinstance(instr.instr_args, JumpFalseCall):
229            jfcall = instr.instr_args
230            print(
231                f"JF ({_format_arg(jfcall.cond_value_index)}) -> {jfcall.destination_instruction}",
232                file=out,
233            )
234        elif isinstance(instr.instr_args, MoveCall):
235            move_call = instr.instr_args
236            print(
237                f"MOVE {_format_arg(move_call.move_from)} -> {_format_arg(move_call.move_to)}",
238                file=out,
239            )
240        elif isinstance(instr.instr_args, FreeCall):
241            print(f"FREE {_format_arg(instr.instr_args.value_index)}", file=out)
242        else:
243            raise InternalError(f"Unsupport instruction type {instr}")
244
245
246# pyre-ignore
247def pretty_print(obj: Any, indent: int = 0, out: Optional[TextIO] = None) -> None:
248    """
249    Pretty prints the given object which is of the Program type and any of its
250    attribute’s types.
251    """
252    if isinstance(obj, torch.fx.GraphModule):
253        raise ExportError(
254            ExportErrorType.INVALID_INPUT_TYPE,
255            "pretty_print() does not accept GraphModule as input.",
256        )
257
258    # Instruction types are IntEnum object
259    if isinstance(obj, IntEnum):
260        print(int(obj), end="", file=out)
261        return
262
263    primitives = (int, str, bool, float, type(None))
264    if isinstance(obj, primitives):
265        print(obj, end="", file=out)
266        return
267
268    if isinstance(obj, bytes):
269        r = reprlib.Repr()
270        r.maxother = 1024
271        print(r.repr(obj), end="", file=out)
272        return
273
274    if isinstance(obj, list):
275        if len(obj) < 10 and all(isinstance(elem, int) for elem in obj):
276            print(obj, end="", file=out)
277            return
278        print("[", file=out)
279        for index, elem in enumerate(obj):
280            print("  " * (indent + 1), end="", file=out)
281            pretty_print(elem, indent + 1, out=out)
282            print(f"(index={index}),", file=out)
283        print("  " * indent + "]", end="", file=out)
284        return
285
286    inline = all(
287        isinstance(getattr(obj, field.name), primitives) for field in fields(obj)
288    )
289    end = "" if inline else "\n"
290    print(f"{type(obj).__name__}(", end=end, file=out)
291    for i, _field in enumerate(fields(obj)):
292        if not inline:
293            print("  " * (indent + 1), end="", file=out)
294        print(_field.name + "=", end="", file=out)
295        pretty_print(getattr(obj, _field.name), indent + 1, out=out)
296        if i < len(fields(obj)) - 1:
297            print(", ", end="", file=out)
298        print("", end=end, file=out)
299    if not inline:
300        print("  " * indent, end="", file=out)
301    print(")", end="" if indent else "\n", file=out)
302
303
304def pretty_print_stacktraces(obj: FrameList) -> str:
305    """
306    Pretty prints the traceback for one instruction
307    """
308    pretty = "Traceback (most recent call last): \n"
309    for frame in obj.items:
310        pretty += f'    File "{frame.filename}", '
311        pretty += f"line {str(frame.lineno)}, in {frame.name}\n"
312        pretty += f"{frame.context} \n"
313    pretty += "\n"
314    return pretty
315
316
317def add_cursor_to_graph(graph: torch.fx.Graph, finding_node: torch.fx.Node) -> str:
318    """
319    Insert a cursor at the node location in the fx.Graph.
320    e.g:
321    # graph():
322    #   %x : [#users=1] = placeholder[target=x]
323    #   %param : [#users=1] = get_attr[target=param]
324    #   %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
325    # --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
326    #   %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
327    #   return clamp
328
329    This is mostly used for error reporting
330    """
331
332    new_graph = copy.deepcopy(graph)
333
334    found_at = -1
335    for ix, node in enumerate(graph.nodes):
336        if node == finding_node:
337            found_at = ix
338
339    # This is heavily based on __str__ method of fx.Graph
340    def _format_graph(graph: torch.fx.Graph, offending_node_idx: int) -> str:
341        s = "graph():"
342        for ix, node in enumerate(graph.nodes):
343            node_str = node.format_node()
344            if node_str:
345                if ix != offending_node_idx:
346                    s += "\n    " + node_str
347                else:
348                    s += "\n--> " + node_str
349        return s
350
351    return _format_graph(new_graph, found_at)
352
353
354def _stacktrace_to_framelist(stacktrace: str) -> FrameList:
355    """Creates a frame list from a stacktrace string."""
356    pattern = r'File "(.*?)", line (\d+), in (.*?)\n'
357    matches = re.findall(pattern, stacktrace)
358    mapped_frame_list = [
359        Frame(
360            filename=match[0],
361            lineno=int(match[1]),
362            name=match[2],
363            context=stacktrace.split("\n")[i * 2 + 1].strip(),
364        )
365        for i, match in enumerate(matches)
366    ]
367    return FrameList(mapped_frame_list)
368
369
370def inspect_node(graph: torch.fx.Graph, node: torch.fx.Node) -> str:
371    """
372    Inspect a node by highlighting the node in the graph as well as the stacktrace.
373
374    Args:
375        graph: The graph containing the node
376        node: The node to be inspected
377
378    Return: A string. An example output is:
379
380    _param_constant0 error_msg:  Here is the failing node in the graph module:
381    graph():
382        %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
383    --> %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
384        %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
385        %aten_convolution_default : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%arg0_1, %_param_constant0, %_param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
386        %_param_constant2 : [num_users=1] = get_attr[target=_param_constant2]
387        %_param_constant3 : [num_users=1] = get_attr[target=_param_constant3]
388        %aten_convolution_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_convolution_default, %_param_constant2, %_param_constant3, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
389        %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_convolution_default, %aten_convolution_default_1), kwargs = {})
390        %_param_constant4 : [num_users=1] = get_attr[target=_param_constant4]
391        %_param_constant5 : [num_users=1] = get_attr[target=_param_constant5]
392        %aten_convolution_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor, %_param_constant4, %_param_constant5, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
393        %aten_gelu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.gelu.default](args = (%aten_convolution_default_2,), kwargs = {})
394        return [aten_gelu_default]
395    This node _param_constant0 has metadata of:
396    The node stacktrace:
397    Traceback (most recent call last):
398        File "/tmp/ipykernel_1204253/3382880687.py", line 7, in forward
399    return self.test_model(x)
400        File "/mnt/xarfuse/uid-25337/7b86ad0c-seed-nspid4026532987_cgpid2707357-ns-4026532984/torch/nn/modules/module.py", line 1528, in _call_impl
401    return forward_call(*args, **kwargs)
402        File "/tmp/ipykernel_1204253/712280972.py", line 10, in forward
403    a = self.conv1(x)
404
405    """
406    graph_str_with_cursor = add_cursor_to_graph(graph, node)
407    error_msg = (
408        f"Here is the node in the graph module:\n"
409        f"{graph_str_with_cursor}\n"
410        f"This node {node} has metadata of:\n"
411    )
412    # Node spec error message
413    if hasattr(node.meta, "spec"):
414        error_msg += f"The node spec:\n{node.meta['spec']}\n"
415
416    # Stacktrace error message
417    if "stack_trace" in node.meta:
418        framelist = _stacktrace_to_framelist(node.meta["stack_trace"])
419        error_msg += f"The node stacktrace:\n{pretty_print_stacktraces(framelist)}\n"
420    return error_msg
421