1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import csv 5import dataclasses 6import inspect 7import os 8import re 9from dataclasses import dataclass 10from functools import lru_cache 11from typing import Dict, List, Set, Tuple, TYPE_CHECKING 12 13from torch._inductor import config 14from torch._inductor.utils import get_benchmark_name 15 16 17# Prevent circular import 18if TYPE_CHECKING: 19 from torch._inductor.scheduler import BaseSchedulerNode 20 21# counter for tracking how many kernels have been generated 22generated_kernel_count = 0 23generated_cpp_vec_kernel_count = 0 24num_bytes_accessed = 0 25nodes_num_elem: List[ 26 Tuple[ 27 BaseSchedulerNode, 28 int, 29 ] 30] = [] 31node_runtimes: List[Tuple[BaseSchedulerNode, float]] = [] 32 33# counters for tracking fusions 34ir_nodes_pre_fusion = 0 35 36# counters for tracking to_dtype inserted 37cpp_to_dtype_count = 0 38 39 40@dataclasses.dataclass 41class CppOuterLoopFusedCount: 42 inner_kernel_number: int 43 local_buffer_number: int = 0 44 45 46# The length counts the number of outer loop fusions. 47cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = [] 48 49num_comprehensive_padding = 0 50num_matches_for_scatter_upon_const_tensor = 0 51 52num_loop_reordering = 0 53 54 55# reset all counters 56def reset(): 57 global generated_kernel_count 58 global generated_cpp_vec_kernel_count 59 global num_bytes_accessed, nodes_num_elem 60 global ir_nodes_pre_fusion 61 global cpp_to_dtype_count 62 global cpp_outer_loop_fused_inner_counts 63 global num_comprehensive_padding 64 global num_matches_for_scatter_upon_const_tensor 65 global num_loop_reordering 66 67 generated_kernel_count = 0 68 generated_cpp_vec_kernel_count = 0 69 num_bytes_accessed = 0 70 nodes_num_elem.clear() 71 node_runtimes.clear() 72 ir_nodes_pre_fusion = 0 73 cpp_to_dtype_count = 0 74 cpp_outer_loop_fused_inner_counts.clear() 75 num_comprehensive_padding = 0 76 num_matches_for_scatter_upon_const_tensor = 0 77 num_loop_reordering = 0 78 79 80@dataclass 81class CachedMetricsDeltas: 82 """ 83 The subset of metrics we want update across cache hits, e.g., the 84 FxGraphCache. 85 """ 86 87 generated_kernel_count: int 88 generated_cpp_vec_kernel_count: int 89 ir_nodes_pre_fusion: int 90 cpp_to_dtype_count: int 91 num_bytes_accessed: int 92 num_matches_for_scatter_upon_const_tensor: int 93 94 95def get_metric_fields(): 96 return [field.name for field in dataclasses.fields(CachedMetricsDeltas)] 97 98 99class CachedMetricsHelper: 100 """ 101 A helper class to help calculate and apply counter deltas for those 102 metrics we want to save with cache entries (e.g., FxGraphCache) and 103 apply on a cache hit. 104 """ 105 106 def __init__(self) -> None: 107 self.cached_metrics = {} 108 for metric in get_metric_fields(): 109 self.cached_metrics[metric] = globals()[metric] 110 111 def get_deltas(self) -> CachedMetricsDeltas: 112 delta_metrics = {} 113 for metric in get_metric_fields(): 114 delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric] 115 116 return CachedMetricsDeltas(**delta_metrics) 117 118 @staticmethod 119 def apply_deltas(delta: CachedMetricsDeltas): 120 for metric in get_metric_fields(): 121 globals()[metric] += getattr(delta, metric) 122 123 124REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {} 125 126 127@dataclass 128class MetricTable: 129 table_name: str 130 column_names: List[str] 131 132 num_rows_added: int = 0 133 134 def add_row(self, row_fn): 135 if self.table_name not in enabled_metric_tables(): 136 return 137 138 row_dict = row_fn() 139 assert len(self.column_names) == len( 140 row_dict 141 ), f"{len(self.column_names)} v.s. {len(row_dict)}" 142 assert set(self.column_names) == set( 143 row_dict.keys() 144 ), f"{set(self.column_names)} v.s. {set(row_dict.keys())}" 145 146 row = [ 147 get_benchmark_name(), 148 ] 149 row += [row_dict[column_name] for column_name in self.column_names] 150 self._write_row(row) 151 152 def output_filename(self): 153 return f"metric_table_{self.table_name}.csv" 154 155 def write_header(self): 156 filename = self.output_filename() 157 with open(filename, "w") as fd: 158 writer = csv.writer(fd, lineterminator="\n") 159 writer.writerow(["model_name"] + self.column_names) 160 161 def _write_row(self, row): 162 filename = self.output_filename() 163 if self.num_rows_added == 0 and not os.path.exists(filename): 164 self.write_header() 165 166 self.num_rows_added += 1 167 168 for idx, orig_val in enumerate(row): 169 if isinstance(orig_val, float): 170 new_val = f"{orig_val:.6f}" 171 elif orig_val is None: 172 new_val = "" 173 else: 174 new_val = orig_val 175 row[idx] = new_val 176 177 with open(filename, "a") as fd: 178 writer = csv.writer(fd, lineterminator="\n") 179 writer.writerow(row) 180 181 @staticmethod 182 def register_table(name, column_names): 183 table = MetricTable(name, column_names) 184 REGISTERED_METRIC_TABLES[name] = table 185 186 187MetricTable.register_table( 188 "slow_fusion", 189 [ 190 "kernel1_path", 191 "kernel1_latency", 192 "kernel2_path", 193 "kernel2_latency", 194 "fused_kernel_path", 195 "fused_kernel_latency", 196 "slow_down_ratio", 197 ], 198) 199 200# track the fusion statistics for each graph 201MetricTable.register_table( 202 "graph_stats", 203 [ 204 "graph_id", 205 "num_nodes_before_fusion", 206 "num_nodes_after_fusion", 207 ], 208) 209 210# track the perf difference between persistent reduction and non-persistent 211# reductions 212MetricTable.register_table( 213 "persistent_red_perf", 214 [ 215 "kernel1_name", 216 "kernel2_name", 217 "kernel1_latency", 218 "kernel2_latency", 219 "size_hints", 220 "reduction_hint", 221 "speedup", 222 ], 223) 224 225# Log the fusion failures due to indexing mismatch 226MetricTable.register_table( 227 "fusion_failure_due_to_indexing_mismatch", 228 [ 229 "pre_grad_graph_id", 230 "post_grad_graph_id", 231 "node1_name", 232 "node2_name", 233 "node1_debug_str", 234 "node2_debug_str", 235 "common_buffer_names", 236 "failure_reason", 237 ], 238) 239 240# Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint 241MetricTable.register_table( 242 "kernel_metadata", 243 [ 244 "kernel_name", 245 "kernel_path", 246 "kernel_category", # pointwise/reduction/foreach etc. 247 "size_hints", 248 "reduction_hint", 249 "line_of_code", 250 "num_load", 251 "num_store", 252 "num_for_loop", 253 "num_atomic_add", 254 "num_args", 255 # xyz numel can be different to size_hints since size_hints are rounded 256 # up to the nearest power of 2. 257 # Inductor kernel will burn in the xyz numel in kernel code for static 258 # shape kernels. 259 # Logging them will be helpful to find unaligned shape for reduction 260 "xnumel", 261 "ynumel", 262 "rnumel", 263 "kernel_args_num_gb", 264 ], 265) 266 267 268def _parse_kernel_fn_code(kernel_module_code): 269 """ 270 The kernel_module_code is the python module that contains kernel function code. 271 kernel function is the proper triton kernel function annotated with 272 @triton.jit 273 """ 274 from .codecache import PyCodeCache 275 from .wrapper_benchmark import get_triton_kernel 276 277 mod = PyCodeCache.load(kernel_module_code) 278 kernel = get_triton_kernel(mod) 279 # kernel is a CachingAutotune; kernel.fn is the JITFunction; 280 # kernel.fn.fn is the function being decorate by triton.jit 281 return inspect.getsource(kernel.fn.fn) 282 283 284def _parse_kernel_line_of_code(proper_kernel_fn_code): 285 """ 286 Return the line of code for the kernel excluding the decorators. 287 """ 288 return len(proper_kernel_fn_code.splitlines()) 289 290 291def _parse_size_hints(kernel_module_code, kernel_category): 292 if kernel_category == "foreach": 293 # foreach kernel does not have size_hints 294 return None 295 m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code) 296 assert m, "size_hints missing!" 297 return m.group(1) 298 299 300def _parse_reduction_hint(kernel_category, kernel_module_code): 301 if kernel_category not in ("reduction", "persistent_reduction"): 302 return None 303 m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code) 304 assert m, "reduction_hint not found in kernel source code!" 305 return m.group(1) 306 307 308def _count_pattern(proper_kernel_fn_code, pattern): 309 return proper_kernel_fn_code.count(pattern) 310 311 312def _count_args(proper_kernel_fn_code): 313 def_line = proper_kernel_fn_code.splitlines()[0] 314 assert def_line.startswith("def ") 315 start_idx = def_line.index("(") 316 end_idx = def_line.index("):") 317 decl_csv = def_line[start_idx + 1 : end_idx] 318 comps = decl_csv.split(",") 319 return len(comps) 320 321 322def _parse_proper_kernel_fn_code(kernel_fn_code): 323 """ 324 Skip decorators. 325 """ 326 start_pos = kernel_fn_code.index("def ") 327 return kernel_fn_code[start_pos:] 328 329 330def _parse_numel(proper_kernel_fn_code, numel_arg_name): 331 m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code) 332 if m: 333 return int(m.group(1)) 334 else: 335 return None 336 337 338def _parse_kernel_args_num_gb(kernel_fn_code, kernel_category): 339 """ 340 inductor meta looks like: 341 inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0}, 342 """ 343 m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code) 344 if m: 345 return float(m.group(1)) 346 else: 347 """ 348 There are a few cases that kernel_num_gdb field can be missing: 349 1. the field will be missing if config.benchmark_kernel and 350 config.profile_bandwidth are false 351 2. even if config.benchmark_kernel or config.profile_bandwidth is true. 352 foreach kernel does not have kernel_num_gb field in the metadata 353 """ 354 return None 355 356 357def log_kernel_metadata(kernel_name, kernel_path, kernel_module_code): 358 """ 359 An utility to log kernel metadata. We may parse metadata from kernel source code here. 360 361 It's fine to parse the generated kernel code here since the logging is 362 disabled by default. It would hurt compilation time. 363 """ 364 from .wrapper_benchmark import get_kernel_category_by_source_code 365 366 kernel_category = get_kernel_category_by_source_code(kernel_module_code) 367 reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code) 368 size_hints = _parse_size_hints(kernel_module_code, kernel_category) 369 kernel_fn_code = _parse_kernel_fn_code(kernel_module_code) 370 371 proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code) 372 373 # the line of code excluding the decortors 374 kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code) 375 376 get_metric_table("kernel_metadata").add_row( 377 lambda: { 378 "kernel_name": kernel_name, 379 "kernel_path": kernel_path, 380 "kernel_category": kernel_category, 381 "size_hints": size_hints, 382 "reduction_hint": reduction_hint, 383 "line_of_code": kernel_line_of_code, 384 "num_load": _count_pattern(proper_kernel_fn_code, "tl.load"), 385 "num_store": _count_pattern(proper_kernel_fn_code, "tl.store"), 386 "num_for_loop": _count_pattern(proper_kernel_fn_code, "for "), 387 "num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"), 388 "num_args": _count_args(proper_kernel_fn_code), 389 "xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"), 390 "ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"), 391 "rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"), 392 "kernel_args_num_gb": _parse_kernel_args_num_gb( 393 kernel_fn_code, kernel_category 394 ), 395 } 396 ) 397 398 399def purge_old_log_files(): 400 """ 401 Purge the old log file at the beginning when the benchmark script runs. 402 Should do it in the parent process rather than the child processes running 403 each individual model. 404 """ 405 for name, table in REGISTERED_METRIC_TABLES.items(): 406 if name in enabled_metric_tables(): 407 filename = table.output_filename() 408 if os.path.exists(filename): 409 os.unlink(filename) 410 411 table.write_header() 412 413 414@lru_cache 415def enabled_metric_tables() -> Set[str]: 416 config_str = config.enabled_metric_tables 417 418 enabled = set() 419 for name in config_str.split(","): 420 name = name.strip() 421 if not name: 422 continue 423 assert ( 424 name in REGISTERED_METRIC_TABLES 425 ), f"Metric table name {name} is not registered" 426 enabled.add(name) 427 return enabled 428 429 430def is_metric_table_enabled(name): 431 return name in enabled_metric_tables() 432 433 434def get_metric_table(name): 435 assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined" 436 return REGISTERED_METRIC_TABLES[name] 437