1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import functools 5import logging 6from enum import Enum 7from typing import List, Optional 8 9from torch import dtype as torch_dtype 10 11from .. import config 12from ..virtualized import V 13from .multi_kernel import MultiKernel 14 15 16log = logging.getLogger(__name__) 17 18 19# AOTI debug printing related configs 20class IntermediateValueDebuggingLevel(Enum): 21 # OFF: No intermediate tensor value debug info will be printed or saved. 22 OFF = "0" 23 # LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed. 24 SAVE_ONLY = "1" 25 # LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed. 26 PRINT_ONLY = "2" 27 28 29class DebugPrinterManager: 30 def __init__( 31 self, 32 debug_printer_level, 33 args_to_print_or_save: Optional[List[str]] = None, 34 kernel_name: str = "", 35 kernel=None, 36 arg_signatures: Optional[List[type]] = None, 37 ): 38 self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level) 39 if args_to_print_or_save is None: 40 args_to_print_or_save = [] 41 self.args_to_print_or_save = args_to_print_or_save 42 self.kernel_name = kernel_name 43 self.arg_signatures: Optional[List[type]] = None 44 self.kernel = kernel 45 self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names() 46 47 def __enter__(self): 48 self._perform_debug_print_or_save_helper( 49 self.args_to_print_or_save, 50 self.kernel_name, 51 before_launch=True, 52 arg_signatures=self.arg_signatures, 53 ) 54 55 def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures): 56 self._perform_debug_print_or_save_helper( 57 args_to_print_or_save, 58 kernel_name, 59 before_launch=False, 60 arg_signatures=arg_signatures, 61 ) 62 63 def _perform_debug_print_or_save_helper( 64 self, 65 args_to_print_or_save, 66 kernel_name, 67 before_launch, 68 arg_signatures: Optional[List[type]] = None, 69 ): 70 if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF: 71 return 72 if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY: 73 # by default save all the tensor values before launch 74 self.codegen_intermediate_tensor_value_save( 75 self.args_to_print_or_save, 76 self.kernel_name, 77 before_launch, 78 arg_signatures=self.arg_signatures, 79 ) 80 if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: 81 # by default print all the tensor values before launch 82 self.codegen_intermediate_tensor_value_print( 83 self.args_to_print_or_save, 84 self.kernel_name, 85 before_launch, 86 arg_signatures=self.arg_signatures, 87 ) 88 89 @functools.lru_cache # noqa: B019 90 def _get_debug_filtered_kernel_names(self) -> List[str]: 91 if config.aot_inductor.filtered_kernel_names is None: 92 return [] 93 return [ 94 x.strip() 95 for x in config.aot_inductor.filtered_kernel_names.lower().split(",") 96 ] 97 98 def set_printer_args( 99 self, 100 args_to_print_or_save: List[str], 101 kernel_name: str, 102 arg_signatures: Optional[List[type]], 103 kernel, 104 ): 105 # Note: MultiKernel debug printing is not supported for now 106 if isinstance(kernel, MultiKernel): 107 log.info( 108 "MultiKernel type is not supported in AOTI debug printer tool yet." 109 ) 110 self.debug_printer_level = IntermediateValueDebuggingLevel.OFF 111 self.args_to_print_or_save = args_to_print_or_save 112 self.kernel_name = kernel_name 113 self.arg_signatures = arg_signatures 114 self.kernel = kernel 115 116 def codegen_intermediate_tensor_value_save( 117 self, 118 args_to_save, 119 kernel_name, 120 before_launch=True, 121 arg_signatures: Optional[List[type]] = None, 122 ) -> None: 123 for i, arg in enumerate(args_to_save): 124 if arg_signatures is not None and not isinstance( 125 arg_signatures[i], torch_dtype 126 ): 127 # infer from the arg data type (has torch.dtype) to see if it is a tensor type 128 continue 129 launch_prefix = "before_launch" if before_launch else "after_launch" 130 if V.graph.cpp_wrapper: 131 if config.abi_compatible: 132 V.graph.wrapper_code.writeline( 133 f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");' 134 ) 135 else: 136 # TODO: add non-abi compatible mode debug printing info 137 pass 138 else: 139 # currently, not cpp wrapper codegen mode not supported. 140 pass 141 142 def codegen_intermediate_tensor_value_print( 143 self, 144 args_to_print, 145 kernel_name, 146 before_launch=True, 147 arg_signatures: Optional[List[type]] = None, 148 ) -> None: 149 for i, arg in enumerate(args_to_print): 150 if arg_signatures is not None and not isinstance( 151 arg_signatures[i], torch_dtype 152 ): 153 # infer from the arg data type (has torch.dtype) to see if it is a tensor type 154 continue 155 if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: 156 # when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY, 157 # check if filtered kernel name list is provided 158 if ( 159 len(self.filtered_kernel_names_to_print) > 0 160 and kernel_name not in self.filtered_kernel_names_to_print 161 ): 162 continue 163 164 launch_prefix = "before_launch" if before_launch else "after_launch" 165 if V.graph.cpp_wrapper: 166 if config.abi_compatible: 167 V.graph.wrapper_code.writeline( 168 f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' 169 ) 170 else: 171 # TODO: add non-abi compatible mode debug printing info 172 pass 173 else: 174 line = f"print('{launch_prefix} - {kernel_name} - {arg}', {arg})" 175 V.graph.wrapper_code.writeline(line) 176