1# mypy: allow-untyped-defs 2import dataclasses 3import sys 4from enum import Enum 5from typing import Callable, Dict, List, Optional, Type 6 7import sympy 8 9import torch 10 11from .. import ir 12from ..cpu_vec_isa import pick_vec_isa, VecAMX, VecAVX2, VecAVX512, VecISA 13from ..utils import IndentedBuffer, parallel_num_threads 14from ..virtualized import V 15from .common import KernelTemplate 16from .cpp_template_kernel import CppTemplateKernel 17from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp 18 19 20class LayoutType(Enum): 21 NORMAL = 0 22 VNNI2 = 1 23 VNNI4 = 2 24 25 26_IS_WINDOWS = sys.platform == "win32" 27 28 29def get_restrict_keyword() -> str: 30 if _IS_WINDOWS: 31 # https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170 32 return "__restrict" 33 else: 34 return "__restrict__" 35 36 37class CppMicroGemm: 38 """ 39 A class that codegens a kernel that computes small-sized matrix multiplication. 40 41 A micro GEMM kernel is responsible for register blocking, instruction selection, 42 and other CPU architecture-specific optimizations. 43 44 The subclasses need to override `codegen_define` to define the kernel function 45 that is called by the code generated by `codegen_call`. 46 """ 47 48 # TODO(jgong5): support constant shapes and lds as template args. 49 DECLARE_KERNEL = r""" 50template <bool accum> 51inline void {{kernel_name}}( 52{%- if kernel_extra_args_declare %} 53 {{kernel_extra_args_declare}} 54{%- endif %} 55 const {{input_t}}* {{restrict_keyword}} A, 56 const {{input2_t}}* {{restrict_keyword}} B, 57 {{output_t}}* {{restrict_keyword}} C, 58 int64_t M, 59 int64_t N, 60 int64_t K, 61 int64_t lda, 62 int64_t ldb, 63 int64_t ldc 64) 65""" 66 67 def __init__( 68 self, 69 name, 70 input_dtype, 71 input2_dtype, 72 output_dtype, 73 compute_dtype, 74 register_blocking, 75 alpha=1, 76 ) -> None: 77 self.name = name 78 self.input_dtype = input_dtype 79 assert input2_dtype is not None 80 self.input2_dtype = input2_dtype 81 self.output_dtype = output_dtype 82 self.compute_dtype = compute_dtype 83 self.register_blocking = register_blocking 84 self.alpha = alpha 85 86 def get_common_options(self): 87 if self.input_dtype == torch.uint8: 88 assert self.compute_dtype == torch.int32 89 assert self.output_dtype == torch.int32 90 assert self.input2_dtype == torch.int8 91 return { 92 "torch": torch, 93 "kernel_name": self.name, 94 "input_dtype": self.input_dtype, 95 "input2_dtype": self.input2_dtype, 96 "output_dtype": self.output_dtype, 97 "compute_dtype": self.compute_dtype, 98 "input_t": DTYPE_TO_CPP[self.input_dtype], 99 "input2_t": DTYPE_TO_CPP[self.input2_dtype], 100 "output_t": DTYPE_TO_CPP[self.output_dtype], 101 "compute_t": DTYPE_TO_CPP[self.compute_dtype], 102 "alpha": self.alpha, 103 "kernel_extra_args_declare": self.get_kernel_extra_args_declare(), 104 "int8_gemm": self.input_dtype == torch.uint8, 105 "vnni_size": 4 if self.input_dtype == torch.uint8 else 2, 106 "restrict_keyword": get_restrict_keyword(), 107 } 108 109 def get_kernel_declaration(self): 110 options = self.get_common_options() 111 return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options) 112 113 def get_kernel_extra_args_declare(self) -> str: 114 return "" 115 116 def get_kernel_extra_args(self) -> str: 117 return "" 118 119 def codegen_define(self, kernel: CppTemplateKernel) -> str: 120 raise NotImplementedError 121 122 def codegen_call( 123 self, 124 kernel: CppTemplateKernel, 125 A: ir.Buffer, 126 B: ir.Buffer, 127 C: ir.Buffer, 128 accum: bool, 129 ) -> str: 130 """ 131 Generate the code for calling the templated kernel that computes 132 `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise. 133 """ 134 A_ptr = f"&({kernel.index(A, [0, 0])})" 135 B_ptr = f"&({kernel.index(B, [0, 0])})" 136 C_ptr = f"&({kernel.index(C, [0, 0])})" 137 M = kernel.size(C, 0) 138 N = kernel.size(C, 1) 139 K = kernel.size(A, 1) 140 lda = kernel.stride(A, 0) 141 ldb = kernel.stride(B, 0) 142 ldc = kernel.stride(C, 0) 143 res = IndentedBuffer() 144 res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(") 145 with res.indent(): 146 extra_args = self.get_kernel_extra_args() 147 if extra_args: 148 res.writeline(extra_args) 149 res.writeline(f"{A_ptr},") 150 res.writeline(f"{B_ptr},") 151 res.writeline(f"{C_ptr},") 152 res.writeline(f"{M},") 153 res.writeline(f"{N},") 154 res.writeline(f"{K},") 155 res.writeline(f"{lda},") 156 res.writeline(f"{ldb},") 157 res.writeline(f"{ldc}") 158 res.writeline(");") 159 return res.getvalue() 160 161 def codegen_init( 162 self, 163 kernel: CppTemplateKernel, 164 ) -> str: 165 return "" 166 167 def codegen_finalize( 168 self, 169 kernel: CppTemplateKernel, 170 ) -> str: 171 return "" 172 173 def get_b_layout(self) -> LayoutType: 174 return LayoutType.NORMAL 175 176 177@dataclasses.dataclass 178class CppMicroGemmConfig: 179 input_dtype: torch.dtype 180 input2_dtype: torch.dtype 181 output_dtype: torch.dtype 182 compute_dtype: torch.dtype 183 vec_isa_cls: Type[VecISA] 184 register_blocking: GemmBlocking 185 extra_check: Optional[Callable[..., bool]] = None 186 187 188micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {} 189 190 191def register_micro_gemm(*configs): 192 def inner(cls): 193 assert ( 194 cls not in micro_gemm_configs 195 ), f"Duplicate micro_gemm registration for {cls}" 196 assert len(configs) > 0, f"No micro_gemm configs provided for {cls}" 197 micro_gemm_configs[cls] = list(configs) 198 return cls 199 200 return inner 201 202 203def generate_gemm_config( 204 vec_isa_cls, 205 register_blockings, 206 input_dtype=torch.float, 207 input2_dtype=None, 208 output_dtype=None, 209 compute_dtype=None, 210 extra_check=None, 211): 212 if output_dtype is None: 213 output_dtype = input_dtype 214 if compute_dtype is None: 215 compute_dtype = output_dtype 216 if input2_dtype is None: 217 input2_dtype = input_dtype 218 return [ 219 CppMicroGemmConfig( 220 input_dtype, 221 input2_dtype, 222 output_dtype, 223 compute_dtype, 224 vec_isa_cls, 225 GemmBlocking(*blocking), 226 extra_check, 227 ) 228 for blocking in register_blockings 229 ] 230 231 232class CppMicroGemmRef(CppMicroGemm): 233 """ 234 A reference implementation of the CppMicroGemm class with naive C++ code. 235 It is used for correctness debugging. 236 """ 237 238 TEMPLATE_ENTRY = r""" 239{{declare_kernel}} { 240 for (int64_t m = 0; m < M; ++m) { 241 for (int64_t n = 0; n < N; ++n) { 242 {{compute_t}} result = accum ? C[m * ldc + n] : 0; 243 for (int64_t k = 0; k < K; ++k) { 244 result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}}; 245 } 246 C[m * ldc + n] = result; 247 } 248 } 249} 250""" 251 252 def __init__( 253 self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha 254 ) -> None: 255 super().__init__( 256 name, 257 input_dtype, 258 input2_dtype, 259 output_dtype, 260 compute_dtype, 261 GemmBlocking(1, 1, 1), 262 alpha, 263 ) 264 265 def codegen_define(self, kernel: CppTemplateKernel) -> str: 266 options = { 267 "declare_kernel": self.get_kernel_declaration(), 268 **self.get_common_options(), 269 } 270 return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) 271 272 273@register_micro_gemm( 274 *generate_gemm_config( 275 VecAVX512, 276 [(8, 48, 1), (8, 32, 1), (16, 16, 1)], 277 input_dtype=torch.float, 278 ), 279 *generate_gemm_config( 280 VecAVX512, 281 [(8, 48, 1), (8, 32, 1), (16, 16, 1)], 282 input_dtype=torch.bfloat16, 283 output_dtype=torch.float, 284 ), 285 *generate_gemm_config( 286 VecAVX512, 287 [(8, 48, 1), (8, 32, 1), (16, 16, 1)], 288 input_dtype=torch.half, 289 output_dtype=torch.float, 290 ), 291 *generate_gemm_config( 292 VecAVX512, 293 [(8, 48, 1), (8, 32, 1), (16, 16, 1)], 294 input_dtype=torch.bfloat16, 295 input2_dtype=torch.int8, 296 output_dtype=torch.float, 297 compute_dtype=torch.float, 298 ), 299 *generate_gemm_config( 300 VecAVX2, 301 [(4, 24, 1), (4, 16, 1), (8, 8, 1)], 302 input_dtype=torch.float, 303 ), 304 *generate_gemm_config( 305 VecAVX2, 306 [(4, 24, 1), (4, 16, 1), (8, 8, 1)], 307 input_dtype=torch.bfloat16, 308 output_dtype=torch.float, 309 ), 310 *generate_gemm_config( 311 VecAVX2, 312 [(4, 24, 1), (4, 16, 1), (8, 8, 1)], 313 input_dtype=torch.half, 314 output_dtype=torch.float, 315 ), 316 *generate_gemm_config( 317 VecAVX2, 318 [(4, 24, 1), (4, 16, 1), (8, 8, 1)], 319 input_dtype=torch.bfloat16, 320 input2_dtype=torch.int8, 321 output_dtype=torch.float, 322 compute_dtype=torch.float, 323 ), 324) 325class CppMicroGemmFP32Vec(CppMicroGemm): 326 """ 327 This class generates the code for micro gemm using fp32 vec instructions for compute. 328 It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output. 329 The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template, 330 if the desired output is BF16/FP16. 331 """ 332 333 TEMPLATE_ENTRY = r""" 334{{declare_kernel}} { 335 TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); 336 TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); 337 // TODO(jgong5): loop unroll for M and N 338 for (int64_t m = 0; m < M; m += {{block_m}}) { 339 int64_t block_m = std::min<int64_t>(M - m, {{block_m}}); 340 for (int64_t n = 0; n < N; n += {{block_n}}) { 341 if (block_m == {{block_m}}) { 342 {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>( 343 A + m * lda, 344 B + n, 345 C + m * ldc + n, 346 K, 347 lda, 348 ldb, 349 ldc 350 ); 351 } else { 352 switch (block_m) { 353{%- for b in range(block_m - 1, 0, -1) %} 354 case {{b}}: 355 {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>( 356 A + m * lda, 357 B + n, 358 C + m * ldc + n, 359 K, 360 lda, 361 ldb, 362 ldc 363 ); 364 break; 365{%- endfor %} 366 default: 367 {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); 368 } 369 } 370 } 371 } 372} 373""" 374 375 TEMPLATE_KERNEL = r""" 376template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum> 377inline void {{kernel_name}}_kernel( 378 const {{input_t}}* {{restrict_keyword}} A, 379 const {{input2_t}}* {{restrict_keyword}} B, 380 {{output_t}}* {{restrict_keyword}} C, 381 int64_t K, 382 int64_t lda, 383 int64_t ldb, 384 int64_t ldc 385) { 386 using Vectorized = at::vec::Vectorized<{{compute_t}}>; 387 using VectorizedIn = at::vec::Vectorized<{{input_t}}>; 388 constexpr auto VLEN = Vectorized::size(); 389 constexpr auto ROWS = BLOCK_M; 390 constexpr auto COLS = BLOCK_N / VLEN; 391 392 Vectorized va; 393 at::vec::VectorizedN<{{compute_t}}, COLS> vb; 394 at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; 395 396 auto loadc = [&](auto i) { 397 if constexpr (accum) { 398 constexpr int row = i / COLS; 399 constexpr int col = i % COLS; 400 vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); 401 } else { 402 vc[i] = Vectorized(0.0f); 403 } 404 }; 405 c10::ForcedUnroll<ROWS * COLS>{}(loadc); 406 407 auto compute = [&, COLS](auto i, int k) { 408 constexpr int row = i / COLS; 409 constexpr int col = i % COLS; 410 411 if constexpr (col == 0) { 412{%- if alpha != 1 %} 413 va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}}); 414{%- else %} 415 va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k])); 416{%- endif %} 417 } 418 419 if constexpr (row == 0) { 420{%- if input2_dtype in [torch.bfloat16, torch.float16] %} 421 auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN); 422 vb[col] = at::vec::convert<{{compute_t}}>(b); 423{%- elif input2_dtype == torch.int8 %} 424 // Convert VLEN int8 elements to int32, and then fp32 425 auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN); 426 vb[col] = at::vec::convert<float>(b32); 427{%- else %} 428 vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); 429{%- endif %} 430 } 431 432 constexpr int idx = row * COLS + col; 433 vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); 434 }; 435 436 for (int k = 0; k < K; ++k) { 437 c10::ForcedUnroll<ROWS * COLS>{}(compute, k); 438 } 439 440 // store to C 441 auto storec = [&](auto i) { 442 constexpr int row = i / COLS; 443 constexpr int col = i % COLS; 444 vc[i].store(C + row * ldc + col * VLEN); 445 }; 446 c10::ForcedUnroll<ROWS * COLS>{}(storec); 447} 448""" 449 450 def codegen_define(self, kernel: CppTemplateKernel) -> str: 451 options = { 452 "declare_kernel": self.get_kernel_declaration(), 453 "kernel": kernel, 454 "block_m": self.register_blocking.block_m, 455 "block_n": self.register_blocking.block_n, 456 "block_k": self.register_blocking.block_k, 457 "restrict_keyword": get_restrict_keyword(), 458 **self.get_common_options(), 459 } 460 result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( 461 options 462 ) 463 result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( 464 options 465 ) 466 return result 467 468 469# extra check for CppMicroGemmAMX 470def check_amx_extra(config, m, n, k, alpha, num_threads): 471 vnni_size = 4 if config.input_dtype == torch.uint8 else 2 472 return k % vnni_size == 0 and alpha == 1 473 474 475@register_micro_gemm( 476 *generate_gemm_config( 477 VecAMX, 478 [(32, 32, 32), (48, 16, 32), (16, 48, 32)], 479 input_dtype=torch.bfloat16, 480 input2_dtype=torch.int8, 481 output_dtype=torch.float, 482 compute_dtype=torch.float, 483 extra_check=check_amx_extra, 484 ), 485 *generate_gemm_config( 486 VecAMX, 487 [(32, 32, 32), (48, 16, 32), (16, 48, 32)], 488 input_dtype=torch.bfloat16, 489 output_dtype=torch.float, 490 extra_check=check_amx_extra, 491 ), 492 *generate_gemm_config( 493 VecAMX, 494 [(32, 32, 64), (48, 16, 64)], 495 input_dtype=torch.uint8, 496 input2_dtype=torch.int8, 497 output_dtype=torch.int32, 498 compute_dtype=torch.int32, 499 extra_check=check_amx_extra, 500 ), 501) 502class CppMicroGemmAMX(CppMicroGemm): 503 """ 504 This class generates the code for micro gemm using Advanced Matrix eXtention (AMX) 505 instructions available in 4th generation Intel Xeon for compute. 506 It supports input types of torch.bfloat16 with fp32 output. 507 TODO(jgong5): support int8 data type. 508 """ 509 510 TEMPLATE_ENTRY = r""" 511{{declare_kernel}} { 512 TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); 513 TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); 514 // TODO(jgong5): loop unroll for M and N 515 for (int64_t m = 0; m < M; m += {{block_m}}) { 516 int64_t block_m = std::min<int64_t>(M - m, {{block_m}}); 517 int64_t m_tail = m; 518 for (int64_t n = 0; n < N; n += {{block_n}}) { 519{%- for num_rows in range(block_m, 0, -16) %} 520 {%- if num_rows != block_m %} 521 else 522 {%- endif %} 523 if (block_m >= {{num_rows}}) { 524 {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>( 525 amx_state, 526 A + m * lda, 527 B + n, 528 C + m * ldc + n, 529 K, 530 lda, 531 ldb, 532 ldc, 533 16 534 ); 535 block_m -= {{num_rows}}; 536 m_tail += {{num_rows}}; 537 } 538{%- endfor %} 539 if (block_m > 0) { 540 {{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>( 541 amx_state, 542 A + m_tail * lda, 543 B + n, 544 C + m_tail * ldc + n, 545 K, 546 lda, 547 ldb, 548 ldc, 549 block_m 550 ); 551 } 552 } 553 } 554} 555""" 556 557 TEMPLATE_KERNEL = r""" 558template <bool accum> 559inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( 560 AMXState& amx_state, 561 const {{input_t}}* {{restrict_keyword}} A, 562 const {{input2_t}}* {{restrict_keyword}} B, 563 {{output_t}}* {{restrict_keyword}} C, 564 int64_t K, 565 int64_t lda, 566 int64_t ldb, 567 int64_t ldc, 568 uint8_t tilecfg_rows 569) { 570 // TODO(jgong5): add prefetch hint for A, B, C 571 auto loadconfig = [](const amx_tilecfg& cfg) { 572 _tile_loadconfig(&cfg); 573 }; 574 const auto last_k_offset = K / {{block_k}} * {{block_k}}; 575 const auto tail_k_size = K - last_k_offset; 576 if C10_LIKELY (last_k_offset > 0) { 577 amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig); 578 } else { 579 amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); 580 } 581 auto load_c = [&]() { 582{%- for tile_row in range(num_rows // 16) %} 583 {%- for tile_col in range(num_columns) %} 584 {%- set tile_idx = tile_row * num_columns + tile_col %} 585 _tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); 586 {%- endfor %} 587{%- endfor %} 588 }; 589 auto zero_c = [&]() { 590{%- for tile_row in range(num_rows // 16) %} 591 {%- for tile_col in range(num_columns) %} 592 {%- set tile_idx = tile_row * num_columns + tile_col %} 593 _tile_zero({{tile_idx}}); 594 {%- endfor %} 595{%- endfor %} 596 }; 597 598 if constexpr (accum) { 599 load_c(); 600 } else { 601 zero_c(); 602 } 603 604{%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %} 605 // create a buffer for tiles of B. 606 alignas(64) {{input_t}} bf16_weights_buf[512]; 607 608 int num_b_rows = (last_k_offset > 0) ? 16 : (tail_k_size * sizeof({{input_t}})) / 4; 609 int b_tile_ptr_stride = ldb * {{vnni_size}}; 610 611 auto load_B_row = [&]({{input2_t}}* src, {{input_t}}* dst) { 612 {{kernel.unroll_pragma(2)}} 613 for (int i = 0; i < 2; i++) { 614 // int8 -> int32 -> fp32 -> bf16 615 auto b32 = at::vec::convert_to_int32<int8_t>(src + i * 16); 616 auto b_bf16 = at::vec::convert<{{input_t}}>(b32); 617 b_bf16.store(dst + i * 16); 618 } 619 }; 620 621 auto load_B_in_buf = [&]({{input2_t}}* B_ptr) { 622 {{kernel.unroll_pragma(8)}} 623 for (int i = 0; i < num_b_rows; i++) { 624 load_B_row( 625 B_ptr + i * b_tile_ptr_stride, 626 bf16_weights_buf + i * 32 627 ); 628 } 629 }; 630{%- endif %} 631 632 auto compute = [&](int k) { 633{%- set tile_offset_a = num_rows // 16 * num_columns %} 634{%- set tile_offset_b = tile_offset_a + num_rows // 16 %} 635{%- for tile_row in range(num_rows // 16) %} 636 {%- for tile_col in range(num_columns) %} 637 {%- set tile_idx_a = tile_offset_a + tile_row %} 638 {%- set tile_idx_b = tile_offset_b + tile_col %} 639 {%- set tile_idx_c = tile_row * num_columns + tile_col %} 640 {%- if tile_col == 0 %} 641 _tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}})); 642 {%- endif %} 643 {%- if tile_row == 0 %} 644 {%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %} 645 load_B_in_buf(const_cast<{{input2_t}}*>(B) + k * ldb + {{tile_col * 16 * vnni_size}}); 646 _tile_loadd({{tile_idx_b}}, bf16_weights_buf, 64); 647 {%- else %} 648 _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}})); 649 {%- endif %} 650 {%- endif %} 651 {%- if int8_gemm %} 652 _tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); 653 {%- else %} 654 _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); 655 {%- endif %} 656 {%- endfor %} 657{%- endfor %} 658 }; 659 660 {{kernel.unroll_pragma(4)}} 661 for (int k = 0; k < last_k_offset; k += {{block_k}}) { 662 compute(k); 663 } 664 665 auto store_c = [&]() { 666 // store to C 667{%- for tile_row in range(num_rows // 16) %} 668 {%- for tile_col in range(num_columns) %} 669 {%- set tile_idx = tile_row * num_columns + tile_col %} 670 _tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); 671 {%- endfor %} 672{%- endfor %} 673 }; 674 675 // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead 676 if C10_UNLIKELY (tail_k_size > 0) { 677 if C10_LIKELY (last_k_offset > 0) { 678 store_c(); 679 amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); 680 load_c(); 681 } 682 compute(last_k_offset); 683 } 684 685 store_c(); 686} 687""" 688 689 def codegen_define(self, kernel: CppTemplateKernel) -> str: 690 block_m, block_n, block_k = self.register_blocking 691 assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX" 692 assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX" 693 if self.input_dtype == torch.uint8: 694 assert block_k == 64, "Only support block_k = 64 for AMX INT8" 695 else: 696 assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16" 697 num_columns = block_n // 16 698 options = { 699 "declare_kernel": self.get_kernel_declaration(), 700 "kernel": kernel, 701 "block_m": block_m, 702 "block_n": block_n, 703 "block_k": block_k, 704 "num_columns": num_columns, 705 "restrict_keyword": get_restrict_keyword(), 706 **self.get_common_options(), 707 } 708 result = "" 709 for num_rows in range(block_m, 0, -16): 710 amx_kernel_options = {**options, "num_rows": num_rows} 711 result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( 712 amx_kernel_options 713 ) 714 result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( 715 options 716 ) 717 return result 718 719 def codegen_init( 720 self, 721 kernel: CppTemplateKernel, 722 ) -> str: 723 return "AMXState amx_state;" 724 725 def codegen_finalize( 726 self, 727 kernel: CppTemplateKernel, 728 ) -> str: 729 return "amx_state.release([]() { _tile_release(); });" 730 731 def get_kernel_extra_args_declare(self) -> str: 732 return "AMXState& amx_state," 733 734 def get_kernel_extra_args(self) -> str: 735 return "amx_state," 736 737 def get_b_layout(self): 738 if self.input_dtype == torch.uint8: 739 return LayoutType.VNNI4 740 else: 741 return LayoutType.VNNI2 742 743 744def create_micro_gemm( 745 name, 746 m, 747 n, 748 k, 749 input_dtype, 750 input2_dtype, 751 output_dtype=None, 752 compute_dtype=None, 753 alpha=1, 754 num_threads=-1, 755 use_ref=True, 756) -> Optional[CppMicroGemm]: 757 def create_from_config(cls, config: CppMicroGemmConfig): 758 return cls( 759 name, 760 config.input_dtype, 761 config.input2_dtype, 762 config.output_dtype, 763 config.compute_dtype, 764 config.register_blocking, 765 alpha, 766 ) 767 768 assert isinstance(n, int) or n.is_number, n 769 assert isinstance(k, int) or k.is_number, k 770 m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m 771 assert isinstance(m, int), m 772 if output_dtype is None: 773 output_dtype = input_dtype 774 if compute_dtype is None: 775 compute_dtype = output_dtype 776 if num_threads < 0: 777 num_threads = parallel_num_threads() 778 vec_isa = pick_vec_isa() 779 matched_configs = [] 780 for cls, configs in micro_gemm_configs.items(): 781 for config in configs: 782 if not issubclass(vec_isa.__class__, config.vec_isa_cls): 783 continue 784 if ( 785 config.input_dtype == input_dtype 786 and config.compute_dtype == compute_dtype 787 and config.input2_dtype == input2_dtype 788 and config.output_dtype == output_dtype 789 # The output_dtype here is the output dtype of the micro-kernel. 790 # In some cases, the actual output dtype of the op for which the micro-kernel 791 # is being created would be same as that of the activation, but the micro-kernels 792 # compute output in Float/int32, which is converted in the GEMM template. This is 793 # subject to change in the future. 794 ): 795 if config.extra_check is not None and not config.extra_check( 796 config, m, n, k, alpha, num_threads 797 ): 798 continue 799 block_m, block_n, block_k = config.register_blocking 800 if ( 801 config.vec_isa_cls == VecAMX 802 and m < block_m 803 and input_dtype == torch.bfloat16 804 and input2_dtype == torch.int8 805 ): 806 # For int8 WoQ GEMM, AMX micro-kernel may not perform well if m < block_m 807 continue 808 # Criteria on the ranking of configurations 809 # 1. ISA: AMX > VEC 810 # 2. Dividable by block sizes (block_m, block_n, block_k) 811 # 3. Number of mxn blocks is large enough to occupy all the threads 812 # 4. Register blocks are larger 813 isa_score = 0 814 if config.vec_isa_cls == VecAMX: 815 isa_score += 1 816 dividable_score = 0 817 if m % block_m == 0: 818 dividable_score += 1 819 if n % block_n == 0: 820 dividable_score += 1 821 if k % block_k == 0: 822 dividable_score += 1 823 occupancy_score = 0 824 n_blocks = (n + block_n - 1) // block_n 825 total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m) 826 if n_blocks >= num_threads: 827 occupancy_score += 1 828 if total_mxn_blocks >= num_threads: 829 occupancy_score += 1 830 register_bytes = ( 831 block_m * block_n * config.compute_dtype.itemsize 832 + (block_m * block_k + block_k * block_n) 833 * config.input_dtype.itemsize 834 ) 835 matched_configs.append( 836 ( 837 (isa_score, dividable_score, occupancy_score, register_bytes), 838 cls, 839 config, 840 ) 841 ) 842 if len(matched_configs) == 0: 843 if use_ref: 844 return CppMicroGemmRef( 845 name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha 846 ) 847 else: 848 return None 849 # TODO(jgong5): allow autotuning on choices of configs 850 return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:]) 851