1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import copy 10import re 11import reprlib 12from dataclasses import fields 13from enum import IntEnum 14from typing import Any, List, Optional, TextIO 15 16import torch 17from executorch.exir.error import ExportError, ExportErrorType, InternalError 18 19from executorch.exir.schema import ( 20 Bool, 21 BoolList, 22 DelegateCall, 23 Double, 24 DoubleList, 25 EValue, 26 Frame, 27 FrameList, 28 FreeCall, 29 Int, 30 IntList, 31 JumpFalseCall, 32 KernelCall, 33 MoveCall, 34 Null, 35 OptionalTensorList, 36 Program, 37 ScalarType, 38 String, 39 Tensor, 40 TensorList, 41 TensorShapeDynamism, 42) 43 44 45def _scalar_type_str(scalar_type: ScalarType) -> str: 46 type2str = { 47 ScalarType.BYTE: "bt", 48 ScalarType.CHAR: "c", 49 ScalarType.SHORT: "s", 50 ScalarType.INT: "i", 51 ScalarType.LONG: "l", 52 ScalarType.HALF: "h", 53 ScalarType.FLOAT: "f", 54 ScalarType.DOUBLE: "d", 55 ScalarType.COMPLEX32: "c32", 56 ScalarType.COMPLEX64: "c64", 57 ScalarType.COMPLEX128: "c128", 58 ScalarType.BOOL: "b", 59 ScalarType.QINT8: "qi8", 60 ScalarType.QUINT8: "qui8", 61 ScalarType.QINT32: "qi32", 62 ScalarType.BFLOAT16: "bf16", 63 ScalarType.QUINT4x2: "qui4x2", 64 ScalarType.QUINT2x4: "qui2x4", 65 } 66 if not (ret := type2str.get(scalar_type, None)): 67 raise RuntimeError(f"Unrecognized scalar_type: {scalar_type}") 68 else: 69 return ret 70 71 72def _is_dynamic_shape_tensor(tensor: Tensor) -> bool: 73 return tensor.shape_dynamism != TensorShapeDynamism.STATIC 74 75 76def _format_evalue( # noqa: C901 77 evalue: EValue, show_meminfo: bool, mark_dynamic_shape_tensor: bool 78) -> str: 79 evstr = "\033[34m" 80 if isinstance(evalue.val, Tensor): 81 tensor = evalue.val 82 if tensor.data_buffer_idx > 0: 83 assert not _is_dynamic_shape_tensor( 84 tensor 85 ), "A constant tensor can not be dynamic shape" 86 evstr += "CT" # constant tensor 87 assert tensor.allocation_info is None 88 else: 89 if mark_dynamic_shape_tensor: 90 if tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND: 91 evstr += "UB" # upper bound tensor will be shown as 'UBT' 92 elif tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND: 93 evstr += "DU" # dynamic unbound tensor will be shown as 'DUT' 94 evstr += "T" 95 if show_meminfo: 96 if tensor.allocation_info: 97 evstr += f"m{tensor.allocation_info.memory_id}.{tensor.allocation_info.memory_offset}" 98 else: 99 evstr += "m." 100 evstr += f"{tensor.sizes}{_scalar_type_str(tensor.scalar_type)}" 101 elif isinstance(evalue.val, TensorList): 102 evstr += "TL" 103 tensorlist = evalue.val 104 # pyre-ignore 105 evstr += str(tensorlist.items) 106 elif isinstance(evalue.val, OptionalTensorList): 107 evstr += "OTL" 108 optionaltensorlist = evalue.val 109 # pyre-ignore 110 evstr += str(optionaltensorlist.items) 111 elif isinstance(evalue.val, IntList): 112 evstr += "IL" 113 intlist = evalue.val 114 # pyre-ignore 115 evstr += str(intlist.items) 116 elif isinstance(evalue.val, DoubleList): 117 evstr += "DL" 118 doublelist = evalue.val 119 # pyre-ignore 120 evstr += str(doublelist.items) 121 elif isinstance(evalue.val, BoolList): 122 evstr += "BL" 123 boollist = evalue.val 124 # pyre-ignore 125 evstr += str(boollist.items) 126 elif isinstance(evalue.val, Int): 127 intval = evalue.val 128 evstr += f"I{intval.int_val}" 129 elif isinstance(evalue.val, Double): 130 doubleval = evalue.val 131 evstr += f"D{doubleval.double_val}" 132 elif isinstance(evalue.val, Bool): 133 boolval = evalue.val 134 evstr += f"B{int(boolval.bool_val)}" # print 0, 1 since it's shorter than false, true 135 elif isinstance(evalue.val, String): 136 stringval = evalue.val 137 evstr += f"S{stringval.string_val}" 138 elif isinstance(evalue.val, Null): 139 evstr += "N" # for null 140 else: 141 raise RuntimeError(f"Unrecognized type of evalue: {evalue}") 142 evstr += "\033[0m" 143 return evstr 144 145 146def print_program( # noqa: C901 147 program: Program, 148 show_meminfo: bool = True, 149 mark_dynamic_shape_tensor: bool = False, 150 out: Optional[TextIO] = None, 151) -> None: 152 """ 153 Dump the instruction list of a program in a more human readable fashion. 154 155 The dump follows the following BNF syntax (I combime some regex syntax 156 so the grammar becomes shorter. The grammar is not strict but the main 157 purpose is to let people understand the dump): 158 ``` 159 PROGRAM: (INSTRUCTION)+ 160 INSTRUCTION: SEQUENCE_NO ':' (CALL_KERNEL | JUMP_FALSE) 161 JUMP_FALSE: 'JF' '(' EVALUE ')' '->' TARGET_SEQUENCE_NO 162 CALL_KERNEL: OVERLOADDED_OP_NAME ARGS 163 ARGS: EVALUE | ARGS ',' EVALUE 164 EVALUE: EVALUE_IDX ( TENSOR | INT | BOOL | ...) 165 INT: 'I' ACTUAL_INT_VALUE 166 BOOL: 'B' ZERO_OR_ONE 167 CONST_TENSOR_PREFIX: 'CT' 168 TENSOR: ('T' | CONST_TENSOR_PREFIX) (MEM_ALLOCATION_INFO)? TENSOR_SHAPE TENSOR_DTYPE 169 TENSOR_SHAPE: '[' dim0_size, dim1_size, ..., last_dim_size ']' 170 MEM_ALLOCATION_INFO: PLANNED_MEM_INFO | UNPLANNED_MEM_INFO 171 PLANNED_MEM_INFO: 'm' MEM_LAYER_ID '.' MEM_LAYER_OFFSET 172 UNPLANNED_MEM_INFO: 'm.' 173 ``` 174 175 To make the dump easier to read, it's colored as follows: 176 1. input/output EValues are marked as red 177 2. EValue types (or more specifically tensor types with size and dtype) are marked as blue 178 """ 179 execution_plan = program.execution_plan[0] 180 operators = execution_plan.operators 181 delegates = execution_plan.delegates 182 chain = execution_plan.chains[0] 183 instructions = chain.instructions 184 inputs: List[int] = execution_plan.inputs 185 outputs: List[int] = execution_plan.outputs 186 values: List[EValue] = execution_plan.values 187 188 def _format_arg(evalue_idx: int) -> str: 189 def _get_io_index(iolist: List[int], target_evalue_idx: int) -> int: 190 """ 191 The list is short enough so linear scan is proper. 192 """ 193 for io_idx, evalue_idx in enumerate(iolist): 194 if evalue_idx == target_evalue_idx: 195 return io_idx 196 return -1 197 198 argstr = str(evalue_idx) 199 if (input_idx := _get_io_index(inputs, evalue_idx)) >= 0: 200 argstr += f"\033[31mI{input_idx}\033[0m" 201 if (output_idx := _get_io_index(outputs, evalue_idx)) >= 0: 202 argstr += f"\033[31mO{output_idx}\033[0m" 203 204 # EValue type 205 evalue = values[evalue_idx] 206 return argstr + _format_evalue(evalue, show_meminfo, mark_dynamic_shape_tensor) 207 208 print( 209 f"The program contains the following {len(instructions)} instructions", file=out 210 ) 211 for idx, instr in enumerate(instructions): 212 print(f"{idx:3}: ", end="", file=out) 213 if isinstance(instr.instr_args, KernelCall): 214 kernel = instr.instr_args 215 op = operators[kernel.op_index] 216 args = kernel.args 217 218 opname = f"{op.name}.{op.overload}" if op.overload else op.name 219 argstr = ",".join(map(_format_arg, args)) 220 print(f"{opname} {argstr}", file=out) 221 elif isinstance(instr.instr_args, DelegateCall): 222 delegate = instr.instr_args 223 backend = delegates[delegate.delegate_index] 224 args = delegate.args 225 backend_id = f"{backend.id}" 226 argstr = ",".join(map(_format_arg, args)) 227 print(f"{backend_id} {argstr}", file=out) 228 elif isinstance(instr.instr_args, JumpFalseCall): 229 jfcall = instr.instr_args 230 print( 231 f"JF ({_format_arg(jfcall.cond_value_index)}) -> {jfcall.destination_instruction}", 232 file=out, 233 ) 234 elif isinstance(instr.instr_args, MoveCall): 235 move_call = instr.instr_args 236 print( 237 f"MOVE {_format_arg(move_call.move_from)} -> {_format_arg(move_call.move_to)}", 238 file=out, 239 ) 240 elif isinstance(instr.instr_args, FreeCall): 241 print(f"FREE {_format_arg(instr.instr_args.value_index)}", file=out) 242 else: 243 raise InternalError(f"Unsupport instruction type {instr}") 244 245 246# pyre-ignore 247def pretty_print(obj: Any, indent: int = 0, out: Optional[TextIO] = None) -> None: 248 """ 249 Pretty prints the given object which is of the Program type and any of its 250 attribute’s types. 251 """ 252 if isinstance(obj, torch.fx.GraphModule): 253 raise ExportError( 254 ExportErrorType.INVALID_INPUT_TYPE, 255 "pretty_print() does not accept GraphModule as input.", 256 ) 257 258 # Instruction types are IntEnum object 259 if isinstance(obj, IntEnum): 260 print(int(obj), end="", file=out) 261 return 262 263 primitives = (int, str, bool, float, type(None)) 264 if isinstance(obj, primitives): 265 print(obj, end="", file=out) 266 return 267 268 if isinstance(obj, bytes): 269 r = reprlib.Repr() 270 r.maxother = 1024 271 print(r.repr(obj), end="", file=out) 272 return 273 274 if isinstance(obj, list): 275 if len(obj) < 10 and all(isinstance(elem, int) for elem in obj): 276 print(obj, end="", file=out) 277 return 278 print("[", file=out) 279 for index, elem in enumerate(obj): 280 print(" " * (indent + 1), end="", file=out) 281 pretty_print(elem, indent + 1, out=out) 282 print(f"(index={index}),", file=out) 283 print(" " * indent + "]", end="", file=out) 284 return 285 286 inline = all( 287 isinstance(getattr(obj, field.name), primitives) for field in fields(obj) 288 ) 289 end = "" if inline else "\n" 290 print(f"{type(obj).__name__}(", end=end, file=out) 291 for i, _field in enumerate(fields(obj)): 292 if not inline: 293 print(" " * (indent + 1), end="", file=out) 294 print(_field.name + "=", end="", file=out) 295 pretty_print(getattr(obj, _field.name), indent + 1, out=out) 296 if i < len(fields(obj)) - 1: 297 print(", ", end="", file=out) 298 print("", end=end, file=out) 299 if not inline: 300 print(" " * indent, end="", file=out) 301 print(")", end="" if indent else "\n", file=out) 302 303 304def pretty_print_stacktraces(obj: FrameList) -> str: 305 """ 306 Pretty prints the traceback for one instruction 307 """ 308 pretty = "Traceback (most recent call last): \n" 309 for frame in obj.items: 310 pretty += f' File "{frame.filename}", ' 311 pretty += f"line {str(frame.lineno)}, in {frame.name}\n" 312 pretty += f"{frame.context} \n" 313 pretty += "\n" 314 return pretty 315 316 317def add_cursor_to_graph(graph: torch.fx.Graph, finding_node: torch.fx.Node) -> str: 318 """ 319 Insert a cursor at the node location in the fx.Graph. 320 e.g: 321 # graph(): 322 # %x : [#users=1] = placeholder[target=x] 323 # %param : [#users=1] = get_attr[target=param] 324 # %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) 325 # --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {}) 326 # %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) 327 # return clamp 328 329 This is mostly used for error reporting 330 """ 331 332 new_graph = copy.deepcopy(graph) 333 334 found_at = -1 335 for ix, node in enumerate(graph.nodes): 336 if node == finding_node: 337 found_at = ix 338 339 # This is heavily based on __str__ method of fx.Graph 340 def _format_graph(graph: torch.fx.Graph, offending_node_idx: int) -> str: 341 s = "graph():" 342 for ix, node in enumerate(graph.nodes): 343 node_str = node.format_node() 344 if node_str: 345 if ix != offending_node_idx: 346 s += "\n " + node_str 347 else: 348 s += "\n--> " + node_str 349 return s 350 351 return _format_graph(new_graph, found_at) 352 353 354def _stacktrace_to_framelist(stacktrace: str) -> FrameList: 355 """Creates a frame list from a stacktrace string.""" 356 pattern = r'File "(.*?)", line (\d+), in (.*?)\n' 357 matches = re.findall(pattern, stacktrace) 358 mapped_frame_list = [ 359 Frame( 360 filename=match[0], 361 lineno=int(match[1]), 362 name=match[2], 363 context=stacktrace.split("\n")[i * 2 + 1].strip(), 364 ) 365 for i, match in enumerate(matches) 366 ] 367 return FrameList(mapped_frame_list) 368 369 370def inspect_node(graph: torch.fx.Graph, node: torch.fx.Node) -> str: 371 """ 372 Inspect a node by highlighting the node in the graph as well as the stacktrace. 373 374 Args: 375 graph: The graph containing the node 376 node: The node to be inspected 377 378 Return: A string. An example output is: 379 380 _param_constant0 error_msg: Here is the failing node in the graph module: 381 graph(): 382 %arg0_1 : [num_users=1] = placeholder[target=arg0_1] 383 --> %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0] 384 %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1] 385 %aten_convolution_default : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%arg0_1, %_param_constant0, %_param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) 386 %_param_constant2 : [num_users=1] = get_attr[target=_param_constant2] 387 %_param_constant3 : [num_users=1] = get_attr[target=_param_constant3] 388 %aten_convolution_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_convolution_default, %_param_constant2, %_param_constant3, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) 389 %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_convolution_default, %aten_convolution_default_1), kwargs = {}) 390 %_param_constant4 : [num_users=1] = get_attr[target=_param_constant4] 391 %_param_constant5 : [num_users=1] = get_attr[target=_param_constant5] 392 %aten_convolution_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor, %_param_constant4, %_param_constant5, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) 393 %aten_gelu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.gelu.default](args = (%aten_convolution_default_2,), kwargs = {}) 394 return [aten_gelu_default] 395 This node _param_constant0 has metadata of: 396 The node stacktrace: 397 Traceback (most recent call last): 398 File "/tmp/ipykernel_1204253/3382880687.py", line 7, in forward 399 return self.test_model(x) 400 File "/mnt/xarfuse/uid-25337/7b86ad0c-seed-nspid4026532987_cgpid2707357-ns-4026532984/torch/nn/modules/module.py", line 1528, in _call_impl 401 return forward_call(*args, **kwargs) 402 File "/tmp/ipykernel_1204253/712280972.py", line 10, in forward 403 a = self.conv1(x) 404 405 """ 406 graph_str_with_cursor = add_cursor_to_graph(graph, node) 407 error_msg = ( 408 f"Here is the node in the graph module:\n" 409 f"{graph_str_with_cursor}\n" 410 f"This node {node} has metadata of:\n" 411 ) 412 # Node spec error message 413 if hasattr(node.meta, "spec"): 414 error_msg += f"The node spec:\n{node.meta['spec']}\n" 415 416 # Stacktrace error message 417 if "stack_trace" in node.meta: 418 framelist = _stacktrace_to_framelist(node.meta["stack_trace"]) 419 error_msg += f"The node stacktrace:\n{pretty_print_stacktraces(framelist)}\n" 420 return error_msg 421