1# mypy: allow-untyped-defs 2import dataclasses 3import tempfile 4from collections import defaultdict 5 6import torch 7from torch.autograd import DeviceType 8 9from .runtime.benchmarking import benchmarker 10from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes 11 12 13_kernel_category_choices = [ 14 "foreach", 15 "persistent_reduction", 16 "pointwise", 17 "reduction", 18 "split_scan", 19 "template", 20] 21 22 23def get_kernel_category_by_source_code(src_code): 24 """ 25 Similar to get_kernel_category but use the source code. Call this API 26 if we have not compile the src_code to module yet. 27 """ 28 choices = [ 29 ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code 30 ] 31 if len(choices) == 1: 32 return choices[0] 33 else: 34 return "unknown" 35 36 37def get_kernel_category(kernel_mod): 38 """ 39 Given the module defining a triton kernel, return the category of the kernel. 40 Category can be one of: 41 - pointwise 42 - reduction 43 - persistent_reduction 44 45 Currently we simply decide the category depending on what decorator is imported 46 by the kernel. 47 """ 48 choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__] 49 if len(choices) == 1: 50 return choices[0] 51 else: 52 return "unknown" 53 54 55def get_triton_kernel(mod): 56 from torch._inductor.runtime.triton_heuristics import CachingAutotuner 57 58 cand_list = [ 59 v 60 for k, v in mod.__dict__.items() 61 if k.startswith("triton_") and isinstance(v, CachingAutotuner) 62 ] 63 assert len(cand_list) == 1 64 return cand_list[0] 65 66 67def benchmark_all_kernels(benchmark_name, benchmark_all_configs): 68 """ 69 An experimental API used only when config.benchmark_kernel is true. 70 71 Run the kernel benchmarks for all the kernels cached in PyCodeCache. 72 Used in the compiled modules. 73 74 Put this method here rather than codegen it for convenience since its implementation 75 does not change based on different graph modules being compiled. 76 """ 77 from torch._inductor.codecache import PyCodeCache 78 79 nfound = 0 80 for kernel_key, kernel_mod in PyCodeCache.cache.items(): 81 if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"): 82 continue 83 84 triton_kernel = get_triton_kernel(kernel_mod) 85 kernel_category = get_kernel_category(kernel_mod) 86 args = kernel_mod.get_args() 87 num_in_out_ptrs = len( 88 [ 89 arg_name 90 for arg_name in triton_kernel.fn.arg_names 91 if arg_name.startswith("in_out_ptr") 92 ] 93 ) 94 num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None) 95 if num_gb is None: 96 num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 97 98 def get_info_str(ms, n_regs, n_spills, shared, prefix=""): 99 if not any(x is None for x in [n_regs, n_spills, shared]): 100 kernel_detail_str = ( 101 f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem" 102 ) 103 else: 104 kernel_detail_str = "" 105 106 gb_per_s = num_gb / (ms / 1e3) 107 return create_bandwidth_info_str( 108 ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str 109 ) 110 111 kernel_desc = ( 112 f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}" 113 ) 114 if benchmark_all_configs: 115 assert hasattr(kernel_mod, "benchmark_all_configs") 116 bench_result = kernel_mod.benchmark_all_configs(args) 117 print(kernel_desc) 118 for launcher, ms in bench_result.items(): 119 print( 120 f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}" 121 ) 122 else: 123 ms = benchmarker.benchmark_gpu( 124 lambda: kernel_mod.call(args), rep=40, fast_flush=True 125 ) 126 assert ( 127 len(triton_kernel.launchers) == 1 128 ), "Autotuner should have selected the best config" 129 launcher = triton_kernel.launchers[0] 130 print( 131 get_info_str( 132 ms, 133 launcher.n_regs, 134 launcher.n_spills, 135 launcher.shared, 136 prefix=f"{kernel_desc} ", 137 ) 138 ) 139 140 nfound += 1 141 if nfound == 0: 142 print( 143 "No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True" 144 ) 145 146 147@dataclasses.dataclass 148class ProfileEvent: 149 category: str 150 key: str 151 self_device_time_ms: float 152 # the benchmark is run multiple times and we average the count across all the 153 # runs. It should be an integer but define a float just in case. 154 count: float 155 156 157def parse_profile_event_list( 158 benchmark_name, event_list, wall_time_ms, nruns, device_name 159): 160 def get_self_device_time(ev): 161 """ 162 ev.self_device_time_total is in microsecond. Convert to millisecond. 163 """ 164 return ev.self_device_time_total / 1000 / nruns 165 166 all_events = defaultdict(list) 167 168 def add_event(ev, category): 169 profile_ev = ProfileEvent( 170 category=category, 171 key=ev.key, 172 self_device_time_ms=get_self_device_time(ev), 173 count=ev.count / nruns, # average across all runs 174 ) 175 all_events[category].append(profile_ev) 176 177 for ev in event_list: 178 assert not ev.is_legacy, "Don't support the legacy profiler" 179 if ev.device_type == DeviceType.CPU: 180 # ignore the event on CPU side 181 continue 182 183 category = "unknown" 184 if ev.key.startswith("triton_"): 185 if ev.key.startswith("triton_poi"): 186 category = "triton_pointwise" 187 elif ev.key.startswith("triton_red"): 188 category = "triton_reduction" 189 elif ev.key.startswith("triton_per"): 190 category = "triton_persistent_reduction" 191 else: 192 category = "triton_unknown" 193 194 add_event(ev, category) 195 196 def report_category(category, profile_events): 197 from tabulate import tabulate 198 199 profile_events.sort(key=lambda ev: ev.self_device_time_ms, reverse=True) 200 201 rows = [] 202 total_time = 0.0 203 print(f"\n == {category} category kernels == ") 204 for ev in profile_events: 205 total_time += ev.self_device_time_ms 206 percent = f"{ev.self_device_time_ms / wall_time_ms * 100:.2f}%" 207 rows.append([ev.key[:120], ev.self_device_time_ms, ev.count, percent]) 208 rows.append( 209 ["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"] 210 ) 211 print( 212 tabulate( 213 rows, 214 headers=[ 215 "Kernel", 216 f"Self {device_name.upper()} TIME (ms)", 217 "Count", 218 "Percent", 219 ], 220 ) 221 ) 222 return total_time 223 224 def report(): 225 category_list = [ 226 "triton_pointwise", 227 "triton_reduction", 228 "triton_persistent_reduction", 229 "triton_unknown", 230 "unknown", 231 ] 232 assert set(all_events.keys()).issubset( 233 set(category_list) 234 ), f"{list(all_events.keys())}" 235 236 per_category_wall_time = {} 237 total_device_ms = 0.0 238 for category in category_list: 239 if category in all_events: 240 _time = report_category(category, all_events[category]) 241 per_category_wall_time[category] = _time 242 total_device_ms += _time 243 244 device_busy_percent = f"{total_device_ms / wall_time_ms * 100:.2f}%" 245 print( 246 f"\nPercent of time when {device_name.upper()} is busy: {device_busy_percent}" 247 ) 248 print(f"Total wall time {wall_time_ms:.3f} ms") 249 250 # output such a line so we can gather such line from all compiled modules from all 251 # benchmarks and tabulate it! 252 # Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent, 253 # unknown_category_percent, device_busy_percent, wall_time_ms 254 tabulate_line = f"Output for tabulate: {benchmark_name}" 255 for category in category_list: 256 percent = ( 257 f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%" 258 ) 259 tabulate_line += f", {percent}" 260 tabulate_line += f", {device_busy_percent}, {wall_time_ms:.3f}ms" 261 262 print(tabulate_line) 263 264 report() 265 266 267def compiled_module_main(benchmark_name, benchmark_compiled_module_fn): 268 """ 269 This is the function called in __main__ block of a compiled module. 270 """ 271 import argparse 272 273 parser = argparse.ArgumentParser() 274 parser.add_argument( 275 "--benchmark-kernels", 276 "-k", 277 action="store_true", 278 help="Whether to benchmark each individual kernels", 279 ) 280 parser.add_argument( 281 "--benchmark-all-configs", 282 "-c", 283 action="store_true", 284 help="Whether to benchmark each individual config for a kernel", 285 ) 286 parser.add_argument( 287 "--profile", 288 "-p", 289 action="store_true", 290 help="Whether to profile the compiled module", 291 ) 292 args = parser.parse_args() 293 294 if args.benchmark_kernels: 295 benchmark_all_kernels(benchmark_name, args.benchmark_all_configs) 296 else: 297 times = 10 298 repeat = 10 299 wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000 300 301 if not args.profile: 302 return 303 304 with torch.profiler.profile(record_shapes=True) as p: 305 benchmark_compiled_module_fn(times=times, repeat=repeat) 306 307 path = f"{tempfile.gettempdir()}/compiled_module_profile.json" 308 p.export_chrome_trace(path) 309 print(f"Profiling result for a compiled module of benchmark {benchmark_name}:") 310 print(f"Chrome trace for the profile is written to {path}") 311 event_list = p.key_averages(group_by_input_shape=True) 312 print(event_list.table(sort_by="self_device_time_total", row_limit=10)) 313 parse_profile_event_list( 314 benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device 315 ) 316