xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/cpp_micro_gemm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import sys
4from enum import Enum
5from typing import Callable, Dict, List, Optional, Type
6
7import sympy
8
9import torch
10
11from .. import ir
12from ..cpu_vec_isa import pick_vec_isa, VecAMX, VecAVX2, VecAVX512, VecISA
13from ..utils import IndentedBuffer, parallel_num_threads
14from ..virtualized import V
15from .common import KernelTemplate
16from .cpp_template_kernel import CppTemplateKernel
17from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp
18
19
20class LayoutType(Enum):
21    NORMAL = 0
22    VNNI2 = 1
23    VNNI4 = 2
24
25
26_IS_WINDOWS = sys.platform == "win32"
27
28
29def get_restrict_keyword() -> str:
30    if _IS_WINDOWS:
31        # https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170
32        return "__restrict"
33    else:
34        return "__restrict__"
35
36
37class CppMicroGemm:
38    """
39    A class that codegens a kernel that computes small-sized matrix multiplication.
40
41    A micro GEMM kernel is responsible for register blocking, instruction selection,
42    and other CPU architecture-specific optimizations.
43
44    The subclasses need to override `codegen_define` to define the kernel function
45    that is called by the code generated by `codegen_call`.
46    """
47
48    # TODO(jgong5): support constant shapes and lds as template args.
49    DECLARE_KERNEL = r"""
50template <bool accum>
51inline void {{kernel_name}}(
52{%- if kernel_extra_args_declare %}
53    {{kernel_extra_args_declare}}
54{%- endif %}
55    const {{input_t}}* {{restrict_keyword}} A,
56    const {{input2_t}}* {{restrict_keyword}} B,
57    {{output_t}}* {{restrict_keyword}} C,
58    int64_t M,
59    int64_t N,
60    int64_t K,
61    int64_t lda,
62    int64_t ldb,
63    int64_t ldc
64)
65"""
66
67    def __init__(
68        self,
69        name,
70        input_dtype,
71        input2_dtype,
72        output_dtype,
73        compute_dtype,
74        register_blocking,
75        alpha=1,
76    ) -> None:
77        self.name = name
78        self.input_dtype = input_dtype
79        assert input2_dtype is not None
80        self.input2_dtype = input2_dtype
81        self.output_dtype = output_dtype
82        self.compute_dtype = compute_dtype
83        self.register_blocking = register_blocking
84        self.alpha = alpha
85
86    def get_common_options(self):
87        if self.input_dtype == torch.uint8:
88            assert self.compute_dtype == torch.int32
89            assert self.output_dtype == torch.int32
90            assert self.input2_dtype == torch.int8
91        return {
92            "torch": torch,
93            "kernel_name": self.name,
94            "input_dtype": self.input_dtype,
95            "input2_dtype": self.input2_dtype,
96            "output_dtype": self.output_dtype,
97            "compute_dtype": self.compute_dtype,
98            "input_t": DTYPE_TO_CPP[self.input_dtype],
99            "input2_t": DTYPE_TO_CPP[self.input2_dtype],
100            "output_t": DTYPE_TO_CPP[self.output_dtype],
101            "compute_t": DTYPE_TO_CPP[self.compute_dtype],
102            "alpha": self.alpha,
103            "kernel_extra_args_declare": self.get_kernel_extra_args_declare(),
104            "int8_gemm": self.input_dtype == torch.uint8,
105            "vnni_size": 4 if self.input_dtype == torch.uint8 else 2,
106            "restrict_keyword": get_restrict_keyword(),
107        }
108
109    def get_kernel_declaration(self):
110        options = self.get_common_options()
111        return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options)
112
113    def get_kernel_extra_args_declare(self) -> str:
114        return ""
115
116    def get_kernel_extra_args(self) -> str:
117        return ""
118
119    def codegen_define(self, kernel: CppTemplateKernel) -> str:
120        raise NotImplementedError
121
122    def codegen_call(
123        self,
124        kernel: CppTemplateKernel,
125        A: ir.Buffer,
126        B: ir.Buffer,
127        C: ir.Buffer,
128        accum: bool,
129    ) -> str:
130        """
131        Generate the code for calling the templated kernel that computes
132        `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise.
133        """
134        A_ptr = f"&({kernel.index(A, [0, 0])})"
135        B_ptr = f"&({kernel.index(B, [0, 0])})"
136        C_ptr = f"&({kernel.index(C, [0, 0])})"
137        M = kernel.size(C, 0)
138        N = kernel.size(C, 1)
139        K = kernel.size(A, 1)
140        lda = kernel.stride(A, 0)
141        ldb = kernel.stride(B, 0)
142        ldc = kernel.stride(C, 0)
143        res = IndentedBuffer()
144        res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(")
145        with res.indent():
146            extra_args = self.get_kernel_extra_args()
147            if extra_args:
148                res.writeline(extra_args)
149            res.writeline(f"{A_ptr},")
150            res.writeline(f"{B_ptr},")
151            res.writeline(f"{C_ptr},")
152            res.writeline(f"{M},")
153            res.writeline(f"{N},")
154            res.writeline(f"{K},")
155            res.writeline(f"{lda},")
156            res.writeline(f"{ldb},")
157            res.writeline(f"{ldc}")
158        res.writeline(");")
159        return res.getvalue()
160
161    def codegen_init(
162        self,
163        kernel: CppTemplateKernel,
164    ) -> str:
165        return ""
166
167    def codegen_finalize(
168        self,
169        kernel: CppTemplateKernel,
170    ) -> str:
171        return ""
172
173    def get_b_layout(self) -> LayoutType:
174        return LayoutType.NORMAL
175
176
177@dataclasses.dataclass
178class CppMicroGemmConfig:
179    input_dtype: torch.dtype
180    input2_dtype: torch.dtype
181    output_dtype: torch.dtype
182    compute_dtype: torch.dtype
183    vec_isa_cls: Type[VecISA]
184    register_blocking: GemmBlocking
185    extra_check: Optional[Callable[..., bool]] = None
186
187
188micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {}
189
190
191def register_micro_gemm(*configs):
192    def inner(cls):
193        assert (
194            cls not in micro_gemm_configs
195        ), f"Duplicate micro_gemm registration for {cls}"
196        assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
197        micro_gemm_configs[cls] = list(configs)
198        return cls
199
200    return inner
201
202
203def generate_gemm_config(
204    vec_isa_cls,
205    register_blockings,
206    input_dtype=torch.float,
207    input2_dtype=None,
208    output_dtype=None,
209    compute_dtype=None,
210    extra_check=None,
211):
212    if output_dtype is None:
213        output_dtype = input_dtype
214    if compute_dtype is None:
215        compute_dtype = output_dtype
216    if input2_dtype is None:
217        input2_dtype = input_dtype
218    return [
219        CppMicroGemmConfig(
220            input_dtype,
221            input2_dtype,
222            output_dtype,
223            compute_dtype,
224            vec_isa_cls,
225            GemmBlocking(*blocking),
226            extra_check,
227        )
228        for blocking in register_blockings
229    ]
230
231
232class CppMicroGemmRef(CppMicroGemm):
233    """
234    A reference implementation of the CppMicroGemm class with naive C++ code.
235    It is used for correctness debugging.
236    """
237
238    TEMPLATE_ENTRY = r"""
239{{declare_kernel}} {
240    for (int64_t m = 0; m < M; ++m) {
241        for (int64_t n = 0; n < N; ++n) {
242            {{compute_t}} result = accum ? C[m * ldc + n] : 0;
243            for (int64_t k = 0; k < K; ++k) {
244                result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}};
245            }
246            C[m * ldc + n] = result;
247        }
248    }
249}
250"""
251
252    def __init__(
253        self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
254    ) -> None:
255        super().__init__(
256            name,
257            input_dtype,
258            input2_dtype,
259            output_dtype,
260            compute_dtype,
261            GemmBlocking(1, 1, 1),
262            alpha,
263        )
264
265    def codegen_define(self, kernel: CppTemplateKernel) -> str:
266        options = {
267            "declare_kernel": self.get_kernel_declaration(),
268            **self.get_common_options(),
269        }
270        return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)
271
272
273@register_micro_gemm(
274    *generate_gemm_config(
275        VecAVX512,
276        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
277        input_dtype=torch.float,
278    ),
279    *generate_gemm_config(
280        VecAVX512,
281        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
282        input_dtype=torch.bfloat16,
283        output_dtype=torch.float,
284    ),
285    *generate_gemm_config(
286        VecAVX512,
287        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
288        input_dtype=torch.half,
289        output_dtype=torch.float,
290    ),
291    *generate_gemm_config(
292        VecAVX512,
293        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
294        input_dtype=torch.bfloat16,
295        input2_dtype=torch.int8,
296        output_dtype=torch.float,
297        compute_dtype=torch.float,
298    ),
299    *generate_gemm_config(
300        VecAVX2,
301        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
302        input_dtype=torch.float,
303    ),
304    *generate_gemm_config(
305        VecAVX2,
306        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
307        input_dtype=torch.bfloat16,
308        output_dtype=torch.float,
309    ),
310    *generate_gemm_config(
311        VecAVX2,
312        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
313        input_dtype=torch.half,
314        output_dtype=torch.float,
315    ),
316    *generate_gemm_config(
317        VecAVX2,
318        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
319        input_dtype=torch.bfloat16,
320        input2_dtype=torch.int8,
321        output_dtype=torch.float,
322        compute_dtype=torch.float,
323    ),
324)
325class CppMicroGemmFP32Vec(CppMicroGemm):
326    """
327    This class generates the code for micro gemm using fp32 vec instructions for compute.
328    It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output.
329    The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template,
330    if the desired output is BF16/FP16.
331    """
332
333    TEMPLATE_ENTRY = r"""
334{{declare_kernel}} {
335    TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
336    TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
337    // TODO(jgong5): loop unroll for M and N
338    for (int64_t m = 0; m < M; m += {{block_m}}) {
339        int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
340        for (int64_t n = 0; n < N; n += {{block_n}}) {
341            if (block_m == {{block_m}}) {
342                {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
343                    A + m * lda,
344                    B + n,
345                    C + m * ldc + n,
346                    K,
347                    lda,
348                    ldb,
349                    ldc
350                );
351            } else {
352                switch (block_m) {
353{%- for b in range(block_m - 1, 0, -1) %}
354                case {{b}}:
355                    {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
356                        A + m * lda,
357                        B + n,
358                        C + m * ldc + n,
359                        K,
360                        lda,
361                        ldb,
362                        ldc
363                    );
364                    break;
365{%- endfor %}
366                default:
367                    {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m);
368                }
369            }
370        }
371    }
372}
373"""
374
375    TEMPLATE_KERNEL = r"""
376template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
377inline void {{kernel_name}}_kernel(
378    const {{input_t}}* {{restrict_keyword}} A,
379    const {{input2_t}}* {{restrict_keyword}} B,
380    {{output_t}}* {{restrict_keyword}} C,
381    int64_t K,
382    int64_t lda,
383    int64_t ldb,
384    int64_t ldc
385) {
386    using Vectorized = at::vec::Vectorized<{{compute_t}}>;
387    using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
388    constexpr auto VLEN = Vectorized::size();
389    constexpr auto ROWS = BLOCK_M;
390    constexpr auto COLS = BLOCK_N / VLEN;
391
392    Vectorized va;
393    at::vec::VectorizedN<{{compute_t}}, COLS> vb;
394    at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;
395
396    auto loadc = [&](auto i) {
397        if constexpr (accum) {
398            constexpr int row = i / COLS;
399            constexpr int col = i % COLS;
400            vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
401        } else {
402            vc[i] = Vectorized(0.0f);
403        }
404    };
405    c10::ForcedUnroll<ROWS * COLS>{}(loadc);
406
407    auto compute = [&, COLS](auto i, int k) {
408        constexpr int row = i / COLS;
409        constexpr int col = i % COLS;
410
411        if constexpr (col == 0) {
412{%- if alpha != 1 %}
413            va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}});
414{%- else %}
415            va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]));
416{%- endif %}
417        }
418
419        if constexpr (row == 0) {
420{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
421            auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN);
422            vb[col] = at::vec::convert<{{compute_t}}>(b);
423{%- elif input2_dtype == torch.int8 %}
424            // Convert VLEN int8 elements to int32, and then fp32
425            auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN);
426            vb[col] = at::vec::convert<float>(b32);
427{%- else %}
428            vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
429{%- endif %}
430        }
431
432        constexpr int idx = row * COLS + col;
433        vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
434    };
435
436    for (int k = 0; k < K; ++k) {
437        c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
438    }
439
440    // store to C
441    auto storec = [&](auto i) {
442        constexpr int row = i / COLS;
443        constexpr int col = i % COLS;
444        vc[i].store(C + row * ldc + col * VLEN);
445    };
446    c10::ForcedUnroll<ROWS * COLS>{}(storec);
447}
448"""
449
450    def codegen_define(self, kernel: CppTemplateKernel) -> str:
451        options = {
452            "declare_kernel": self.get_kernel_declaration(),
453            "kernel": kernel,
454            "block_m": self.register_blocking.block_m,
455            "block_n": self.register_blocking.block_n,
456            "block_k": self.register_blocking.block_k,
457            "restrict_keyword": get_restrict_keyword(),
458            **self.get_common_options(),
459        }
460        result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
461            options
462        )
463        result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
464            options
465        )
466        return result
467
468
469# extra check for CppMicroGemmAMX
470def check_amx_extra(config, m, n, k, alpha, num_threads):
471    vnni_size = 4 if config.input_dtype == torch.uint8 else 2
472    return k % vnni_size == 0 and alpha == 1
473
474
475@register_micro_gemm(
476    *generate_gemm_config(
477        VecAMX,
478        [(32, 32, 32), (48, 16, 32), (16, 48, 32)],
479        input_dtype=torch.bfloat16,
480        input2_dtype=torch.int8,
481        output_dtype=torch.float,
482        compute_dtype=torch.float,
483        extra_check=check_amx_extra,
484    ),
485    *generate_gemm_config(
486        VecAMX,
487        [(32, 32, 32), (48, 16, 32), (16, 48, 32)],
488        input_dtype=torch.bfloat16,
489        output_dtype=torch.float,
490        extra_check=check_amx_extra,
491    ),
492    *generate_gemm_config(
493        VecAMX,
494        [(32, 32, 64), (48, 16, 64)],
495        input_dtype=torch.uint8,
496        input2_dtype=torch.int8,
497        output_dtype=torch.int32,
498        compute_dtype=torch.int32,
499        extra_check=check_amx_extra,
500    ),
501)
502class CppMicroGemmAMX(CppMicroGemm):
503    """
504    This class generates the code for micro gemm using Advanced Matrix eXtention (AMX)
505    instructions available in 4th generation Intel Xeon for compute.
506    It supports input types of torch.bfloat16 with fp32 output.
507    TODO(jgong5): support int8 data type.
508    """
509
510    TEMPLATE_ENTRY = r"""
511{{declare_kernel}} {
512    TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
513    TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2");
514    // TODO(jgong5): loop unroll for M and N
515    for (int64_t m = 0; m < M; m += {{block_m}}) {
516        int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
517        int64_t m_tail = m;
518        for (int64_t n = 0; n < N; n += {{block_n}}) {
519{%- for num_rows in range(block_m, 0, -16) %}
520    {%- if num_rows != block_m %}
521            else
522    {%- endif %}
523            if (block_m >= {{num_rows}}) {
524                {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
525                    amx_state,
526                    A + m * lda,
527                    B + n,
528                    C + m * ldc + n,
529                    K,
530                    lda,
531                    ldb,
532                    ldc,
533                    16
534                );
535                block_m -= {{num_rows}};
536                m_tail += {{num_rows}};
537            }
538{%- endfor %}
539            if (block_m > 0) {
540                {{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
541                    amx_state,
542                    A + m_tail * lda,
543                    B + n,
544                    C + m_tail * ldc + n,
545                    K,
546                    lda,
547                    ldb,
548                    ldc,
549                    block_m
550                );
551            }
552        }
553    }
554}
555"""
556
557    TEMPLATE_KERNEL = r"""
558template <bool accum>
559inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
560    AMXState& amx_state,
561    const {{input_t}}* {{restrict_keyword}} A,
562    const {{input2_t}}* {{restrict_keyword}} B,
563    {{output_t}}* {{restrict_keyword}} C,
564    int64_t K,
565    int64_t lda,
566    int64_t ldb,
567    int64_t ldc,
568    uint8_t tilecfg_rows
569) {
570    // TODO(jgong5): add prefetch hint for A, B, C
571    auto loadconfig = [](const amx_tilecfg& cfg) {
572        _tile_loadconfig(&cfg);
573    };
574    const auto last_k_offset = K / {{block_k}} * {{block_k}};
575    const auto tail_k_size = K - last_k_offset;
576    if C10_LIKELY (last_k_offset > 0) {
577        amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig);
578    } else {
579        amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
580    }
581    auto load_c = [&]() {
582{%- for tile_row in range(num_rows // 16) %}
583    {%- for tile_col in range(num_columns) %}
584        {%- set tile_idx = tile_row * num_columns + tile_col %}
585        _tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
586    {%- endfor %}
587{%- endfor %}
588    };
589    auto zero_c = [&]() {
590{%- for tile_row in range(num_rows // 16) %}
591    {%- for tile_col in range(num_columns) %}
592        {%- set tile_idx = tile_row * num_columns + tile_col %}
593        _tile_zero({{tile_idx}});
594    {%- endfor %}
595{%- endfor %}
596    };
597
598    if constexpr (accum) {
599        load_c();
600    } else {
601        zero_c();
602    }
603
604{%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %}
605    // create a buffer for tiles of B.
606    alignas(64) {{input_t}} bf16_weights_buf[512];
607
608    int num_b_rows = (last_k_offset > 0) ? 16 : (tail_k_size * sizeof({{input_t}})) / 4;
609    int b_tile_ptr_stride = ldb * {{vnni_size}};
610
611    auto load_B_row = [&]({{input2_t}}* src, {{input_t}}* dst) {
612        {{kernel.unroll_pragma(2)}}
613        for (int i = 0; i < 2; i++) {
614            // int8 -> int32 -> fp32 -> bf16
615            auto b32 = at::vec::convert_to_int32<int8_t>(src + i * 16);
616            auto b_bf16 = at::vec::convert<{{input_t}}>(b32);
617            b_bf16.store(dst + i * 16);
618         }
619    };
620
621    auto load_B_in_buf = [&]({{input2_t}}* B_ptr) {
622        {{kernel.unroll_pragma(8)}}
623        for (int i = 0; i < num_b_rows; i++) {
624            load_B_row(
625                B_ptr + i * b_tile_ptr_stride,
626                bf16_weights_buf + i * 32
627            );
628        }
629    };
630{%- endif %}
631
632    auto compute = [&](int k) {
633{%- set tile_offset_a = num_rows // 16 * num_columns %}
634{%- set tile_offset_b = tile_offset_a + num_rows // 16 %}
635{%- for tile_row in range(num_rows // 16) %}
636    {%- for tile_col in range(num_columns) %}
637        {%- set tile_idx_a = tile_offset_a + tile_row %}
638        {%- set tile_idx_b = tile_offset_b + tile_col %}
639        {%- set tile_idx_c = tile_row * num_columns + tile_col %}
640        {%- if tile_col == 0 %}
641        _tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}}));
642        {%- endif %}
643        {%- if tile_row == 0 %}
644            {%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %}
645        load_B_in_buf(const_cast<{{input2_t}}*>(B) + k * ldb + {{tile_col * 16 * vnni_size}});
646        _tile_loadd({{tile_idx_b}}, bf16_weights_buf, 64);
647            {%- else %}
648        _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}}));
649            {%- endif %}
650        {%- endif %}
651        {%- if int8_gemm %}
652        _tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
653        {%- else %}
654        _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
655        {%- endif %}
656    {%- endfor %}
657{%- endfor %}
658    };
659
660    {{kernel.unroll_pragma(4)}}
661    for (int k = 0; k < last_k_offset; k += {{block_k}}) {
662        compute(k);
663    }
664
665    auto store_c = [&]() {
666    // store to C
667{%- for tile_row in range(num_rows // 16) %}
668    {%- for tile_col in range(num_columns) %}
669        {%- set tile_idx = tile_row * num_columns + tile_col %}
670        _tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
671    {%- endfor %}
672{%- endfor %}
673    };
674
675    // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
676    if C10_UNLIKELY (tail_k_size > 0) {
677        if C10_LIKELY (last_k_offset > 0) {
678            store_c();
679            amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
680            load_c();
681        }
682        compute(last_k_offset);
683    }
684
685    store_c();
686}
687"""
688
689    def codegen_define(self, kernel: CppTemplateKernel) -> str:
690        block_m, block_n, block_k = self.register_blocking
691        assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX"
692        assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX"
693        if self.input_dtype == torch.uint8:
694            assert block_k == 64, "Only support block_k = 64 for AMX INT8"
695        else:
696            assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16"
697        num_columns = block_n // 16
698        options = {
699            "declare_kernel": self.get_kernel_declaration(),
700            "kernel": kernel,
701            "block_m": block_m,
702            "block_n": block_n,
703            "block_k": block_k,
704            "num_columns": num_columns,
705            "restrict_keyword": get_restrict_keyword(),
706            **self.get_common_options(),
707        }
708        result = ""
709        for num_rows in range(block_m, 0, -16):
710            amx_kernel_options = {**options, "num_rows": num_rows}
711            result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
712                amx_kernel_options
713            )
714        result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
715            options
716        )
717        return result
718
719    def codegen_init(
720        self,
721        kernel: CppTemplateKernel,
722    ) -> str:
723        return "AMXState amx_state;"
724
725    def codegen_finalize(
726        self,
727        kernel: CppTemplateKernel,
728    ) -> str:
729        return "amx_state.release([]() { _tile_release(); });"
730
731    def get_kernel_extra_args_declare(self) -> str:
732        return "AMXState& amx_state,"
733
734    def get_kernel_extra_args(self) -> str:
735        return "amx_state,"
736
737    def get_b_layout(self):
738        if self.input_dtype == torch.uint8:
739            return LayoutType.VNNI4
740        else:
741            return LayoutType.VNNI2
742
743
744def create_micro_gemm(
745    name,
746    m,
747    n,
748    k,
749    input_dtype,
750    input2_dtype,
751    output_dtype=None,
752    compute_dtype=None,
753    alpha=1,
754    num_threads=-1,
755    use_ref=True,
756) -> Optional[CppMicroGemm]:
757    def create_from_config(cls, config: CppMicroGemmConfig):
758        return cls(
759            name,
760            config.input_dtype,
761            config.input2_dtype,
762            config.output_dtype,
763            config.compute_dtype,
764            config.register_blocking,
765            alpha,
766        )
767
768    assert isinstance(n, int) or n.is_number, n
769    assert isinstance(k, int) or k.is_number, k
770    m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m
771    assert isinstance(m, int), m
772    if output_dtype is None:
773        output_dtype = input_dtype
774    if compute_dtype is None:
775        compute_dtype = output_dtype
776    if num_threads < 0:
777        num_threads = parallel_num_threads()
778    vec_isa = pick_vec_isa()
779    matched_configs = []
780    for cls, configs in micro_gemm_configs.items():
781        for config in configs:
782            if not issubclass(vec_isa.__class__, config.vec_isa_cls):
783                continue
784            if (
785                config.input_dtype == input_dtype
786                and config.compute_dtype == compute_dtype
787                and config.input2_dtype == input2_dtype
788                and config.output_dtype == output_dtype
789                # The output_dtype here is the output dtype of the micro-kernel.
790                # In some cases, the actual output dtype of the op for which the micro-kernel
791                # is being created would be same as that of the activation, but the micro-kernels
792                # compute output in Float/int32, which is converted in the GEMM template. This is
793                # subject to change in the future.
794            ):
795                if config.extra_check is not None and not config.extra_check(
796                    config, m, n, k, alpha, num_threads
797                ):
798                    continue
799                block_m, block_n, block_k = config.register_blocking
800                if (
801                    config.vec_isa_cls == VecAMX
802                    and m < block_m
803                    and input_dtype == torch.bfloat16
804                    and input2_dtype == torch.int8
805                ):
806                    # For int8 WoQ GEMM, AMX micro-kernel may not perform well if m < block_m
807                    continue
808                # Criteria on the ranking of configurations
809                # 1. ISA: AMX > VEC
810                # 2. Dividable by block sizes (block_m, block_n, block_k)
811                # 3. Number of mxn blocks is large enough to occupy all the threads
812                # 4. Register blocks are larger
813                isa_score = 0
814                if config.vec_isa_cls == VecAMX:
815                    isa_score += 1
816                dividable_score = 0
817                if m % block_m == 0:
818                    dividable_score += 1
819                if n % block_n == 0:
820                    dividable_score += 1
821                if k % block_k == 0:
822                    dividable_score += 1
823                occupancy_score = 0
824                n_blocks = (n + block_n - 1) // block_n
825                total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m)
826                if n_blocks >= num_threads:
827                    occupancy_score += 1
828                if total_mxn_blocks >= num_threads:
829                    occupancy_score += 1
830                register_bytes = (
831                    block_m * block_n * config.compute_dtype.itemsize
832                    + (block_m * block_k + block_k * block_n)
833                    * config.input_dtype.itemsize
834                )
835                matched_configs.append(
836                    (
837                        (isa_score, dividable_score, occupancy_score, register_bytes),
838                        cls,
839                        config,
840                    )
841                )
842    if len(matched_configs) == 0:
843        if use_ref:
844            return CppMicroGemmRef(
845                name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
846            )
847        else:
848            return None
849    # TODO(jgong5): allow autotuning on choices of configs
850    return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:])
851