1# mypy: allow-untyped-defs 2import contextlib 3import logging 4import math 5from functools import lru_cache 6from typing import Any, Callable, cast, List, Optional, Set, Union 7from unittest.mock import patch 8 9import torch 10import torch.utils 11 12from ..._dynamo.utils import counters 13from .. import config, ir, lowering as L 14from ..kernel.mm_common import mm_args 15from ..select_algorithm import DataProcessorTemplateWrapper 16from ..utils import cache_on_self, has_free_symbols, parallel_num_threads 17from ..virtualized import ops, V 18from .cpp import get_export_declaration 19from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType 20from .cpp_template import CppTemplate 21from .cpp_template_kernel import CppTemplateKernel 22from .cpp_utils import ( 23 create_epilogue_with_attr, 24 DTYPE_TO_CPP, 25 GemmBlocking, 26 get_gemm_template_output_and_compute_dtype, 27) 28 29 30log = logging.getLogger(__name__) 31 32GEMM_TEMPLATE = r""" 33{{template.header().getvalue()}} 34 35{{micro_gemm.codegen_define(kernel)}} 36 37{%- if x_scale is not none %} 38 {%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %} 39{%- else %} 40 {%- set kernel_args = {"X": X, "W": W, "inp": inp} %} 41{%- endif %} 42 43extern "C" {{export_declaration}} 44{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}} 45{ 46 {{kernel.maybe_codegen_profile()}} 47 constexpr int64_t num_threads = {{num_threads}}; 48 constexpr int64_t N = {{N}}; 49 constexpr int64_t K = {{K}}; 50 constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}}; 51 constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}}; 52 constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}}; 53 constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; 54 constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; 55 56{%- if is_dynamic_M %} 57 const int64_t M = {{kernel.size(GemmOut, 0)}}; 58 const int64_t Mr_blocks = (M + Mr - 1) / Mr; 59 {%- if num_threads > 1 %} 60 int64_t Mt_blocks, Nt_blocks, Kt_blocks; 61 mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks); 62 {%- else %} 63 const auto Mt_blocks = Mr_blocks; 64 const auto Nt_blocks = Nr_blocks; 65 const auto Kt_blocks = Kr_blocks; 66 {%- endif %} 67 int64_t Mc_blocks, Nc_blocks, Kc_blocks; 68 uint32_t L1_cache_size = {{L1_cache_size}}; 69 uint32_t L2_cache_size = {{L2_cache_size}}; 70 mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>( 71 num_threads, 72 M, 73 N, 74 K, 75 Mr, 76 Nr, 77 Kr, 78 Mt_blocks, 79 Nt_blocks, 80 Kt_blocks, 81 Mc_blocks, 82 Nc_blocks, 83 Kc_blocks, 84 L1_cache_size, 85 L2_cache_size 86 ); 87 const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; 88 const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; 89 const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; 90{%- else %} 91 constexpr int64_t M = {{kernel.size(GemmOut, 0)}}; 92 constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; 93 constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}}; 94 constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}}; 95 constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}}; 96 constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}}; 97 constexpr int64_t Nc_blocks = {{template.cache_blocking().block_n}}; 98 constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}}; 99 constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; 100 constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; 101 constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; 102{%- endif %} 103 104 // make sure all partitions are assigned 105 {{kernel.assert_function}}( 106 Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks, 107 "Not all partitions are assigned." 108 ); 109 110{%- if maybe_k_slicing %} 111 std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs; 112 if (num_k_slices > 1) { 113 local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]); 114 } 115{%- endif %} 116 117{%- if num_threads > 1 %} 118 #pragma omp parallel num_threads({{num_threads}}) 119 { 120 const int tid = omp_get_thread_num(); 121 int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end; 122 mm_get_thread_blocks( 123 tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks, 124 m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); 125 {%- if maybe_k_slicing %} 126 const int64_t k_group_id = tid / num_k_slices; 127 const int64_t k_slice_id = tid % num_k_slices; 128 {%- endif %} 129{%- else %} 130 { 131 const int tid = 0; 132 const int64_t m_block_start = 0; 133 const int64_t m_block_end = Mr_blocks; 134 const int64_t n_block_start = 0; 135 const int64_t n_block_end = Nr_blocks; 136 const int64_t k_block_start = 0; 137 const int64_t k_block_end = Kr_blocks; 138{%- endif %} 139 {{ micro_gemm.codegen_init(kernel) }} 140{%- if use_local_acc %} 141 {%- set acc_buf_name = "local_acc_buf" %} 142 {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} 143{%- endif %} 144 for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { 145 const int64_t m_start = mc * Mr; 146 const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); 147 const int64_t m_size = m_end - m_start; 148 for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { 149 const int64_t n_start = nc * Nr; 150 const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); 151 const int64_t n_size = n_end - n_start; 152 // NB: assume we pad N, nc_block_end won't exceed padded N here. 153 const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); 154{%- if use_local_acc %} 155 {%- set acc = kernel.local_buffers[acc_buf_name] %} 156 {{ kernel.reinit_buffer_if_null(acc_buf_name) }} 157{%- else %} 158 {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %} 159{%- endif %} 160 for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { 161 int64_t k_start = kc * Kr; 162 int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); 163{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} 164 for (int64_t nci = nc; nci < nc_block_end; nci++) { 165{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %} 166{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %} 167{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} 168 if (kc == k_block_start) { 169 {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=False)|indent(28, false) }} 170 } else { 171 {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=True)|indent(28, false) }} 172 } 173 } 174 } 175{%- if maybe_k_slicing %} 176 if (num_k_slices > 1) { 177 const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; 178 local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }}); 179 } else 180{%- endif %} 181 { 182{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %} 183{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %} 184 {{ kernel.store_output( 185 tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers 186 )|indent(20, false) 187 }} 188 } 189 } 190 } 191{%- if maybe_k_slicing %} 192 if (num_k_slices > 1) { 193 #pragma omp barrier 194 for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { 195 // We slice M-dim and each thread in the k-slicing group works on a slice 196 const int64_t m_start_unsliced = mc * Mr; 197 const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); 198 const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced; 199 const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices; 200 const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced); 201 const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced); 202 const int64_t m_size = m_end - m_start; 203 const int64_t m_offset = m_start - m_start_unsliced; 204 for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { 205 const int64_t n_start = nc * Nr; 206 const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); 207 const int64_t n_size = n_end - n_start; 208 const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; 209 auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get(); 210 for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) { 211 auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get(); 212 for (int64_t m = m_offset; m < m_offset + m_size; m++) { 213 #pragma omp simd 214 for (int64_t n = 0; n < n_size; n++) { 215 {{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n]; 216 } 217 } 218 } 219 {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %} 220 {{ kernel.store_output( 221 tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers 222 )|indent(20, false) 223 }} 224 } 225 } 226 } 227{%- endif %} 228 {{ micro_gemm.codegen_finalize(kernel) }} 229 } 230} 231""" 232 233 234def get_padded_n(n, block_n): 235 return (n + block_n - 1) // block_n * block_n 236 237 238class CppPackedGemmTemplate(CppTemplate): 239 def __init__( 240 self, 241 input_nodes, 242 layout: ir.Layout, 243 num_threads: int, 244 register_blocking: GemmBlocking, 245 beta=1, 246 alpha=1, 247 has_bias=False, 248 epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, 249 ) -> None: 250 assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8] 251 super().__init__( 252 "packed_gemm", 253 input_nodes, 254 layout, 255 num_threads, 256 epilogue_creator=epilogue_creator, 257 ) 258 self.beta = beta 259 self.alpha = alpha 260 self.has_bias = has_bias 261 self.register_blocking = register_blocking 262 m, n = layout.size 263 _, k = input_nodes[0].get_size() 264 self.m, self.n, self.k = m, n, k 265 self.padded_n = get_padded_n(n, self.register_blocking.block_n) 266 self.is_dynamic_M = has_free_symbols((m,)) 267 268 @cache_on_self 269 def thread_blocking(self) -> GemmBlocking: 270 """ 271 NOTE [Thread blocking in Cpp GEMM] 272 We use simple heuristics to decide the thread blocking: 273 1. Make sure all threads are occupied as much as possible. 274 2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse. 275 3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing. 276 TODO(jgong5): allow tuning various blocking options 277 """ 278 279 @lru_cache(maxsize=100) 280 def get_factors(number): 281 factors = [] 282 for i in range(int(number**0.5), 0, -1): 283 if number % i == 0: 284 factors.append(number // i) 285 factors.append(i) 286 return factors 287 288 def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks): 289 thread_block_k = math.ceil(k_blocks / k_factor) 290 thread_block_n = math.ceil(n_blocks / n_factor) 291 thread_block_m = math.ceil(m_blocks / m_factor) 292 return GemmBlocking(thread_block_m, thread_block_n, thread_block_k) 293 294 assert ( 295 not self.is_dynamic_M 296 ), "Unable to determine thread blocking for dynamic M." 297 register_blocking = self.register_blocking 298 m_blocks = math.ceil(self.m / register_blocking.block_m) 299 n_blocks = math.ceil(self.n / register_blocking.block_n) 300 k_blocks = math.ceil(self.k / register_blocking.block_k) 301 factors = get_factors(self.num_threads) 302 assert len(factors) > 0 303 304 if config.cpp.gemm_thread_factors is not None: 305 factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")] 306 assert len(factors) == 3 307 assert math.prod(factors) == self.num_threads 308 return get_blocking( 309 factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks 310 ) 311 312 # we favor square-sized thread blocks for good data reuse 313 def get_better_blocking(blocking, best_blocking): 314 if best_blocking is None: 315 best_blocking = blocking 316 else: 317 block_m_size = blocking.block_m * register_blocking.block_m 318 block_n_size = blocking.block_n * register_blocking.block_n 319 best_block_m_size = best_blocking.block_m * register_blocking.block_m 320 best_block_n_size = best_blocking.block_n * register_blocking.block_n 321 if blocking.block_k > best_blocking.block_k: 322 best_blocking = blocking 323 elif ( 324 blocking.block_k == best_blocking.block_k 325 and block_m_size + block_n_size 326 < best_block_m_size + best_block_n_size 327 ): 328 best_blocking = blocking 329 return best_blocking 330 331 best_blocking = None 332 # check if we can have a thread-blocking to occupy all threads without k-slicing 333 for n_factor in factors: 334 m_factor = self.num_threads // n_factor 335 if n_blocks >= n_factor and m_blocks >= m_factor: 336 blocking = get_blocking( 337 m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks 338 ) 339 best_blocking = get_better_blocking(blocking, best_blocking) 340 341 if best_blocking is None: 342 for k_factor in factors: 343 if k_blocks >= k_factor and ( 344 config.cpp.gemm_max_k_slices == 0 345 or k_factor <= config.cpp.gemm_max_k_slices 346 ): 347 n_factors = get_factors(self.num_threads // k_factor) 348 for n_factor in n_factors: 349 m_factor = (self.num_threads // k_factor) // n_factor 350 if n_blocks >= n_factor and m_blocks >= m_factor: 351 blocking = get_blocking( 352 m_factor, 353 n_factor, 354 k_factor, 355 m_blocks, 356 n_blocks, 357 k_blocks, 358 ) 359 best_blocking = get_better_blocking(blocking, best_blocking) 360 361 if best_blocking is None: 362 for n_factor in factors: 363 m_factor = self.num_threads // n_factor 364 if n_blocks >= n_factor or m_blocks >= m_factor: 365 blocking = get_blocking( 366 m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks 367 ) 368 best_blocking = get_better_blocking(blocking, best_blocking) 369 370 assert best_blocking is not None 371 return best_blocking 372 373 @cache_on_self 374 def cache_blocking(self) -> GemmBlocking: 375 def get_cache_blocking(register_blocking, thread_blocking): 376 Mr = register_blocking.block_m 377 Nr = register_blocking.block_n 378 Kr = register_blocking.block_k 379 380 Mt_blocks = thread_blocking.block_m 381 Nt_blocks = thread_blocking.block_n 382 Kt_blocks = thread_blocking.block_k 383 384 if config.cpp.gemm_cache_blocking is not None: 385 blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")] 386 assert len(blockings) == 3 387 Mc_blocks, Nc_blocks, Kc_blocks = blockings 388 return ( 389 min(Mc_blocks, Mt_blocks), 390 min(Nc_blocks, Nt_blocks), 391 min(Kc_blocks, Kt_blocks), 392 ) 393 394 # The ratios below are empirically determined to decide 395 # the effective sizes of L1 and L2. 396 # TODO: tune the factor here 397 L1_limit_factor = 0.8 398 L2_limit_factor = 0.5 399 400 L1_cache_size = ( 401 torch._C._cpu._L1d_cache_size() 402 ) # per core cache size in Bytes 403 assert ( 404 L1_cache_size > 0 405 ), f"Expect L1_cache_size > 0 but got {L1_cache_size}" 406 L1 = L1_cache_size * L1_limit_factor 407 408 L2_cache_size = ( 409 torch._C._cpu._L2_cache_size() 410 ) # per core cache size in Bytes 411 assert ( 412 L2_cache_size > 0 413 ), f"Expect L2_cache_size > 0 but got {L2_cache_size}" 414 L2 = L2_cache_size * L2_limit_factor 415 416 def get_num_byte(dtype): 417 return torch.tensor([], dtype=dtype).element_size() 418 419 num_byte_A = get_num_byte(self.input_nodes[0].get_dtype()) 420 num_byte_B = get_num_byte(self.input_nodes[1].get_dtype()) 421 422 # NOTE [CPP GEMM Cache Blocking Algorithm] 423 # Our overall strategy is to 424 # 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc. 425 # Here, B is Kc x Nr where Nr is a single register block. We use L1 size to 426 # decide Kc. We want to make Mc large enough to better reuse B. 427 # 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A 428 # along N, where we have two sub-strategies (see notes below) to decide Mc and Nc. 429 430 # Step 1: Decide Kc assuming B block is L1-reside. 431 size_cache_B = Kr * Kt_blocks * Nr * num_byte_B 432 Kc_blocks = Kt_blocks 433 if size_cache_B > L1: 434 Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B)) 435 436 # Step 2: Decide Mc assuming A block is L2-reside. 437 min_Mc_ratio = 2 # TODO(jgong5): something to tune? 438 min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr) 439 assert min_Mc_blocks >= 1 440 Kt_bytes = Kt_blocks * Kr * num_byte_A 441 if min_Mc_blocks * Mr * Kt_bytes < L2: 442 # Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt 443 # when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks) 444 # to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside 445 # in L1. 446 Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes))) 447 Nc_blocks = 1 448 else: 449 # Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse 450 # A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2. 451 Mc_blocks = Mt_blocks 452 Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) 453 Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32 454 Kc_bytes = Kc_blocks * Kr * num_byte_A 455 if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2: 456 # The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2, 457 # assuming Mc == Nc for good data reuse. 458 M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8 459 if M_max < Mc_blocks * Mr: 460 Mc_blocks = math.floor(M_max / Mr) 461 Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) 462 463 return Mc_blocks, Nc_blocks, Kc_blocks 464 465 assert ( 466 not self.is_dynamic_M 467 ), "Unable to determine cache blocking for dynamic M." 468 register_blocking = self.register_blocking 469 thread_blocking = self.thread_blocking() 470 471 return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking)) 472 473 def log_blockings(self): 474 log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004 475 if self.is_dynamic_M: 476 # thread and cache blockings are determined at runtime for dynamic shapes 477 return 478 log.debug(f"Cache blocking: {self.cache_blocking()}") # noqa: G004 479 thread_blocking = self.thread_blocking() 480 log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004 481 482 def get_occupancy(): 483 m_blocks = math.ceil(self.m / self.register_blocking.block_m) 484 n_blocks = math.ceil(self.n / self.register_blocking.block_n) 485 k_blocks = math.ceil(self.k / self.register_blocking.block_k) 486 m = math.ceil(m_blocks / thread_blocking.block_m) 487 n = math.ceil(n_blocks / thread_blocking.block_n) 488 k = math.ceil(k_blocks / thread_blocking.block_k) 489 return (m, n, k) 490 491 log.debug( 492 f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004 493 ) 494 495 def maybe_k_slicing(self): 496 if self.num_threads == 1: 497 return False 498 if self.is_dynamic_M: 499 # TODO(jgong5): perhaps use size hint to decide? 500 return True 501 register_blocking = self.register_blocking 502 k_blocks = math.ceil(self.k / register_blocking.block_k) 503 thread_blocking = self.thread_blocking() 504 return k_blocks > thread_blocking.block_k 505 506 @staticmethod 507 def add_choices( 508 choices, 509 layout, 510 input_nodes, 511 beta=1, 512 alpha=1, 513 has_bias=False, 514 trans_w=False, 515 input_indices=None, 516 epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, 517 ): 518 if input_indices is None: 519 input_indices = list(range(len(input_nodes))) 520 521 def reorder_and_filter(inputs, layout_or_out): 522 if has_bias: 523 assert len(input_indices) >= 3 524 # Assume the input order is [inp, x, w] and we reorder it to [x, w, inp] 525 inp_idx = input_indices[0] 526 x_idx = input_indices[1] 527 w_idx = input_indices[2] 528 return [ 529 inputs[x_idx], 530 inputs[w_idx], 531 inputs[inp_idx], 532 *[inputs[idx] for idx in input_indices[3:]], 533 ], layout_or_out 534 else: 535 assert len(input_indices) >= 2 536 return [inputs[idx] for idx in input_indices], layout_or_out 537 538 def maybe_to_dense(inputs, layout_or_out): 539 new_inputs = list(inputs) 540 if isinstance(inputs[1], torch.Tensor): 541 W = inputs[1] 542 new_inputs[1] = W.to_dense() if W.is_mkldnn else W 543 return new_inputs, layout_or_out 544 545 def normalize_shapes(inputs, layout_or_out): 546 if not trans_w: 547 return inputs, layout_or_out 548 new_inputs = list(inputs) 549 X = inputs[0] 550 W = inputs[1] 551 B = inputs[2] if has_bias else None 552 if isinstance(W, ir.IRNode): 553 if trans_w: 554 if not isinstance(W, ir.TensorBox): 555 W = ir.TensorBox(W) 556 W = L.permute(W, [1, 0]) 557 else: 558 if trans_w: 559 assert isinstance(W, torch.Tensor) 560 W = W.transpose(0, 1) 561 if B is not None: 562 if isinstance(B, ir.IRNode): 563 if not isinstance(B, ir.TensorBox): 564 B = ir.TensorBox(B) 565 B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) 566 else: 567 assert isinstance(B, torch.Tensor) 568 B = B.expand(X.shape[0], B.shape[-1]) 569 new_inputs[1] = W 570 if B is not None: 571 new_inputs[2] = B 572 return new_inputs, layout_or_out 573 574 # TODO(jgong5): decide proper number of threads per problem size 575 num_threads = parallel_num_threads() 576 new_inputs, _ = normalize_shapes( 577 *maybe_to_dense(*reorder_and_filter(input_nodes, layout)) 578 ) 579 m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) 580 output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( 581 new_inputs[0].get_dtype() 582 ) 583 micro_gemm = create_micro_gemm( 584 "micro_gemm", 585 m, 586 n, 587 k, 588 input_dtype=new_inputs[0].get_dtype(), 589 input2_dtype=new_inputs[1].get_dtype(), 590 output_dtype=output_dtype, 591 compute_dtype=compute_dtype, 592 alpha=alpha, 593 num_threads=num_threads, 594 ) 595 assert micro_gemm is not None 596 _, block_n, _ = micro_gemm.register_blocking 597 padded_n = get_padded_n(n, block_n) 598 599 def pack_weight(inputs, layout_or_out): 600 W = inputs[1] 601 new_inputs = list(inputs) 602 blocked_w: Union[ir.IRNode, torch.Tensor] = W 603 if isinstance(W, ir.IRNode): 604 new_size = [padded_n // block_n, k, block_n] 605 blocked_w = ir.Buffer( 606 W.get_name(), # Borrow the registered buffer name 607 ir.FixedLayout( 608 W.get_device(), 609 W.get_dtype(), 610 new_size, 611 ir.FlexibleLayout.contiguous_strides(new_size), 612 0, 613 ), 614 ) 615 else: 616 blocked_w = ( 617 torch.nn.functional.pad(W, (0, padded_n - n)) 618 .reshape(k, padded_n // block_n, block_n) 619 .transpose(0, 1) 620 .contiguous() 621 ) 622 if micro_gemm.get_b_layout() != LayoutType.NORMAL: 623 layout_str = ( 624 "VNNI4" 625 if micro_gemm.get_b_layout() == LayoutType.VNNI4 626 else "VNNI2" 627 ) 628 assert micro_gemm.get_b_layout() in [ 629 LayoutType.VNNI2, 630 LayoutType.VNNI4, 631 ], f"We only support {layout_str} for now" 632 vnni_size = ( 633 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2 634 ) 635 assert ( 636 k % vnni_size == 0 637 ), f"k should be divisible by vnni_size for {layout_str} layout" 638 blocked_w = ( 639 blocked_w.view( 640 padded_n // block_n, k // vnni_size, vnni_size, block_n 641 ) 642 .transpose(-1, -2) 643 .contiguous() 644 .view(padded_n // block_n, k, block_n) 645 ) 646 # normalize stride to be "contiguous_strides" per size 647 # this avoids the problems in L.view during template codegen 648 new_stride = [1] 649 for sz in reversed(blocked_w.shape[1:]): 650 new_stride.insert(0, new_stride[0] * sz) 651 blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride) 652 new_inputs[1] = blocked_w 653 654 def _is_int8_gemm(inputs): 655 return ( 656 isinstance(inputs[0], ir.IRNode) 657 and inputs[0].get_dtype() == torch.uint8 658 ) or ( 659 isinstance(inputs[0], torch.Tensor) 660 and inputs[0].dtype == torch.uint8 661 ) 662 663 if _is_int8_gemm(new_inputs): 664 BCompensate = None 665 if isinstance(W, ir.IRNode): 666 BCompensate = V.graph.add_tensor_constant( 667 V.graph.constants[W.get_name() + "_BMatrixCompens"], 668 W.get_name() + "_BMatrixCompens", 669 ) 670 else: 671 BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment] 672 new_inputs.append(BCompensate) 673 return new_inputs, layout_or_out 674 675 def preprocessor(inputs, layout): 676 return pack_weight( 677 *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout))) 678 ) 679 680 def postprocessor(output): 681 if isinstance(output, ir.TensorBox): 682 # prepack the weight as input to the template buffer 683 template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) 684 assert isinstance(template_buffer, ir.CppTemplateBuffer) 685 new_input_nodes, _ = reorder_and_filter(input_nodes, layout) 686 687 W_node = new_input_nodes[1] 688 assert W_node.get_name() in V.graph.constants 689 W = V.graph.constants[W_node.get_name()] 690 new_input_nodes[1] = W 691 new_input_nodes, _ = pack_weight( 692 *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) 693 ) 694 695 # By using the new packed weight for the GEMM template, we can prune the 696 # old weight if it has no other users. This saves memory but makes the FX graph 697 # non-retraceable. To support retracing, we can add a repack node to the 698 # FX graph. For example: 699 # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template 700 W_tensor_users = 0 701 for node in reversed(V.graph.graph.nodes): 702 # Case may happen when the wgt tensor is used by more than 1 get_attr node 703 # https://github.com/pytorch/pytorch/issues/134998 704 if node.op == "get_attr" and hasattr( 705 V.graph.module, node.name 706 ): # wgt might already be deleted 707 comp_tensor = getattr(V.graph.module, node.name) 708 if ( 709 W.is_mkldnn == comp_tensor.is_mkldnn 710 and W.dtype == comp_tensor.dtype 711 and W.device == comp_tensor.device 712 and ( 713 ( 714 not W.is_mkldnn 715 and ( 716 W.untyped_storage().data_ptr() 717 == comp_tensor.untyped_storage().data_ptr() 718 ) 719 ) 720 or ( 721 W.is_mkldnn 722 and ( 723 torch.ops.mkldnn.data_ptr(W) 724 == torch.ops.mkldnn.data_ptr(comp_tensor) 725 ) 726 ) 727 ) 728 ): 729 W_tensor_users += 1 730 731 for node in reversed(V.graph.graph.nodes): 732 # The wgt tensor has been used by only 1 get_attr node 733 # The get_attr node has only 1 user fx node 734 if ( 735 node.name == W_node.get_name() 736 and len(node.users) == 1 737 and W_tensor_users == 1 738 ): 739 del V.graph.constants[node.name] 740 delattr(V.graph.module, node.name) 741 delattr(V.graph.graph.owning_module, node.name) 742 743 W_packed = new_input_nodes[1] 744 W_packed_constant = V.graph.add_tensor_constant(W_packed) 745 template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input( 746 W_packed_constant 747 ) 748 return output 749 750 template = DataProcessorTemplateWrapper( 751 CppPackedGemmTemplate, 752 preprocessor, 753 postprocessor, 754 input_nodes=input_nodes, 755 layout=layout, 756 num_threads=num_threads, 757 register_blocking=micro_gemm.register_blocking, 758 beta=beta, 759 alpha=alpha, 760 has_bias=has_bias, 761 epilogue_creator=epilogue_creator, 762 ) 763 template.maybe_append_choice(choices) 764 return template 765 766 def render( # type: ignore[override,return] 767 self, 768 kernel: CppTemplateKernel, 769 template_buffer_node: Optional[ir.CppTemplateBuffer] = None, 770 flag_template_buffer_has_other_users: Optional[bool] = None, 771 epilogue_nodes: Optional[List[ir.IRNode]] = None, 772 **kwargs, 773 ) -> str: 774 assert len(self.input_nodes) >= 2 775 776 int8_gemm = self.input_nodes[0].get_dtype() == torch.uint8 777 x_scale = None 778 x_zp = None 779 w_scale = None 780 w_zp = None 781 if int8_gemm: 782 X, W = self.input_nodes[0], self.input_nodes[1] 783 bias_idx = 2 if self.has_bias else 1 784 inp = self.input_nodes[bias_idx] if self.has_bias else None 785 x_scale = self.input_nodes[bias_idx + 1] 786 x_zp = self.input_nodes[bias_idx + 2] 787 w_scale = self.input_nodes[bias_idx + 3] 788 w_zp = self.input_nodes[bias_idx + 4] 789 Y = self.output_node 790 else: 791 X, W = self.input_nodes[0], self.input_nodes[1] 792 Y = self.output_node 793 inp = self.input_nodes[2] if self.has_bias else None 794 795 template_buffer_has_other_users = None 796 797 if template_buffer_node is not None: 798 # Use the updated prepacked weight buffer 799 W = template_buffer_node.inputs[1] 800 Y = template_buffer_node 801 802 assert flag_template_buffer_has_other_users is not None 803 template_buffer_has_other_users = flag_template_buffer_has_other_users 804 805 template_buffer = Y 806 gemm_output_buffer = template_buffer 807 808 epilogues: List[ir.IRNode] = [] 809 reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = [] 810 epilogue_creators: List[Callable[[ir.Buffer], ir.Pointwise]] = [] 811 fake_buffers: List[ir.Buffer] = [] 812 Y_aliases: Set[str] = set() 813 814 use_local_acc = ( 815 self.layout.dtype != torch.float 816 or template_buffer_has_other_users 817 or int8_gemm 818 or self.padded_n != self.n 819 or self.maybe_k_slicing() 820 ) 821 822 # TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template, 823 # but we'd better move it here to align with fp. 824 if inp is not None and self.beta != 0 and not int8_gemm: 825 # add an epilogue for bias add 826 def _bias_add_epilogue(buf): 827 return create_epilogue_with_attr( 828 buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype 829 ) 830 831 epilogue_creators.append(_bias_add_epilogue) 832 833 if self.epilogue_creator is not None: 834 epilogue_creators.append(self.epilogue_creator) 835 836 # When the GEMM output buffer is localized but it has users other than the epilogue nodes, 837 # we need to copy the value in the GEMM output local buffer to a global buffer. 838 def need_copy_from_local_to_global_buffer_epilogue( 839 use_local_acc, template_buffer_has_other_users, epilogue_creators 840 ): 841 # The GEMM output buffer is a global buffer, thus copy is not needed. 842 if not use_local_acc: 843 return False 844 845 # The possible value of template_buffer_has_other_users is (None, False, True) 846 # It is None when generating the gemm template during autotune and it will have value during scheduler codegen. 847 # extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases: 848 # 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune) 849 # 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the 850 # GEMM output buffer in local buffer only (no users outside of the epilogues will use its value). 851 if not template_buffer_has_other_users: 852 return False 853 854 # When bias is not None or self.epilogue_creator is not None, 855 # there will be epilogue_creators after the GEMM. 856 # The GEMM output buffer is localized while 857 # the output buffer of the epilogue_creators is a global buffer. 858 if epilogue_creators: 859 return False 860 861 return True 862 863 if need_copy_from_local_to_global_buffer_epilogue( 864 use_local_acc, template_buffer_has_other_users, epilogue_creators 865 ): 866 867 def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer): 868 dtype = self.layout.dtype 869 input_loader = input_buffer.make_loader() 870 871 def copy_inner(index): 872 input = input_loader(index) 873 result = ops.to_dtype(input, dtype) 874 return result 875 876 return ir.Pointwise( 877 device=input_buffer.get_device(), 878 dtype=self.layout.dtype, 879 inner_fn=copy_inner, 880 ranges=input_buffer.get_size(), 881 ) 882 883 epilogue_creators.append(copy_from_local_to_global_buffer_epilogue) 884 885 # NOTE [How CPP GEMM template epilogues are organized] 886 # gemm_output_buffer 887 # --> zero or more in-template epilogues (created by `epilogue_creators`) --> 888 # template_buffer 889 # --> zero or more out-of-template epilogues (`epilogue_nodes`) --> 890 # Y 891 if epilogue_creators: 892 gemm_output_name = "buf_GemmOut" 893 gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout) 894 current_input_buffer = gemm_output_buffer 895 for i, creator in enumerate(epilogue_creators): 896 if i == len(epilogue_creators) - 1: 897 buffer_name = template_buffer.get_name() 898 else: 899 buffer_name = f"buf_GemmOut_epilogue_{i}" 900 epilogues.append( 901 ir.ComputedBuffer( 902 name=buffer_name, 903 layout=template_buffer.layout, 904 data=creator(current_input_buffer), 905 ) 906 ) 907 fake_buffers.append(current_input_buffer) 908 Y_aliases.add(current_input_buffer.get_name()) 909 reindexers.append(None) 910 if i < len(epilogue_creators) - 1: 911 current_input_buffer = ir.Buffer( 912 buffer_name, template_buffer.layout 913 ) 914 915 Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y 916 917 if epilogue_nodes: 918 epilogues.extend(epilogue_nodes) 919 assert Y.get_numel() == epilogues[-1].get_numel() 920 Y = cast(ir.Buffer, epilogues[-1]) 921 922 if not template_buffer_has_other_users: 923 Y_aliases.add(template_buffer.get_name()) 924 925 if ( 926 Y.get_size() == template_buffer.get_size() 927 and Y.get_stride() == template_buffer.get_stride() 928 ): 929 reindexers.extend([None] * len(epilogue_nodes)) 930 Y_2d = Y 931 else: 932 933 def get_reindexer(epilogue_node): 934 # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example: 935 # template_buffer: 936 # size (324, 512), stride (512, 1) 937 # epilogue_node_ordered (ordered by stride decreasingly, in dense format): 938 # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) 939 stride_order = list( 940 ir.get_stride_order( 941 V.graph.sizevars.size_hints(epilogue_node.get_stride()) 942 ) 943 ) 944 fill_order = ir.stride_order2fill_order(stride_order) 945 reversed_fill_order = list(reversed(fill_order)) 946 size_with_stride_ordered_decreasingly = [ 947 epilogue_node.get_size()[i] for i in reversed_fill_order 948 ] 949 reshape_reindex = ir.View.dynamic_reshape_indexer( 950 size_with_stride_ordered_decreasingly, 951 template_buffer.get_size(), 952 ) 953 954 # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example: 955 # epilogue_node_ordered (ordered by stride decreasingly, in dense format): 956 # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) 957 # epilogue_node: 958 # size (1, 18, 18, 512), stride (165888, 1, 9216, 512) 959 from_stride_ordered_decreasingly_to_epilogue_node_order = [ 960 (len(stride_order) - 1) - stride_order[i] 961 for i in range(len(stride_order)) 962 ] 963 stride_reindex = ir.same_reorder( 964 from_stride_ordered_decreasingly_to_epilogue_node_order 965 ) 966 967 reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) 968 return reindexer 969 970 reindexers.extend([get_reindexer(epilogue_node) for epilogue_node in epilogue_nodes]) # type: ignore[list-item] 971 if isinstance(Y, ir.BaseView): 972 storage = ir.StorageBox(Y.unwrap_view()) 973 else: 974 assert isinstance(Y, ir.Buffer) 975 storage = ir.StorageBox(Y) 976 Y_2d = ir.ReinterpretView(storage, template_buffer.get_layout()) 977 978 output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( 979 X.get_dtype() 980 ) 981 micro_gemm = create_micro_gemm( 982 f"{kernel.kernel_name}_micro_gemm", 983 self.m, 984 self.n, 985 self.k, 986 input_dtype=X.get_dtype(), 987 input2_dtype=W.get_dtype(), 988 output_dtype=output_dtype, 989 compute_dtype=compute_dtype, 990 alpha=self.alpha, 991 num_threads=self.num_threads, 992 ) 993 assert micro_gemm is not None 994 assert self.register_blocking == micro_gemm.register_blocking 995 self.log_blockings() 996 if isinstance(micro_gemm, CppMicroGemmAMX): 997 counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 998 999 L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes 1000 assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}" 1001 1002 L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes 1003 assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}" 1004 1005 options = dict( 1006 X=X, 1007 W=W, 1008 inp=inp, 1009 Y=Y, 1010 N=self.n, 1011 K=self.k, 1012 PADDED_N=self.padded_n, 1013 GemmOut=gemm_output_buffer, 1014 aliases={alias: Y.get_name() for alias in Y_aliases}, 1015 beta=self.beta, 1016 alpha=self.alpha, 1017 num_threads=self.num_threads, 1018 micro_gemm=micro_gemm, 1019 is_dynamic_M=self.is_dynamic_M, 1020 template=self, 1021 kernel=kernel, 1022 export_declaration=get_export_declaration(), 1023 epilogue_nodes=epilogues, 1024 reindexers=reindexers, 1025 Y_2d=Y_2d, 1026 use_local_acc=use_local_acc, 1027 maybe_k_slicing=self.maybe_k_slicing(), 1028 x_scale=x_scale, 1029 x_zp=x_zp, 1030 w_scale=w_scale, 1031 w_zp=w_zp, 1032 acc_buf_dtype=torch.int32 if int8_gemm else torch.float, 1033 DTYPE_TO_CPP=DTYPE_TO_CPP, 1034 L1_cache_size=L1_cache_size, 1035 L2_cache_size=L2_cache_size, 1036 config=config, 1037 ) 1038 with contextlib.ExitStack() as stack: 1039 for buf in fake_buffers: 1040 stack.enter_context( 1041 patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) 1042 ) 1043 return self._template_from_string(GEMM_TEMPLATE).render(**options) 1044