1# mypy: allow-untyped-defs 2import functools 3import itertools 4import logging 5from typing import cast, List, Tuple 6 7import sympy 8 9import torch 10from torch._inductor.select_algorithm import realize_inputs 11from torch._inductor.virtualized import V 12 13from .. import config as inductor_config 14from ..runtime.runtime_utils import next_power_of_2 15from ..utils import ceildiv as cdiv 16 17 18log = logging.getLogger(__name__) 19 20 21def triton_config(num_stages, num_warps, **kwargs): 22 from triton import Config 23 24 return Config(kwargs, num_stages=num_stages, num_warps=num_warps) 25 26 27def filtered_configs( 28 m: int, 29 n: int, 30 k: int, 31 configs: List[Tuple[int, int, int, int, int]], 32 has_int8_tensor=False, 33): 34 """Heuristic to shrink configs when they are bigger than the input size""" 35 36 min_block_size = 16 37 # block_k=16 seems to be causing issues 38 # see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424 39 min_block_size_k = 32 if has_int8_tensor else 16 40 m = max( 41 next_power_of_2( 42 V.graph.sizevars.size_hint( 43 m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] 44 ) 45 ), 46 min_block_size, 47 ) 48 n = max( 49 next_power_of_2( 50 V.graph.sizevars.size_hint( 51 n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] 52 ) 53 ), 54 min_block_size, 55 ) 56 k = max( 57 next_power_of_2( 58 V.graph.sizevars.size_hint( 59 k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] 60 ) 61 ), 62 min_block_size_k, 63 ) 64 used = set() 65 for block_m, block_n, block_k, num_stages, num_warps in configs: 66 # shrink configs for small sizes 67 block_m = max(min(block_m, m), min_block_size) 68 block_n = max(min(block_n, n), min_block_size) 69 block_k = max(min(block_k, k), min_block_size_k) 70 # each warp computes 16x16 tile = 256 71 num_warps = min(num_warps, block_m * block_n // 256) 72 if torch.version.hip: 73 for matrix_instr_nonkdim in [0, 16]: 74 if matrix_instr_nonkdim != 0 and ( 75 block_m % matrix_instr_nonkdim != 0 76 or block_n % matrix_instr_nonkdim != 0 77 ): 78 # block_m and block_n must be a multiple of matrix_instr_nonkdim 79 continue 80 if ( 81 block_m, 82 block_n, 83 block_k, 84 num_stages, 85 num_warps, 86 matrix_instr_nonkdim, 87 ) not in used: 88 used.add( 89 ( 90 block_m, 91 block_n, 92 block_k, 93 num_stages, 94 num_warps, 95 matrix_instr_nonkdim, 96 ) 97 ) 98 yield triton_config( 99 BLOCK_M=block_m, 100 BLOCK_N=block_n, 101 BLOCK_K=block_k, 102 num_stages=num_stages, 103 num_warps=num_warps, 104 matrix_instr_nonkdim=matrix_instr_nonkdim, 105 ) 106 else: 107 if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used: 108 used.add((block_m, block_n, block_k, num_stages, num_warps, 0)) 109 yield triton_config( 110 BLOCK_M=block_m, 111 BLOCK_N=block_n, 112 BLOCK_K=block_k, 113 num_stages=num_stages, 114 num_warps=num_warps, 115 ) 116 117 118# List of dictionaries to store the kernel configs. Configs that evaluate to true 119# will be utilised on the target platform. The configs are as follows: 120# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) 121mm_kernel_configs = ( 122 [ 123 {"config": (32, 32, 16, 1, 2), "cond": True}, 124 {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None}, 125 {"config": (32, 64, 32, 5, 8), "cond": True}, 126 {"config": (64, 32, 32, 5, 8), "cond": True}, 127 {"config": (64, 32, 128, 5, 4), "cond": True}, 128 {"config": (64, 64, 16, 2, 4), "cond": True}, 129 {"config": (64, 64, 32, 2, 4), "cond": True}, 130 {"config": (64, 64, 64, 3, 8), "cond": True}, 131 {"config": (64, 64, 128, 5, 4), "cond": True}, 132 {"config": (64, 128, 32, 3, 4), "cond": True}, 133 {"config": (64, 128, 32, 4, 8), "cond": True}, 134 {"config": (64, 128, 64, 3, 4), "cond": True}, 135 {"config": (64, 128, 128, 4, 4), "cond": True}, 136 {"config": (128, 64, 32, 3, 4), "cond": True}, 137 {"config": (128, 64, 32, 4, 8), "cond": True}, 138 {"config": (128, 128, 32, 2, 8), "cond": True}, 139 {"config": (128, 128, 32, 3, 4), "cond": True}, 140 {"config": (128, 128, 64, 3, 4), "cond": True}, 141 {"config": (128, 128, 64, 5, 8), "cond": True}, 142 ] 143 if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" 144 else [ 145 {"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True} 146 for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( 147 [16, 32, 64, 128, 256], repeat=3 148 ) 149 for num_stages in [1, 2, 3, 4, 5] 150 for num_warps in [2, 4, 8] 151 ] 152) 153 154# these are only used in tuned_mm when AutoHeuristic is enabled 155# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned 156# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10 157# which saves compilation time (since less configs are autotuned) and potentially increase performance 158# because the learned heuristic might predict a config that is not part mm_configs 159extra_mm_kernel_configs = [ 160 {"config": (16, 32, 16, 3, 2), "cond": True}, 161 {"config": (16, 32, 32, 4, 2), "cond": True}, 162 {"config": (16, 32, 32, 5, 2), "cond": True}, 163 {"config": (64, 64, 128, 3, 4), "cond": True}, 164 {"config": (128, 64, 32, 2, 2), "cond": True}, 165 {"config": (128, 64, 64, 3, 8), "cond": True}, 166 {"config": (128, 64, 128, 4, 8), "cond": True}, 167 {"config": (128, 128, 32, 4, 4), "cond": True}, 168 {"config": (128, 128, 64, 3, 8), "cond": True}, 169 {"config": (128, 128, 64, 5, 4), "cond": True}, 170] 171 172int8_mm_kernel_configs = [ 173 {"config": (64, 64, 32, 2, 4), "cond": True}, 174 {"config": (64, 128, 32, 3, 4), "cond": True}, 175 {"config": (128, 64, 32, 3, 4), "cond": True}, 176 {"config": (64, 128, 32, 4, 8), "cond": True}, 177 {"config": (128, 64, 32, 4, 8), "cond": True}, 178 {"config": (64, 32, 32, 5, 8), "cond": True}, 179 {"config": (32, 64, 32, 5, 8), "cond": True}, 180 {"config": (128, 128, 32, 2, 8), "cond": True}, 181 {"config": (64, 64, 64, 3, 8), "cond": True}, 182 # {"config": (32, 32, 128, 2, 4), "cond": True}, 183 # {"config": (64, 64, 16, 2, 4), "cond": True}, 184 # {"config": (32, 32, 16, 1, 2), "cond": True}, 185 {"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None}, 186 {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None}, 187] 188 189# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192). 190mixed_mm_kernel_configs_small_m = [ 191 {"config": (16, 128, 256, 3, 4), "cond": True}, 192 {"config": (16, 128, 256, 5, 8), "cond": True}, 193] 194 195mixed_mm_kernel_configs = ( 196 mm_kernel_configs + mixed_mm_kernel_configs_small_m 197 if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" 198 else mm_kernel_configs 199) 200 201scaled_mm_kernel_configs = [ 202 {"config": (128, 256, 32, 3, 8), "cond": True}, 203 {"config": (256, 128, 32, 3, 8), "cond": True}, 204 {"config": (256, 64, 32, 4, 4), "cond": True}, 205 {"config": (64, 256, 32, 4, 4), "cond": True}, 206 {"config": (128, 128, 32, 4, 4), "cond": True}, 207 {"config": (128, 64, 32, 4, 4), "cond": True}, 208 {"config": (64, 128, 32, 4, 4), "cond": True}, 209 {"config": (128, 32, 32, 4, 4), "cond": True}, 210 {"config": (64, 32, 32, 5, 2), "cond": True}, 211 {"config": (256, 128, 128, 3, 8), "cond": True}, 212 {"config": (256, 64, 128, 4, 4), "cond": True}, 213 {"config": (64, 256, 128, 4, 4), "cond": True}, 214 {"config": (128, 128, 128, 4, 4), "cond": True}, 215 {"config": (128, 64, 64, 4, 4), "cond": True}, 216 {"config": (64, 128, 64, 4, 4), "cond": True}, 217 {"config": (128, 32, 64, 4, 4), "cond": True}, 218 {"config": (64, 32, 64, 5, 2), "cond": True}, 219 {"config": (16, 32, 32, 2, 2), "cond": True}, 220 {"config": (16, 64, 32, 2, 2), "cond": True}, 221 {"config": (16, 128, 32, 2, 4), "cond": True}, 222 {"config": (16, 256, 32, 2, 4), "cond": True}, 223 {"config": (16, 32, 64, 2, 2), "cond": True}, 224 {"config": (16, 64, 64, 2, 2), "cond": True}, 225 {"config": (16, 128, 64, 2, 4), "cond": True}, 226 {"config": (16, 256, 64, 2, 4), "cond": True}, 227 {"config": (32, 32, 32, 2, 2), "cond": True}, 228 {"config": (32, 64, 32, 2, 2), "cond": True}, 229 {"config": (32, 128, 32, 2, 4), "cond": True}, 230 {"config": (32, 256, 32, 2, 4), "cond": True}, 231 {"config": (32, 32, 64, 2, 2), "cond": True}, 232 {"config": (32, 64, 64, 2, 2), "cond": True}, 233 {"config": (32, 128, 64, 2, 4), "cond": True}, 234 {"config": (32, 256, 64, 2, 4), "cond": True}, 235 {"config": (16, 32, 32, 3, 2), "cond": True}, 236 {"config": (16, 64, 32, 3, 2), "cond": True}, 237 {"config": (16, 128, 32, 3, 4), "cond": True}, 238 {"config": (16, 256, 32, 3, 4), "cond": True}, 239 {"config": (16, 32, 64, 3, 2), "cond": True}, 240 {"config": (16, 64, 64, 3, 2), "cond": True}, 241 {"config": (16, 128, 64, 3, 4), "cond": True}, 242 {"config": (16, 256, 64, 3, 4), "cond": True}, 243 {"config": (32, 32, 32, 3, 2), "cond": True}, 244 {"config": (32, 64, 32, 3, 2), "cond": True}, 245 {"config": (32, 128, 32, 3, 4), "cond": True}, 246 {"config": (32, 256, 32, 3, 4), "cond": True}, 247 {"config": (32, 32, 64, 3, 2), "cond": True}, 248 {"config": (32, 64, 64, 3, 2), "cond": True}, 249 {"config": (32, 128, 64, 3, 4), "cond": True}, 250 {"config": (32, 256, 64, 3, 4), "cond": True}, 251 {"config": (16, 32, 32, 4, 2), "cond": True}, 252 {"config": (16, 64, 32, 4, 2), "cond": True}, 253 {"config": (16, 128, 32, 4, 4), "cond": True}, 254 {"config": (16, 256, 32, 4, 4), "cond": True}, 255 {"config": (16, 32, 64, 4, 2), "cond": True}, 256 {"config": (16, 64, 64, 4, 2), "cond": True}, 257 {"config": (16, 128, 64, 4, 4), "cond": True}, 258 {"config": (16, 256, 64, 4, 4), "cond": True}, 259 {"config": (32, 32, 32, 4, 2), "cond": True}, 260 {"config": (32, 64, 32, 4, 2), "cond": True}, 261 {"config": (32, 128, 32, 4, 4), "cond": True}, 262 {"config": (32, 256, 32, 4, 4), "cond": True}, 263 {"config": (32, 32, 64, 4, 2), "cond": True}, 264 {"config": (32, 64, 64, 4, 2), "cond": True}, 265 {"config": (32, 128, 64, 4, 4), "cond": True}, 266 {"config": (32, 256, 64, 4, 4), "cond": True}, 267 {"config": (16, 32, 32, 5, 2), "cond": True}, 268 {"config": (16, 64, 32, 5, 2), "cond": True}, 269 {"config": (16, 128, 32, 5, 4), "cond": True}, 270 {"config": (16, 256, 32, 5, 4), "cond": True}, 271 {"config": (16, 32, 64, 5, 2), "cond": True}, 272 {"config": (16, 64, 64, 5, 2), "cond": True}, 273 {"config": (16, 128, 64, 5, 4), "cond": True}, 274 {"config": (16, 256, 64, 5, 4), "cond": True}, 275 {"config": (32, 32, 32, 5, 2), "cond": True}, 276 {"config": (32, 64, 32, 5, 2), "cond": True}, 277 {"config": (32, 128, 32, 5, 4), "cond": True}, 278 {"config": (32, 256, 32, 5, 4), "cond": True}, 279 {"config": (32, 32, 64, 5, 2), "cond": True}, 280 {"config": (32, 64, 64, 5, 2), "cond": True}, 281 {"config": (32, 128, 64, 5, 4), "cond": True}, 282 {"config": (32, 256, 64, 5, 4), "cond": True}, 283 {"config": (16, 32, 32, 6, 2), "cond": True}, 284 {"config": (16, 64, 32, 6, 2), "cond": True}, 285 {"config": (16, 128, 32, 6, 4), "cond": True}, 286 {"config": (16, 256, 32, 6, 4), "cond": True}, 287 {"config": (16, 32, 64, 6, 2), "cond": True}, 288 {"config": (16, 64, 64, 6, 2), "cond": True}, 289 {"config": (16, 128, 64, 6, 4), "cond": True}, 290 {"config": (16, 256, 64, 6, 4), "cond": True}, 291 {"config": (32, 32, 32, 6, 2), "cond": True}, 292 {"config": (32, 64, 32, 6, 2), "cond": True}, 293 {"config": (32, 128, 32, 6, 4), "cond": True}, 294 {"config": (32, 256, 32, 6, 4), "cond": True}, 295 {"config": (32, 32, 64, 6, 2), "cond": True}, 296 {"config": (32, 64, 64, 6, 2), "cond": True}, 297 {"config": (32, 128, 64, 6, 4), "cond": True}, 298 {"config": (32, 256, 64, 6, 4), "cond": True}, 299] 300 301 302# Create filtered list of configs based on cond evaluation 303mm_platform_configs = tuple( 304 cast(Tuple[int, int, int, int, int], config["config"]) 305 for config in mm_kernel_configs 306 if config["cond"] 307) 308extra_mm_platform_configs = tuple( 309 cast(Tuple[int, int, int, int, int], config["config"]) 310 for config in extra_mm_kernel_configs 311 if config["cond"] 312) 313int8_platform_configs = tuple( 314 cast(Tuple[int, int, int, int, int], config["config"]) 315 for config in int8_mm_kernel_configs 316 if config["cond"] 317) 318mixed_mm_platform_configs = tuple( 319 cast(Tuple[int, int, int, int, int], config["config"]) 320 for config in mixed_mm_kernel_configs 321 if config["cond"] 322) 323scaled_mm_platform_configs = tuple( 324 cast(Tuple[int, int, int, int, int], config["config"]) 325 for config in scaled_mm_kernel_configs 326 if config["cond"] 327) 328 329# On ROCm convert num_stages to 0 to enable software pipelining 330if torch.version.hip: 331 mm_platform_configs = tuple( 332 (config[0], config[1], config[2], 0, config[4]) 333 for config in mm_platform_configs 334 ) 335 extra_mm_platform_configs = tuple( 336 (config[0], config[1], config[2], 0, config[4]) 337 for config in extra_mm_platform_configs 338 ) 339 int8_platform_configs = tuple( 340 (config[0], config[1], config[2], 0, config[4]) 341 for config in mm_platform_configs 342 ) 343 mixed_mm_platform_configs = tuple( 344 (config[0], config[1], config[2], 0, config[4]) 345 for config in mixed_mm_platform_configs 346 ) 347 scaled_mm_platform_configs = tuple( 348 (config[0], config[1], config[2], 0, config[4]) 349 for config in scaled_mm_platform_configs 350 ) 351 352mm_configs = functools.partial( 353 filtered_configs, 354 configs=mm_platform_configs, 355) 356 357extra_mm_configs = functools.partial( 358 filtered_configs, 359 configs=extra_mm_platform_configs, 360) 361 362int8_mm_configs = functools.partial( 363 filtered_configs, 364 configs=int8_platform_configs, 365) 366 367mixed_mm_configs = functools.partial( 368 filtered_configs, 369 configs=mixed_mm_platform_configs, 370) 371 372scaled_mm_configs = functools.partial( 373 filtered_configs, 374 configs=scaled_mm_platform_configs, 375) 376 377 378def mm_grid(m, n, meta): 379 """ 380 The CUDA grid size for matmul triton templates. 381 """ 382 return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1) 383 384 385def acc_type(dtype): 386 if dtype in (torch.float16, torch.bfloat16): 387 return "tl.float32" 388 return f"tl.{dtype}".replace("torch.", "") 389 390 391def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None): 392 """ 393 Common options to matmul triton templates. 394 """ 395 even_k_symbolic = ( 396 # it isn't worth guarding on this 397 sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) 398 == config.kwargs["BLOCK_K"] 399 ) 400 allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( 401 not inductor_config.force_same_precision 402 or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0) 403 ) 404 return dict( 405 GROUP_M=8, 406 EVEN_K=even_k_symbolic, 407 ALLOW_TF32=allow_tf32, 408 ACC_TYPE=acc_type(layout.dtype), 409 B_PROLOGUE_CAST_TYPE=b_prologue_cast_type, 410 num_stages=config.num_stages, 411 num_warps=config.num_warps, 412 **config.kwargs, 413 ) 414 415 416def mm_args( 417 mat1, 418 mat2, 419 *others, 420 layout=None, 421 out_dtype=None, 422 use_4x2_dim=False, 423 mat2_transposed=False, 424): 425 """ 426 Common arg processing for mm,bmm,addmm,etc 427 """ 428 mat1, mat2 = realize_inputs(mat1, mat2) 429 *b1, m, k1 = mat1.get_size() 430 if mat2_transposed: 431 *b2, n, k2 = mat2.get_size() 432 else: 433 *b2, k2, n = mat2.get_size() 434 b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)] 435 if use_4x2_dim: 436 k2 = k2 * 2 437 k = V.graph.sizevars.guard_equals(k1, k2) 438 if layout is None: 439 from torch._inductor.ir import FixedLayout 440 441 if out_dtype is None: 442 out_dtype = mat1.get_dtype() 443 444 layout = FixedLayout( 445 mat1.get_device(), 446 out_dtype, 447 [*b, m, n], 448 ) 449 else: 450 assert out_dtype is None, "out_dtype is ignored if layout is specified." 451 from ..lowering import expand 452 453 others = [realize_inputs(expand(x, layout.size)) for x in others] 454 455 return [m, n, k, layout, mat1, mat2, *others] 456 457 458def addmm_epilogue(dtype, alpha, beta): 459 def epilogue(acc, bias): 460 if alpha != 1: 461 acc = V.ops.mul(acc, V.ops.constant(alpha, dtype)) 462 if beta != 1: 463 bias = V.ops.mul(bias, V.ops.constant(beta, dtype)) 464 return V.ops.add(acc, bias) 465 466 return epilogue 467