1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport copy 10*523fa7a6SAndroid Build Coastguard Workerimport re 11*523fa7a6SAndroid Build Coastguard Workerimport reprlib 12*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import fields 13*523fa7a6SAndroid Build Coastguard Workerfrom enum import IntEnum 14*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, List, Optional, TextIO 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Workerimport torch 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType, InternalError 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import ( 20*523fa7a6SAndroid Build Coastguard Worker Bool, 21*523fa7a6SAndroid Build Coastguard Worker BoolList, 22*523fa7a6SAndroid Build Coastguard Worker DelegateCall, 23*523fa7a6SAndroid Build Coastguard Worker Double, 24*523fa7a6SAndroid Build Coastguard Worker DoubleList, 25*523fa7a6SAndroid Build Coastguard Worker EValue, 26*523fa7a6SAndroid Build Coastguard Worker Frame, 27*523fa7a6SAndroid Build Coastguard Worker FrameList, 28*523fa7a6SAndroid Build Coastguard Worker FreeCall, 29*523fa7a6SAndroid Build Coastguard Worker Int, 30*523fa7a6SAndroid Build Coastguard Worker IntList, 31*523fa7a6SAndroid Build Coastguard Worker JumpFalseCall, 32*523fa7a6SAndroid Build Coastguard Worker KernelCall, 33*523fa7a6SAndroid Build Coastguard Worker MoveCall, 34*523fa7a6SAndroid Build Coastguard Worker Null, 35*523fa7a6SAndroid Build Coastguard Worker OptionalTensorList, 36*523fa7a6SAndroid Build Coastguard Worker Program, 37*523fa7a6SAndroid Build Coastguard Worker ScalarType, 38*523fa7a6SAndroid Build Coastguard Worker String, 39*523fa7a6SAndroid Build Coastguard Worker Tensor, 40*523fa7a6SAndroid Build Coastguard Worker TensorList, 41*523fa7a6SAndroid Build Coastguard Worker TensorShapeDynamism, 42*523fa7a6SAndroid Build Coastguard Worker) 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Workerdef _scalar_type_str(scalar_type: ScalarType) -> str: 46*523fa7a6SAndroid Build Coastguard Worker type2str = { 47*523fa7a6SAndroid Build Coastguard Worker ScalarType.BYTE: "bt", 48*523fa7a6SAndroid Build Coastguard Worker ScalarType.CHAR: "c", 49*523fa7a6SAndroid Build Coastguard Worker ScalarType.SHORT: "s", 50*523fa7a6SAndroid Build Coastguard Worker ScalarType.INT: "i", 51*523fa7a6SAndroid Build Coastguard Worker ScalarType.LONG: "l", 52*523fa7a6SAndroid Build Coastguard Worker ScalarType.HALF: "h", 53*523fa7a6SAndroid Build Coastguard Worker ScalarType.FLOAT: "f", 54*523fa7a6SAndroid Build Coastguard Worker ScalarType.DOUBLE: "d", 55*523fa7a6SAndroid Build Coastguard Worker ScalarType.COMPLEX32: "c32", 56*523fa7a6SAndroid Build Coastguard Worker ScalarType.COMPLEX64: "c64", 57*523fa7a6SAndroid Build Coastguard Worker ScalarType.COMPLEX128: "c128", 58*523fa7a6SAndroid Build Coastguard Worker ScalarType.BOOL: "b", 59*523fa7a6SAndroid Build Coastguard Worker ScalarType.QINT8: "qi8", 60*523fa7a6SAndroid Build Coastguard Worker ScalarType.QUINT8: "qui8", 61*523fa7a6SAndroid Build Coastguard Worker ScalarType.QINT32: "qi32", 62*523fa7a6SAndroid Build Coastguard Worker ScalarType.BFLOAT16: "bf16", 63*523fa7a6SAndroid Build Coastguard Worker ScalarType.QUINT4x2: "qui4x2", 64*523fa7a6SAndroid Build Coastguard Worker ScalarType.QUINT2x4: "qui2x4", 65*523fa7a6SAndroid Build Coastguard Worker } 66*523fa7a6SAndroid Build Coastguard Worker if not (ret := type2str.get(scalar_type, None)): 67*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Unrecognized scalar_type: {scalar_type}") 68*523fa7a6SAndroid Build Coastguard Worker else: 69*523fa7a6SAndroid Build Coastguard Worker return ret 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Workerdef _is_dynamic_shape_tensor(tensor: Tensor) -> bool: 73*523fa7a6SAndroid Build Coastguard Worker return tensor.shape_dynamism != TensorShapeDynamism.STATIC 74*523fa7a6SAndroid Build Coastguard Worker 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Workerdef _format_evalue( # noqa: C901 77*523fa7a6SAndroid Build Coastguard Worker evalue: EValue, show_meminfo: bool, mark_dynamic_shape_tensor: bool 78*523fa7a6SAndroid Build Coastguard Worker) -> str: 79*523fa7a6SAndroid Build Coastguard Worker evstr = "\033[34m" 80*523fa7a6SAndroid Build Coastguard Worker if isinstance(evalue.val, Tensor): 81*523fa7a6SAndroid Build Coastguard Worker tensor = evalue.val 82*523fa7a6SAndroid Build Coastguard Worker if tensor.data_buffer_idx > 0: 83*523fa7a6SAndroid Build Coastguard Worker assert not _is_dynamic_shape_tensor( 84*523fa7a6SAndroid Build Coastguard Worker tensor 85*523fa7a6SAndroid Build Coastguard Worker ), "A constant tensor can not be dynamic shape" 86*523fa7a6SAndroid Build Coastguard Worker evstr += "CT" # constant tensor 87*523fa7a6SAndroid Build Coastguard Worker assert tensor.allocation_info is None 88*523fa7a6SAndroid Build Coastguard Worker else: 89*523fa7a6SAndroid Build Coastguard Worker if mark_dynamic_shape_tensor: 90*523fa7a6SAndroid Build Coastguard Worker if tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND: 91*523fa7a6SAndroid Build Coastguard Worker evstr += "UB" # upper bound tensor will be shown as 'UBT' 92*523fa7a6SAndroid Build Coastguard Worker elif tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND: 93*523fa7a6SAndroid Build Coastguard Worker evstr += "DU" # dynamic unbound tensor will be shown as 'DUT' 94*523fa7a6SAndroid Build Coastguard Worker evstr += "T" 95*523fa7a6SAndroid Build Coastguard Worker if show_meminfo: 96*523fa7a6SAndroid Build Coastguard Worker if tensor.allocation_info: 97*523fa7a6SAndroid Build Coastguard Worker evstr += f"m{tensor.allocation_info.memory_id}.{tensor.allocation_info.memory_offset}" 98*523fa7a6SAndroid Build Coastguard Worker else: 99*523fa7a6SAndroid Build Coastguard Worker evstr += "m." 100*523fa7a6SAndroid Build Coastguard Worker evstr += f"{tensor.sizes}{_scalar_type_str(tensor.scalar_type)}" 101*523fa7a6SAndroid Build Coastguard Worker elif isinstance(evalue.val, TensorList): 102*523fa7a6SAndroid Build Coastguard Worker evstr += "TL" 103*523fa7a6SAndroid Build Coastguard Worker tensorlist = evalue.val 104*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 105*523fa7a6SAndroid Build Coastguard Worker evstr += str(tensorlist.items) 106*523fa7a6SAndroid Build Coastguard Worker elif isinstance(evalue.val, OptionalTensorList): 107*523fa7a6SAndroid Build Coastguard Worker evstr += "OTL" 108*523fa7a6SAndroid Build Coastguard Worker optionaltensorlist = evalue.val 109*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 110*523fa7a6SAndroid Build Coastguard Worker evstr += str(optionaltensorlist.items) 111*523fa7a6SAndroid Build Coastguard Worker elif isinstance(evalue.val, IntList): 112*523fa7a6SAndroid Build Coastguard Worker evstr += "IL" 113*523fa7a6SAndroid Build Coastguard Worker intlist = evalue.val 114*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 115*523fa7a6SAndroid Build Coastguard Worker evstr += str(intlist.items) 116*523fa7a6SAndroid Build Coastguard Worker elif isinstance(evalue.val, DoubleList): 117*523fa7a6SAndroid Build Coastguard Worker evstr += "DL" 118*523fa7a6SAndroid Build Coastguard Worker doublelist = evalue.val 119*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 120*523fa7a6SAndroid Build Coastguard Worker evstr += str(doublelist.items) 121*523fa7a6SAndroid Build Coastguard Worker elif isinstance(evalue.val, BoolList): 122*523fa7a6SAndroid Build Coastguard Worker evstr += "BL" 123*523fa7a6SAndroid Build Coastguard Worker boollist = evalue.val 124*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 125*523fa7a6SAndroid Build Coastguard Worker evstr += str(boollist.items) 126*523fa7a6SAndroid Build Coastguard Worker elif isinstance(evalue.val, Int): 127*523fa7a6SAndroid Build Coastguard Worker intval = evalue.val 128*523fa7a6SAndroid Build Coastguard Worker evstr += f"I{intval.int_val}" 129*523fa7a6SAndroid Build Coastguard Worker elif isinstance(evalue.val, Double): 130*523fa7a6SAndroid Build Coastguard Worker doubleval = evalue.val 131*523fa7a6SAndroid Build Coastguard Worker evstr += f"D{doubleval.double_val}" 132*523fa7a6SAndroid Build Coastguard Worker elif isinstance(evalue.val, Bool): 133*523fa7a6SAndroid Build Coastguard Worker boolval = evalue.val 134*523fa7a6SAndroid Build Coastguard Worker evstr += f"B{int(boolval.bool_val)}" # print 0, 1 since it's shorter than false, true 135*523fa7a6SAndroid Build Coastguard Worker elif isinstance(evalue.val, String): 136*523fa7a6SAndroid Build Coastguard Worker stringval = evalue.val 137*523fa7a6SAndroid Build Coastguard Worker evstr += f"S{stringval.string_val}" 138*523fa7a6SAndroid Build Coastguard Worker elif isinstance(evalue.val, Null): 139*523fa7a6SAndroid Build Coastguard Worker evstr += "N" # for null 140*523fa7a6SAndroid Build Coastguard Worker else: 141*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Unrecognized type of evalue: {evalue}") 142*523fa7a6SAndroid Build Coastguard Worker evstr += "\033[0m" 143*523fa7a6SAndroid Build Coastguard Worker return evstr 144*523fa7a6SAndroid Build Coastguard Worker 145*523fa7a6SAndroid Build Coastguard Worker 146*523fa7a6SAndroid Build Coastguard Workerdef print_program( # noqa: C901 147*523fa7a6SAndroid Build Coastguard Worker program: Program, 148*523fa7a6SAndroid Build Coastguard Worker show_meminfo: bool = True, 149*523fa7a6SAndroid Build Coastguard Worker mark_dynamic_shape_tensor: bool = False, 150*523fa7a6SAndroid Build Coastguard Worker out: Optional[TextIO] = None, 151*523fa7a6SAndroid Build Coastguard Worker) -> None: 152*523fa7a6SAndroid Build Coastguard Worker """ 153*523fa7a6SAndroid Build Coastguard Worker Dump the instruction list of a program in a more human readable fashion. 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Worker The dump follows the following BNF syntax (I combime some regex syntax 156*523fa7a6SAndroid Build Coastguard Worker so the grammar becomes shorter. The grammar is not strict but the main 157*523fa7a6SAndroid Build Coastguard Worker purpose is to let people understand the dump): 158*523fa7a6SAndroid Build Coastguard Worker ``` 159*523fa7a6SAndroid Build Coastguard Worker PROGRAM: (INSTRUCTION)+ 160*523fa7a6SAndroid Build Coastguard Worker INSTRUCTION: SEQUENCE_NO ':' (CALL_KERNEL | JUMP_FALSE) 161*523fa7a6SAndroid Build Coastguard Worker JUMP_FALSE: 'JF' '(' EVALUE ')' '->' TARGET_SEQUENCE_NO 162*523fa7a6SAndroid Build Coastguard Worker CALL_KERNEL: OVERLOADDED_OP_NAME ARGS 163*523fa7a6SAndroid Build Coastguard Worker ARGS: EVALUE | ARGS ',' EVALUE 164*523fa7a6SAndroid Build Coastguard Worker EVALUE: EVALUE_IDX ( TENSOR | INT | BOOL | ...) 165*523fa7a6SAndroid Build Coastguard Worker INT: 'I' ACTUAL_INT_VALUE 166*523fa7a6SAndroid Build Coastguard Worker BOOL: 'B' ZERO_OR_ONE 167*523fa7a6SAndroid Build Coastguard Worker CONST_TENSOR_PREFIX: 'CT' 168*523fa7a6SAndroid Build Coastguard Worker TENSOR: ('T' | CONST_TENSOR_PREFIX) (MEM_ALLOCATION_INFO)? TENSOR_SHAPE TENSOR_DTYPE 169*523fa7a6SAndroid Build Coastguard Worker TENSOR_SHAPE: '[' dim0_size, dim1_size, ..., last_dim_size ']' 170*523fa7a6SAndroid Build Coastguard Worker MEM_ALLOCATION_INFO: PLANNED_MEM_INFO | UNPLANNED_MEM_INFO 171*523fa7a6SAndroid Build Coastguard Worker PLANNED_MEM_INFO: 'm' MEM_LAYER_ID '.' MEM_LAYER_OFFSET 172*523fa7a6SAndroid Build Coastguard Worker UNPLANNED_MEM_INFO: 'm.' 173*523fa7a6SAndroid Build Coastguard Worker ``` 174*523fa7a6SAndroid Build Coastguard Worker 175*523fa7a6SAndroid Build Coastguard Worker To make the dump easier to read, it's colored as follows: 176*523fa7a6SAndroid Build Coastguard Worker 1. input/output EValues are marked as red 177*523fa7a6SAndroid Build Coastguard Worker 2. EValue types (or more specifically tensor types with size and dtype) are marked as blue 178*523fa7a6SAndroid Build Coastguard Worker """ 179*523fa7a6SAndroid Build Coastguard Worker execution_plan = program.execution_plan[0] 180*523fa7a6SAndroid Build Coastguard Worker operators = execution_plan.operators 181*523fa7a6SAndroid Build Coastguard Worker delegates = execution_plan.delegates 182*523fa7a6SAndroid Build Coastguard Worker chain = execution_plan.chains[0] 183*523fa7a6SAndroid Build Coastguard Worker instructions = chain.instructions 184*523fa7a6SAndroid Build Coastguard Worker inputs: List[int] = execution_plan.inputs 185*523fa7a6SAndroid Build Coastguard Worker outputs: List[int] = execution_plan.outputs 186*523fa7a6SAndroid Build Coastguard Worker values: List[EValue] = execution_plan.values 187*523fa7a6SAndroid Build Coastguard Worker 188*523fa7a6SAndroid Build Coastguard Worker def _format_arg(evalue_idx: int) -> str: 189*523fa7a6SAndroid Build Coastguard Worker def _get_io_index(iolist: List[int], target_evalue_idx: int) -> int: 190*523fa7a6SAndroid Build Coastguard Worker """ 191*523fa7a6SAndroid Build Coastguard Worker The list is short enough so linear scan is proper. 192*523fa7a6SAndroid Build Coastguard Worker """ 193*523fa7a6SAndroid Build Coastguard Worker for io_idx, evalue_idx in enumerate(iolist): 194*523fa7a6SAndroid Build Coastguard Worker if evalue_idx == target_evalue_idx: 195*523fa7a6SAndroid Build Coastguard Worker return io_idx 196*523fa7a6SAndroid Build Coastguard Worker return -1 197*523fa7a6SAndroid Build Coastguard Worker 198*523fa7a6SAndroid Build Coastguard Worker argstr = str(evalue_idx) 199*523fa7a6SAndroid Build Coastguard Worker if (input_idx := _get_io_index(inputs, evalue_idx)) >= 0: 200*523fa7a6SAndroid Build Coastguard Worker argstr += f"\033[31mI{input_idx}\033[0m" 201*523fa7a6SAndroid Build Coastguard Worker if (output_idx := _get_io_index(outputs, evalue_idx)) >= 0: 202*523fa7a6SAndroid Build Coastguard Worker argstr += f"\033[31mO{output_idx}\033[0m" 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Worker # EValue type 205*523fa7a6SAndroid Build Coastguard Worker evalue = values[evalue_idx] 206*523fa7a6SAndroid Build Coastguard Worker return argstr + _format_evalue(evalue, show_meminfo, mark_dynamic_shape_tensor) 207*523fa7a6SAndroid Build Coastguard Worker 208*523fa7a6SAndroid Build Coastguard Worker print( 209*523fa7a6SAndroid Build Coastguard Worker f"The program contains the following {len(instructions)} instructions", file=out 210*523fa7a6SAndroid Build Coastguard Worker ) 211*523fa7a6SAndroid Build Coastguard Worker for idx, instr in enumerate(instructions): 212*523fa7a6SAndroid Build Coastguard Worker print(f"{idx:3}: ", end="", file=out) 213*523fa7a6SAndroid Build Coastguard Worker if isinstance(instr.instr_args, KernelCall): 214*523fa7a6SAndroid Build Coastguard Worker kernel = instr.instr_args 215*523fa7a6SAndroid Build Coastguard Worker op = operators[kernel.op_index] 216*523fa7a6SAndroid Build Coastguard Worker args = kernel.args 217*523fa7a6SAndroid Build Coastguard Worker 218*523fa7a6SAndroid Build Coastguard Worker opname = f"{op.name}.{op.overload}" if op.overload else op.name 219*523fa7a6SAndroid Build Coastguard Worker argstr = ",".join(map(_format_arg, args)) 220*523fa7a6SAndroid Build Coastguard Worker print(f"{opname} {argstr}", file=out) 221*523fa7a6SAndroid Build Coastguard Worker elif isinstance(instr.instr_args, DelegateCall): 222*523fa7a6SAndroid Build Coastguard Worker delegate = instr.instr_args 223*523fa7a6SAndroid Build Coastguard Worker backend = delegates[delegate.delegate_index] 224*523fa7a6SAndroid Build Coastguard Worker args = delegate.args 225*523fa7a6SAndroid Build Coastguard Worker backend_id = f"{backend.id}" 226*523fa7a6SAndroid Build Coastguard Worker argstr = ",".join(map(_format_arg, args)) 227*523fa7a6SAndroid Build Coastguard Worker print(f"{backend_id} {argstr}", file=out) 228*523fa7a6SAndroid Build Coastguard Worker elif isinstance(instr.instr_args, JumpFalseCall): 229*523fa7a6SAndroid Build Coastguard Worker jfcall = instr.instr_args 230*523fa7a6SAndroid Build Coastguard Worker print( 231*523fa7a6SAndroid Build Coastguard Worker f"JF ({_format_arg(jfcall.cond_value_index)}) -> {jfcall.destination_instruction}", 232*523fa7a6SAndroid Build Coastguard Worker file=out, 233*523fa7a6SAndroid Build Coastguard Worker ) 234*523fa7a6SAndroid Build Coastguard Worker elif isinstance(instr.instr_args, MoveCall): 235*523fa7a6SAndroid Build Coastguard Worker move_call = instr.instr_args 236*523fa7a6SAndroid Build Coastguard Worker print( 237*523fa7a6SAndroid Build Coastguard Worker f"MOVE {_format_arg(move_call.move_from)} -> {_format_arg(move_call.move_to)}", 238*523fa7a6SAndroid Build Coastguard Worker file=out, 239*523fa7a6SAndroid Build Coastguard Worker ) 240*523fa7a6SAndroid Build Coastguard Worker elif isinstance(instr.instr_args, FreeCall): 241*523fa7a6SAndroid Build Coastguard Worker print(f"FREE {_format_arg(instr.instr_args.value_index)}", file=out) 242*523fa7a6SAndroid Build Coastguard Worker else: 243*523fa7a6SAndroid Build Coastguard Worker raise InternalError(f"Unsupport instruction type {instr}") 244*523fa7a6SAndroid Build Coastguard Worker 245*523fa7a6SAndroid Build Coastguard Worker 246*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore 247*523fa7a6SAndroid Build Coastguard Workerdef pretty_print(obj: Any, indent: int = 0, out: Optional[TextIO] = None) -> None: 248*523fa7a6SAndroid Build Coastguard Worker """ 249*523fa7a6SAndroid Build Coastguard Worker Pretty prints the given object which is of the Program type and any of its 250*523fa7a6SAndroid Build Coastguard Worker attribute’s types. 251*523fa7a6SAndroid Build Coastguard Worker """ 252*523fa7a6SAndroid Build Coastguard Worker if isinstance(obj, torch.fx.GraphModule): 253*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 254*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.INVALID_INPUT_TYPE, 255*523fa7a6SAndroid Build Coastguard Worker "pretty_print() does not accept GraphModule as input.", 256*523fa7a6SAndroid Build Coastguard Worker ) 257*523fa7a6SAndroid Build Coastguard Worker 258*523fa7a6SAndroid Build Coastguard Worker # Instruction types are IntEnum object 259*523fa7a6SAndroid Build Coastguard Worker if isinstance(obj, IntEnum): 260*523fa7a6SAndroid Build Coastguard Worker print(int(obj), end="", file=out) 261*523fa7a6SAndroid Build Coastguard Worker return 262*523fa7a6SAndroid Build Coastguard Worker 263*523fa7a6SAndroid Build Coastguard Worker primitives = (int, str, bool, float, type(None)) 264*523fa7a6SAndroid Build Coastguard Worker if isinstance(obj, primitives): 265*523fa7a6SAndroid Build Coastguard Worker print(obj, end="", file=out) 266*523fa7a6SAndroid Build Coastguard Worker return 267*523fa7a6SAndroid Build Coastguard Worker 268*523fa7a6SAndroid Build Coastguard Worker if isinstance(obj, bytes): 269*523fa7a6SAndroid Build Coastguard Worker r = reprlib.Repr() 270*523fa7a6SAndroid Build Coastguard Worker r.maxother = 1024 271*523fa7a6SAndroid Build Coastguard Worker print(r.repr(obj), end="", file=out) 272*523fa7a6SAndroid Build Coastguard Worker return 273*523fa7a6SAndroid Build Coastguard Worker 274*523fa7a6SAndroid Build Coastguard Worker if isinstance(obj, list): 275*523fa7a6SAndroid Build Coastguard Worker if len(obj) < 10 and all(isinstance(elem, int) for elem in obj): 276*523fa7a6SAndroid Build Coastguard Worker print(obj, end="", file=out) 277*523fa7a6SAndroid Build Coastguard Worker return 278*523fa7a6SAndroid Build Coastguard Worker print("[", file=out) 279*523fa7a6SAndroid Build Coastguard Worker for index, elem in enumerate(obj): 280*523fa7a6SAndroid Build Coastguard Worker print(" " * (indent + 1), end="", file=out) 281*523fa7a6SAndroid Build Coastguard Worker pretty_print(elem, indent + 1, out=out) 282*523fa7a6SAndroid Build Coastguard Worker print(f"(index={index}),", file=out) 283*523fa7a6SAndroid Build Coastguard Worker print(" " * indent + "]", end="", file=out) 284*523fa7a6SAndroid Build Coastguard Worker return 285*523fa7a6SAndroid Build Coastguard Worker 286*523fa7a6SAndroid Build Coastguard Worker inline = all( 287*523fa7a6SAndroid Build Coastguard Worker isinstance(getattr(obj, field.name), primitives) for field in fields(obj) 288*523fa7a6SAndroid Build Coastguard Worker ) 289*523fa7a6SAndroid Build Coastguard Worker end = "" if inline else "\n" 290*523fa7a6SAndroid Build Coastguard Worker print(f"{type(obj).__name__}(", end=end, file=out) 291*523fa7a6SAndroid Build Coastguard Worker for i, _field in enumerate(fields(obj)): 292*523fa7a6SAndroid Build Coastguard Worker if not inline: 293*523fa7a6SAndroid Build Coastguard Worker print(" " * (indent + 1), end="", file=out) 294*523fa7a6SAndroid Build Coastguard Worker print(_field.name + "=", end="", file=out) 295*523fa7a6SAndroid Build Coastguard Worker pretty_print(getattr(obj, _field.name), indent + 1, out=out) 296*523fa7a6SAndroid Build Coastguard Worker if i < len(fields(obj)) - 1: 297*523fa7a6SAndroid Build Coastguard Worker print(", ", end="", file=out) 298*523fa7a6SAndroid Build Coastguard Worker print("", end=end, file=out) 299*523fa7a6SAndroid Build Coastguard Worker if not inline: 300*523fa7a6SAndroid Build Coastguard Worker print(" " * indent, end="", file=out) 301*523fa7a6SAndroid Build Coastguard Worker print(")", end="" if indent else "\n", file=out) 302*523fa7a6SAndroid Build Coastguard Worker 303*523fa7a6SAndroid Build Coastguard Worker 304*523fa7a6SAndroid Build Coastguard Workerdef pretty_print_stacktraces(obj: FrameList) -> str: 305*523fa7a6SAndroid Build Coastguard Worker """ 306*523fa7a6SAndroid Build Coastguard Worker Pretty prints the traceback for one instruction 307*523fa7a6SAndroid Build Coastguard Worker """ 308*523fa7a6SAndroid Build Coastguard Worker pretty = "Traceback (most recent call last): \n" 309*523fa7a6SAndroid Build Coastguard Worker for frame in obj.items: 310*523fa7a6SAndroid Build Coastguard Worker pretty += f' File "{frame.filename}", ' 311*523fa7a6SAndroid Build Coastguard Worker pretty += f"line {str(frame.lineno)}, in {frame.name}\n" 312*523fa7a6SAndroid Build Coastguard Worker pretty += f"{frame.context} \n" 313*523fa7a6SAndroid Build Coastguard Worker pretty += "\n" 314*523fa7a6SAndroid Build Coastguard Worker return pretty 315*523fa7a6SAndroid Build Coastguard Worker 316*523fa7a6SAndroid Build Coastguard Worker 317*523fa7a6SAndroid Build Coastguard Workerdef add_cursor_to_graph(graph: torch.fx.Graph, finding_node: torch.fx.Node) -> str: 318*523fa7a6SAndroid Build Coastguard Worker """ 319*523fa7a6SAndroid Build Coastguard Worker Insert a cursor at the node location in the fx.Graph. 320*523fa7a6SAndroid Build Coastguard Worker e.g: 321*523fa7a6SAndroid Build Coastguard Worker # graph(): 322*523fa7a6SAndroid Build Coastguard Worker # %x : [#users=1] = placeholder[target=x] 323*523fa7a6SAndroid Build Coastguard Worker # %param : [#users=1] = get_attr[target=param] 324*523fa7a6SAndroid Build Coastguard Worker # %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) 325*523fa7a6SAndroid Build Coastguard Worker # --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {}) 326*523fa7a6SAndroid Build Coastguard Worker # %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) 327*523fa7a6SAndroid Build Coastguard Worker # return clamp 328*523fa7a6SAndroid Build Coastguard Worker 329*523fa7a6SAndroid Build Coastguard Worker This is mostly used for error reporting 330*523fa7a6SAndroid Build Coastguard Worker """ 331*523fa7a6SAndroid Build Coastguard Worker 332*523fa7a6SAndroid Build Coastguard Worker new_graph = copy.deepcopy(graph) 333*523fa7a6SAndroid Build Coastguard Worker 334*523fa7a6SAndroid Build Coastguard Worker found_at = -1 335*523fa7a6SAndroid Build Coastguard Worker for ix, node in enumerate(graph.nodes): 336*523fa7a6SAndroid Build Coastguard Worker if node == finding_node: 337*523fa7a6SAndroid Build Coastguard Worker found_at = ix 338*523fa7a6SAndroid Build Coastguard Worker 339*523fa7a6SAndroid Build Coastguard Worker # This is heavily based on __str__ method of fx.Graph 340*523fa7a6SAndroid Build Coastguard Worker def _format_graph(graph: torch.fx.Graph, offending_node_idx: int) -> str: 341*523fa7a6SAndroid Build Coastguard Worker s = "graph():" 342*523fa7a6SAndroid Build Coastguard Worker for ix, node in enumerate(graph.nodes): 343*523fa7a6SAndroid Build Coastguard Worker node_str = node.format_node() 344*523fa7a6SAndroid Build Coastguard Worker if node_str: 345*523fa7a6SAndroid Build Coastguard Worker if ix != offending_node_idx: 346*523fa7a6SAndroid Build Coastguard Worker s += "\n " + node_str 347*523fa7a6SAndroid Build Coastguard Worker else: 348*523fa7a6SAndroid Build Coastguard Worker s += "\n--> " + node_str 349*523fa7a6SAndroid Build Coastguard Worker return s 350*523fa7a6SAndroid Build Coastguard Worker 351*523fa7a6SAndroid Build Coastguard Worker return _format_graph(new_graph, found_at) 352*523fa7a6SAndroid Build Coastguard Worker 353*523fa7a6SAndroid Build Coastguard Worker 354*523fa7a6SAndroid Build Coastguard Workerdef _stacktrace_to_framelist(stacktrace: str) -> FrameList: 355*523fa7a6SAndroid Build Coastguard Worker """Creates a frame list from a stacktrace string.""" 356*523fa7a6SAndroid Build Coastguard Worker pattern = r'File "(.*?)", line (\d+), in (.*?)\n' 357*523fa7a6SAndroid Build Coastguard Worker matches = re.findall(pattern, stacktrace) 358*523fa7a6SAndroid Build Coastguard Worker mapped_frame_list = [ 359*523fa7a6SAndroid Build Coastguard Worker Frame( 360*523fa7a6SAndroid Build Coastguard Worker filename=match[0], 361*523fa7a6SAndroid Build Coastguard Worker lineno=int(match[1]), 362*523fa7a6SAndroid Build Coastguard Worker name=match[2], 363*523fa7a6SAndroid Build Coastguard Worker context=stacktrace.split("\n")[i * 2 + 1].strip(), 364*523fa7a6SAndroid Build Coastguard Worker ) 365*523fa7a6SAndroid Build Coastguard Worker for i, match in enumerate(matches) 366*523fa7a6SAndroid Build Coastguard Worker ] 367*523fa7a6SAndroid Build Coastguard Worker return FrameList(mapped_frame_list) 368*523fa7a6SAndroid Build Coastguard Worker 369*523fa7a6SAndroid Build Coastguard Worker 370*523fa7a6SAndroid Build Coastguard Workerdef inspect_node(graph: torch.fx.Graph, node: torch.fx.Node) -> str: 371*523fa7a6SAndroid Build Coastguard Worker """ 372*523fa7a6SAndroid Build Coastguard Worker Inspect a node by highlighting the node in the graph as well as the stacktrace. 373*523fa7a6SAndroid Build Coastguard Worker 374*523fa7a6SAndroid Build Coastguard Worker Args: 375*523fa7a6SAndroid Build Coastguard Worker graph: The graph containing the node 376*523fa7a6SAndroid Build Coastguard Worker node: The node to be inspected 377*523fa7a6SAndroid Build Coastguard Worker 378*523fa7a6SAndroid Build Coastguard Worker Return: A string. An example output is: 379*523fa7a6SAndroid Build Coastguard Worker 380*523fa7a6SAndroid Build Coastguard Worker _param_constant0 error_msg: Here is the failing node in the graph module: 381*523fa7a6SAndroid Build Coastguard Worker graph(): 382*523fa7a6SAndroid Build Coastguard Worker %arg0_1 : [num_users=1] = placeholder[target=arg0_1] 383*523fa7a6SAndroid Build Coastguard Worker --> %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0] 384*523fa7a6SAndroid Build Coastguard Worker %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1] 385*523fa7a6SAndroid Build Coastguard Worker %aten_convolution_default : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%arg0_1, %_param_constant0, %_param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) 386*523fa7a6SAndroid Build Coastguard Worker %_param_constant2 : [num_users=1] = get_attr[target=_param_constant2] 387*523fa7a6SAndroid Build Coastguard Worker %_param_constant3 : [num_users=1] = get_attr[target=_param_constant3] 388*523fa7a6SAndroid Build Coastguard Worker %aten_convolution_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_convolution_default, %_param_constant2, %_param_constant3, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) 389*523fa7a6SAndroid Build Coastguard Worker %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_convolution_default, %aten_convolution_default_1), kwargs = {}) 390*523fa7a6SAndroid Build Coastguard Worker %_param_constant4 : [num_users=1] = get_attr[target=_param_constant4] 391*523fa7a6SAndroid Build Coastguard Worker %_param_constant5 : [num_users=1] = get_attr[target=_param_constant5] 392*523fa7a6SAndroid Build Coastguard Worker %aten_convolution_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor, %_param_constant4, %_param_constant5, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) 393*523fa7a6SAndroid Build Coastguard Worker %aten_gelu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.gelu.default](args = (%aten_convolution_default_2,), kwargs = {}) 394*523fa7a6SAndroid Build Coastguard Worker return [aten_gelu_default] 395*523fa7a6SAndroid Build Coastguard Worker This node _param_constant0 has metadata of: 396*523fa7a6SAndroid Build Coastguard Worker The node stacktrace: 397*523fa7a6SAndroid Build Coastguard Worker Traceback (most recent call last): 398*523fa7a6SAndroid Build Coastguard Worker File "/tmp/ipykernel_1204253/3382880687.py", line 7, in forward 399*523fa7a6SAndroid Build Coastguard Worker return self.test_model(x) 400*523fa7a6SAndroid Build Coastguard Worker File "/mnt/xarfuse/uid-25337/7b86ad0c-seed-nspid4026532987_cgpid2707357-ns-4026532984/torch/nn/modules/module.py", line 1528, in _call_impl 401*523fa7a6SAndroid Build Coastguard Worker return forward_call(*args, **kwargs) 402*523fa7a6SAndroid Build Coastguard Worker File "/tmp/ipykernel_1204253/712280972.py", line 10, in forward 403*523fa7a6SAndroid Build Coastguard Worker a = self.conv1(x) 404*523fa7a6SAndroid Build Coastguard Worker 405*523fa7a6SAndroid Build Coastguard Worker """ 406*523fa7a6SAndroid Build Coastguard Worker graph_str_with_cursor = add_cursor_to_graph(graph, node) 407*523fa7a6SAndroid Build Coastguard Worker error_msg = ( 408*523fa7a6SAndroid Build Coastguard Worker f"Here is the node in the graph module:\n" 409*523fa7a6SAndroid Build Coastguard Worker f"{graph_str_with_cursor}\n" 410*523fa7a6SAndroid Build Coastguard Worker f"This node {node} has metadata of:\n" 411*523fa7a6SAndroid Build Coastguard Worker ) 412*523fa7a6SAndroid Build Coastguard Worker # Node spec error message 413*523fa7a6SAndroid Build Coastguard Worker if hasattr(node.meta, "spec"): 414*523fa7a6SAndroid Build Coastguard Worker error_msg += f"The node spec:\n{node.meta['spec']}\n" 415*523fa7a6SAndroid Build Coastguard Worker 416*523fa7a6SAndroid Build Coastguard Worker # Stacktrace error message 417*523fa7a6SAndroid Build Coastguard Worker if "stack_trace" in node.meta: 418*523fa7a6SAndroid Build Coastguard Worker framelist = _stacktrace_to_framelist(node.meta["stack_trace"]) 419*523fa7a6SAndroid Build Coastguard Worker error_msg += f"The node stacktrace:\n{pretty_print_stacktraces(framelist)}\n" 420*523fa7a6SAndroid Build Coastguard Worker return error_msg 421