xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/cpp_gemm_template.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import logging
4import math
5from functools import lru_cache
6from typing import Any, Callable, cast, List, Optional, Set, Union
7from unittest.mock import patch
8
9import torch
10import torch.utils
11
12from ..._dynamo.utils import counters
13from .. import config, ir, lowering as L
14from ..kernel.mm_common import mm_args
15from ..select_algorithm import DataProcessorTemplateWrapper
16from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
17from ..virtualized import ops, V
18from .cpp import get_export_declaration
19from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType
20from .cpp_template import CppTemplate
21from .cpp_template_kernel import CppTemplateKernel
22from .cpp_utils import (
23    create_epilogue_with_attr,
24    DTYPE_TO_CPP,
25    GemmBlocking,
26    get_gemm_template_output_and_compute_dtype,
27)
28
29
30log = logging.getLogger(__name__)
31
32GEMM_TEMPLATE = r"""
33{{template.header().getvalue()}}
34
35{{micro_gemm.codegen_define(kernel)}}
36
37{%- if x_scale is not none %}
38    {%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %}
39{%- else %}
40    {%- set kernel_args = {"X": X, "W": W, "inp": inp} %}
41{%- endif %}
42
43extern "C" {{export_declaration}}
44{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}}
45{
46    {{kernel.maybe_codegen_profile()}}
47    constexpr int64_t num_threads = {{num_threads}};
48    constexpr int64_t N = {{N}};
49    constexpr int64_t K = {{K}};
50    constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}};
51    constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}};
52    constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}};
53    constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
54    constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
55
56{%- if is_dynamic_M %}
57    const int64_t M = {{kernel.size(GemmOut, 0)}};
58    const int64_t Mr_blocks = (M + Mr - 1) / Mr;
59    {%- if num_threads > 1 %}
60    int64_t Mt_blocks, Nt_blocks, Kt_blocks;
61    mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks);
62    {%- else %}
63    const auto Mt_blocks = Mr_blocks;
64    const auto Nt_blocks = Nr_blocks;
65    const auto Kt_blocks = Kr_blocks;
66    {%- endif %}
67    int64_t Mc_blocks, Nc_blocks, Kc_blocks;
68    uint32_t L1_cache_size = {{L1_cache_size}};
69    uint32_t L2_cache_size = {{L2_cache_size}};
70    mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>(
71        num_threads,
72        M,
73        N,
74        K,
75        Mr,
76        Nr,
77        Kr,
78        Mt_blocks,
79        Nt_blocks,
80        Kt_blocks,
81        Mc_blocks,
82        Nc_blocks,
83        Kc_blocks,
84        L1_cache_size,
85        L2_cache_size
86    );
87    const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
88    const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
89    const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
90{%- else %}
91    constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
92    constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
93    constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}};
94    constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}};
95    constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}};
96    constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}};
97    constexpr int64_t Nc_blocks = {{template.cache_blocking().block_n}};
98    constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}};
99    constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
100    constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
101    constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
102{%- endif %}
103
104    // make sure all partitions are assigned
105    {{kernel.assert_function}}(
106        Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks,
107        "Not all partitions are assigned."
108    );
109
110{%- if maybe_k_slicing %}
111    std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs;
112    if (num_k_slices > 1) {
113        local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]);
114    }
115{%- endif %}
116
117{%- if num_threads > 1 %}
118    #pragma omp parallel num_threads({{num_threads}})
119    {
120        const int tid = omp_get_thread_num();
121        int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end;
122        mm_get_thread_blocks(
123            tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks,
124            m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end);
125    {%- if maybe_k_slicing %}
126        const int64_t k_group_id = tid / num_k_slices;
127        const int64_t k_slice_id = tid % num_k_slices;
128    {%- endif %}
129{%- else %}
130    {
131        const int tid = 0;
132        const int64_t m_block_start = 0;
133        const int64_t m_block_end = Mr_blocks;
134        const int64_t n_block_start = 0;
135        const int64_t n_block_end = Nr_blocks;
136        const int64_t k_block_start = 0;
137        const int64_t k_block_end = Kr_blocks;
138{%- endif %}
139        {{ micro_gemm.codegen_init(kernel) }}
140{%- if use_local_acc %}
141    {%- set acc_buf_name = "local_acc_buf" %}
142        {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
143{%- endif %}
144        for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
145            const int64_t m_start = mc * Mr;
146            const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
147            const int64_t m_size = m_end - m_start;
148            for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
149                const int64_t n_start = nc * Nr;
150                const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
151                const int64_t n_size = n_end - n_start;
152                // NB: assume we pad N, nc_block_end won't exceed padded N here.
153                const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
154{%- if use_local_acc %}
155    {%- set acc = kernel.local_buffers[acc_buf_name] %}
156                {{ kernel.reinit_buffer_if_null(acc_buf_name) }}
157{%- else %}
158    {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %}
159{%- endif %}
160                for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
161                    int64_t k_start = kc * Kr;
162                    int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
163{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
164                    for (int64_t nci = nc; nci < nc_block_end; nci++) {
165{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %}
166{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
167{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
168                        if (kc == k_block_start) {
169                            {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=False)|indent(28, false) }}
170                        } else {
171                            {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=True)|indent(28, false) }}
172                        }
173                    }
174                }
175{%- if maybe_k_slicing %}
176                if (num_k_slices > 1) {
177                    const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
178                    local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }});
179                } else
180{%- endif %}
181                {
182{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %}
183{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %}
184                    {{ kernel.store_output(
185                        tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
186                    )|indent(20, false)
187                    }}
188                }
189            }
190        }
191{%- if maybe_k_slicing %}
192        if (num_k_slices > 1) {
193            #pragma omp barrier
194            for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
195                // We slice M-dim and each thread in the k-slicing group works on a slice
196                const int64_t m_start_unsliced = mc * Mr;
197                const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
198                const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
199                const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices;
200                const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
201                const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
202                const int64_t m_size = m_end - m_start;
203                const int64_t m_offset = m_start - m_start_unsliced;
204                for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
205                    const int64_t n_start = nc * Nr;
206                    const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
207                    const int64_t n_size = n_end - n_start;
208                    const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
209                    auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get();
210                    for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) {
211                        auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get();
212                        for (int64_t m = m_offset; m < m_offset + m_size; m++) {
213                            #pragma omp simd
214                            for (int64_t n = 0; n < n_size; n++) {
215                                {{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n];
216                            }
217                        }
218                    }
219    {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %}
220                    {{ kernel.store_output(
221                        tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
222                    )|indent(20, false)
223                    }}
224                }
225            }
226        }
227{%- endif %}
228        {{ micro_gemm.codegen_finalize(kernel) }}
229    }
230}
231"""
232
233
234def get_padded_n(n, block_n):
235    return (n + block_n - 1) // block_n * block_n
236
237
238class CppPackedGemmTemplate(CppTemplate):
239    def __init__(
240        self,
241        input_nodes,
242        layout: ir.Layout,
243        num_threads: int,
244        register_blocking: GemmBlocking,
245        beta=1,
246        alpha=1,
247        has_bias=False,
248        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
249    ) -> None:
250        assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8]
251        super().__init__(
252            "packed_gemm",
253            input_nodes,
254            layout,
255            num_threads,
256            epilogue_creator=epilogue_creator,
257        )
258        self.beta = beta
259        self.alpha = alpha
260        self.has_bias = has_bias
261        self.register_blocking = register_blocking
262        m, n = layout.size
263        _, k = input_nodes[0].get_size()
264        self.m, self.n, self.k = m, n, k
265        self.padded_n = get_padded_n(n, self.register_blocking.block_n)
266        self.is_dynamic_M = has_free_symbols((m,))
267
268    @cache_on_self
269    def thread_blocking(self) -> GemmBlocking:
270        """
271        NOTE [Thread blocking in Cpp GEMM]
272        We use simple heuristics to decide the thread blocking:
273        1. Make sure all threads are occupied as much as possible.
274        2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse.
275        3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing.
276        TODO(jgong5): allow tuning various blocking options
277        """
278
279        @lru_cache(maxsize=100)
280        def get_factors(number):
281            factors = []
282            for i in range(int(number**0.5), 0, -1):
283                if number % i == 0:
284                    factors.append(number // i)
285                    factors.append(i)
286            return factors
287
288        def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks):
289            thread_block_k = math.ceil(k_blocks / k_factor)
290            thread_block_n = math.ceil(n_blocks / n_factor)
291            thread_block_m = math.ceil(m_blocks / m_factor)
292            return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
293
294        assert (
295            not self.is_dynamic_M
296        ), "Unable to determine thread blocking for dynamic M."
297        register_blocking = self.register_blocking
298        m_blocks = math.ceil(self.m / register_blocking.block_m)
299        n_blocks = math.ceil(self.n / register_blocking.block_n)
300        k_blocks = math.ceil(self.k / register_blocking.block_k)
301        factors = get_factors(self.num_threads)
302        assert len(factors) > 0
303
304        if config.cpp.gemm_thread_factors is not None:
305            factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")]
306            assert len(factors) == 3
307            assert math.prod(factors) == self.num_threads
308            return get_blocking(
309                factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks
310            )
311
312        # we favor square-sized thread blocks for good data reuse
313        def get_better_blocking(blocking, best_blocking):
314            if best_blocking is None:
315                best_blocking = blocking
316            else:
317                block_m_size = blocking.block_m * register_blocking.block_m
318                block_n_size = blocking.block_n * register_blocking.block_n
319                best_block_m_size = best_blocking.block_m * register_blocking.block_m
320                best_block_n_size = best_blocking.block_n * register_blocking.block_n
321                if blocking.block_k > best_blocking.block_k:
322                    best_blocking = blocking
323                elif (
324                    blocking.block_k == best_blocking.block_k
325                    and block_m_size + block_n_size
326                    < best_block_m_size + best_block_n_size
327                ):
328                    best_blocking = blocking
329            return best_blocking
330
331        best_blocking = None
332        # check if we can have a thread-blocking to occupy all threads without k-slicing
333        for n_factor in factors:
334            m_factor = self.num_threads // n_factor
335            if n_blocks >= n_factor and m_blocks >= m_factor:
336                blocking = get_blocking(
337                    m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
338                )
339                best_blocking = get_better_blocking(blocking, best_blocking)
340
341        if best_blocking is None:
342            for k_factor in factors:
343                if k_blocks >= k_factor and (
344                    config.cpp.gemm_max_k_slices == 0
345                    or k_factor <= config.cpp.gemm_max_k_slices
346                ):
347                    n_factors = get_factors(self.num_threads // k_factor)
348                    for n_factor in n_factors:
349                        m_factor = (self.num_threads // k_factor) // n_factor
350                        if n_blocks >= n_factor and m_blocks >= m_factor:
351                            blocking = get_blocking(
352                                m_factor,
353                                n_factor,
354                                k_factor,
355                                m_blocks,
356                                n_blocks,
357                                k_blocks,
358                            )
359                            best_blocking = get_better_blocking(blocking, best_blocking)
360
361        if best_blocking is None:
362            for n_factor in factors:
363                m_factor = self.num_threads // n_factor
364                if n_blocks >= n_factor or m_blocks >= m_factor:
365                    blocking = get_blocking(
366                        m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
367                    )
368                    best_blocking = get_better_blocking(blocking, best_blocking)
369
370        assert best_blocking is not None
371        return best_blocking
372
373    @cache_on_self
374    def cache_blocking(self) -> GemmBlocking:
375        def get_cache_blocking(register_blocking, thread_blocking):
376            Mr = register_blocking.block_m
377            Nr = register_blocking.block_n
378            Kr = register_blocking.block_k
379
380            Mt_blocks = thread_blocking.block_m
381            Nt_blocks = thread_blocking.block_n
382            Kt_blocks = thread_blocking.block_k
383
384            if config.cpp.gemm_cache_blocking is not None:
385                blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")]
386                assert len(blockings) == 3
387                Mc_blocks, Nc_blocks, Kc_blocks = blockings
388                return (
389                    min(Mc_blocks, Mt_blocks),
390                    min(Nc_blocks, Nt_blocks),
391                    min(Kc_blocks, Kt_blocks),
392                )
393
394            # The ratios below are empirically determined to decide
395            # the effective sizes of L1 and L2.
396            # TODO: tune the factor here
397            L1_limit_factor = 0.8
398            L2_limit_factor = 0.5
399
400            L1_cache_size = (
401                torch._C._cpu._L1d_cache_size()
402            )  # per core cache size in Bytes
403            assert (
404                L1_cache_size > 0
405            ), f"Expect L1_cache_size > 0 but got {L1_cache_size}"
406            L1 = L1_cache_size * L1_limit_factor
407
408            L2_cache_size = (
409                torch._C._cpu._L2_cache_size()
410            )  # per core cache size in Bytes
411            assert (
412                L2_cache_size > 0
413            ), f"Expect L2_cache_size > 0 but got {L2_cache_size}"
414            L2 = L2_cache_size * L2_limit_factor
415
416            def get_num_byte(dtype):
417                return torch.tensor([], dtype=dtype).element_size()
418
419            num_byte_A = get_num_byte(self.input_nodes[0].get_dtype())
420            num_byte_B = get_num_byte(self.input_nodes[1].get_dtype())
421
422            # NOTE [CPP GEMM Cache Blocking Algorithm]
423            # Our overall strategy is to
424            # 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc.
425            #    Here, B is Kc x Nr where Nr is a single register block. We use L1 size to
426            #    decide Kc. We want to make Mc large enough to better reuse B.
427            # 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A
428            #    along N, where we have two sub-strategies (see notes below) to decide Mc and Nc.
429
430            # Step 1: Decide Kc assuming B block is L1-reside.
431            size_cache_B = Kr * Kt_blocks * Nr * num_byte_B
432            Kc_blocks = Kt_blocks
433            if size_cache_B > L1:
434                Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B))
435
436            # Step 2: Decide Mc assuming A block is L2-reside.
437            min_Mc_ratio = 2  # TODO(jgong5): something to tune?
438            min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr)
439            assert min_Mc_blocks >= 1
440            Kt_bytes = Kt_blocks * Kr * num_byte_A
441            if min_Mc_blocks * Mr * Kt_bytes < L2:
442                # Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt
443                # when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks)
444                # to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside
445                # in L1.
446                Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes)))
447                Nc_blocks = 1
448            else:
449                # Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse
450                # A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2.
451                Mc_blocks = Mt_blocks
452                Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
453                Nc_bytes = Nc_blocks * Nr * 4  # assume C or acc is float32/int32
454                Kc_bytes = Kc_blocks * Kr * num_byte_A
455                if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2:
456                    # The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2,
457                    # assuming Mc == Nc for good data reuse.
458                    M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8
459                    if M_max < Mc_blocks * Mr:
460                        Mc_blocks = math.floor(M_max / Mr)
461                        Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
462
463            return Mc_blocks, Nc_blocks, Kc_blocks
464
465        assert (
466            not self.is_dynamic_M
467        ), "Unable to determine cache blocking for dynamic M."
468        register_blocking = self.register_blocking
469        thread_blocking = self.thread_blocking()
470
471        return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking))
472
473    def log_blockings(self):
474        log.debug(f"Register blocking: {self.register_blocking}")  # noqa: G004
475        if self.is_dynamic_M:
476            # thread and cache blockings are determined at runtime for dynamic shapes
477            return
478        log.debug(f"Cache blocking: {self.cache_blocking()}")  # noqa: G004
479        thread_blocking = self.thread_blocking()
480        log.debug(f"Thread blocking: {thread_blocking}")  # noqa: G004
481
482        def get_occupancy():
483            m_blocks = math.ceil(self.m / self.register_blocking.block_m)
484            n_blocks = math.ceil(self.n / self.register_blocking.block_n)
485            k_blocks = math.ceil(self.k / self.register_blocking.block_k)
486            m = math.ceil(m_blocks / thread_blocking.block_m)
487            n = math.ceil(n_blocks / thread_blocking.block_n)
488            k = math.ceil(k_blocks / thread_blocking.block_k)
489            return (m, n, k)
490
491        log.debug(
492            f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}"  # noqa: G004
493        )
494
495    def maybe_k_slicing(self):
496        if self.num_threads == 1:
497            return False
498        if self.is_dynamic_M:
499            # TODO(jgong5): perhaps use size hint to decide?
500            return True
501        register_blocking = self.register_blocking
502        k_blocks = math.ceil(self.k / register_blocking.block_k)
503        thread_blocking = self.thread_blocking()
504        return k_blocks > thread_blocking.block_k
505
506    @staticmethod
507    def add_choices(
508        choices,
509        layout,
510        input_nodes,
511        beta=1,
512        alpha=1,
513        has_bias=False,
514        trans_w=False,
515        input_indices=None,
516        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
517    ):
518        if input_indices is None:
519            input_indices = list(range(len(input_nodes)))
520
521        def reorder_and_filter(inputs, layout_or_out):
522            if has_bias:
523                assert len(input_indices) >= 3
524                # Assume the input order is [inp, x, w] and we reorder it to [x, w, inp]
525                inp_idx = input_indices[0]
526                x_idx = input_indices[1]
527                w_idx = input_indices[2]
528                return [
529                    inputs[x_idx],
530                    inputs[w_idx],
531                    inputs[inp_idx],
532                    *[inputs[idx] for idx in input_indices[3:]],
533                ], layout_or_out
534            else:
535                assert len(input_indices) >= 2
536                return [inputs[idx] for idx in input_indices], layout_or_out
537
538        def maybe_to_dense(inputs, layout_or_out):
539            new_inputs = list(inputs)
540            if isinstance(inputs[1], torch.Tensor):
541                W = inputs[1]
542                new_inputs[1] = W.to_dense() if W.is_mkldnn else W
543            return new_inputs, layout_or_out
544
545        def normalize_shapes(inputs, layout_or_out):
546            if not trans_w:
547                return inputs, layout_or_out
548            new_inputs = list(inputs)
549            X = inputs[0]
550            W = inputs[1]
551            B = inputs[2] if has_bias else None
552            if isinstance(W, ir.IRNode):
553                if trans_w:
554                    if not isinstance(W, ir.TensorBox):
555                        W = ir.TensorBox(W)
556                    W = L.permute(W, [1, 0])
557            else:
558                if trans_w:
559                    assert isinstance(W, torch.Tensor)
560                    W = W.transpose(0, 1)
561            if B is not None:
562                if isinstance(B, ir.IRNode):
563                    if not isinstance(B, ir.TensorBox):
564                        B = ir.TensorBox(B)
565                    B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
566                else:
567                    assert isinstance(B, torch.Tensor)
568                    B = B.expand(X.shape[0], B.shape[-1])
569            new_inputs[1] = W
570            if B is not None:
571                new_inputs[2] = B
572            return new_inputs, layout_or_out
573
574        # TODO(jgong5): decide proper number of threads per problem size
575        num_threads = parallel_num_threads()
576        new_inputs, _ = normalize_shapes(
577            *maybe_to_dense(*reorder_and_filter(input_nodes, layout))
578        )
579        m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1])
580        output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
581            new_inputs[0].get_dtype()
582        )
583        micro_gemm = create_micro_gemm(
584            "micro_gemm",
585            m,
586            n,
587            k,
588            input_dtype=new_inputs[0].get_dtype(),
589            input2_dtype=new_inputs[1].get_dtype(),
590            output_dtype=output_dtype,
591            compute_dtype=compute_dtype,
592            alpha=alpha,
593            num_threads=num_threads,
594        )
595        assert micro_gemm is not None
596        _, block_n, _ = micro_gemm.register_blocking
597        padded_n = get_padded_n(n, block_n)
598
599        def pack_weight(inputs, layout_or_out):
600            W = inputs[1]
601            new_inputs = list(inputs)
602            blocked_w: Union[ir.IRNode, torch.Tensor] = W
603            if isinstance(W, ir.IRNode):
604                new_size = [padded_n // block_n, k, block_n]
605                blocked_w = ir.Buffer(
606                    W.get_name(),  # Borrow the registered buffer name
607                    ir.FixedLayout(
608                        W.get_device(),
609                        W.get_dtype(),
610                        new_size,
611                        ir.FlexibleLayout.contiguous_strides(new_size),
612                        0,
613                    ),
614                )
615            else:
616                blocked_w = (
617                    torch.nn.functional.pad(W, (0, padded_n - n))
618                    .reshape(k, padded_n // block_n, block_n)
619                    .transpose(0, 1)
620                    .contiguous()
621                )
622                if micro_gemm.get_b_layout() != LayoutType.NORMAL:
623                    layout_str = (
624                        "VNNI4"
625                        if micro_gemm.get_b_layout() == LayoutType.VNNI4
626                        else "VNNI2"
627                    )
628                    assert micro_gemm.get_b_layout() in [
629                        LayoutType.VNNI2,
630                        LayoutType.VNNI4,
631                    ], f"We only support {layout_str} for now"
632                    vnni_size = (
633                        4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
634                    )
635                    assert (
636                        k % vnni_size == 0
637                    ), f"k should be divisible by vnni_size for {layout_str} layout"
638                    blocked_w = (
639                        blocked_w.view(
640                            padded_n // block_n, k // vnni_size, vnni_size, block_n
641                        )
642                        .transpose(-1, -2)
643                        .contiguous()
644                        .view(padded_n // block_n, k, block_n)
645                    )
646                # normalize stride to be "contiguous_strides" per size
647                # this avoids the problems in L.view during template codegen
648                new_stride = [1]
649                for sz in reversed(blocked_w.shape[1:]):
650                    new_stride.insert(0, new_stride[0] * sz)
651                blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride)
652            new_inputs[1] = blocked_w
653
654            def _is_int8_gemm(inputs):
655                return (
656                    isinstance(inputs[0], ir.IRNode)
657                    and inputs[0].get_dtype() == torch.uint8
658                ) or (
659                    isinstance(inputs[0], torch.Tensor)
660                    and inputs[0].dtype == torch.uint8
661                )
662
663            if _is_int8_gemm(new_inputs):
664                BCompensate = None
665                if isinstance(W, ir.IRNode):
666                    BCompensate = V.graph.add_tensor_constant(
667                        V.graph.constants[W.get_name() + "_BMatrixCompens"],
668                        W.get_name() + "_BMatrixCompens",
669                    )
670                else:
671                    BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0)  # type: ignore[assignment]
672                new_inputs.append(BCompensate)
673            return new_inputs, layout_or_out
674
675        def preprocessor(inputs, layout):
676            return pack_weight(
677                *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
678            )
679
680        def postprocessor(output):
681            if isinstance(output, ir.TensorBox):
682                # prepack the weight as input to the template buffer
683                template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
684                assert isinstance(template_buffer, ir.CppTemplateBuffer)
685                new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
686
687                W_node = new_input_nodes[1]
688                assert W_node.get_name() in V.graph.constants
689                W = V.graph.constants[W_node.get_name()]
690                new_input_nodes[1] = W
691                new_input_nodes, _ = pack_weight(
692                    *normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
693                )
694
695                # By using the new packed weight for the GEMM template, we can prune the
696                # old weight if it has no other users. This saves memory but makes the FX graph
697                # non-retraceable. To support retracing, we can add a repack node to the
698                # FX graph. For example:
699                # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template
700                W_tensor_users = 0
701                for node in reversed(V.graph.graph.nodes):
702                    # Case may happen when the wgt tensor is used by more than 1 get_attr node
703                    # https://github.com/pytorch/pytorch/issues/134998
704                    if node.op == "get_attr" and hasattr(
705                        V.graph.module, node.name
706                    ):  # wgt might already be deleted
707                        comp_tensor = getattr(V.graph.module, node.name)
708                        if (
709                            W.is_mkldnn == comp_tensor.is_mkldnn
710                            and W.dtype == comp_tensor.dtype
711                            and W.device == comp_tensor.device
712                            and (
713                                (
714                                    not W.is_mkldnn
715                                    and (
716                                        W.untyped_storage().data_ptr()
717                                        == comp_tensor.untyped_storage().data_ptr()
718                                    )
719                                )
720                                or (
721                                    W.is_mkldnn
722                                    and (
723                                        torch.ops.mkldnn.data_ptr(W)
724                                        == torch.ops.mkldnn.data_ptr(comp_tensor)
725                                    )
726                                )
727                            )
728                        ):
729                            W_tensor_users += 1
730
731                for node in reversed(V.graph.graph.nodes):
732                    # The wgt tensor has been used by only 1 get_attr node
733                    # The get_attr node has only 1 user fx node
734                    if (
735                        node.name == W_node.get_name()
736                        and len(node.users) == 1
737                        and W_tensor_users == 1
738                    ):
739                        del V.graph.constants[node.name]
740                        delattr(V.graph.module, node.name)
741                        delattr(V.graph.graph.owning_module, node.name)
742
743                W_packed = new_input_nodes[1]
744                W_packed_constant = V.graph.add_tensor_constant(W_packed)
745                template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input(
746                    W_packed_constant
747                )
748            return output
749
750        template = DataProcessorTemplateWrapper(
751            CppPackedGemmTemplate,
752            preprocessor,
753            postprocessor,
754            input_nodes=input_nodes,
755            layout=layout,
756            num_threads=num_threads,
757            register_blocking=micro_gemm.register_blocking,
758            beta=beta,
759            alpha=alpha,
760            has_bias=has_bias,
761            epilogue_creator=epilogue_creator,
762        )
763        template.maybe_append_choice(choices)
764        return template
765
766    def render(  # type: ignore[override,return]
767        self,
768        kernel: CppTemplateKernel,
769        template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
770        flag_template_buffer_has_other_users: Optional[bool] = None,
771        epilogue_nodes: Optional[List[ir.IRNode]] = None,
772        **kwargs,
773    ) -> str:
774        assert len(self.input_nodes) >= 2
775
776        int8_gemm = self.input_nodes[0].get_dtype() == torch.uint8
777        x_scale = None
778        x_zp = None
779        w_scale = None
780        w_zp = None
781        if int8_gemm:
782            X, W = self.input_nodes[0], self.input_nodes[1]
783            bias_idx = 2 if self.has_bias else 1
784            inp = self.input_nodes[bias_idx] if self.has_bias else None
785            x_scale = self.input_nodes[bias_idx + 1]
786            x_zp = self.input_nodes[bias_idx + 2]
787            w_scale = self.input_nodes[bias_idx + 3]
788            w_zp = self.input_nodes[bias_idx + 4]
789            Y = self.output_node
790        else:
791            X, W = self.input_nodes[0], self.input_nodes[1]
792            Y = self.output_node
793            inp = self.input_nodes[2] if self.has_bias else None
794
795        template_buffer_has_other_users = None
796
797        if template_buffer_node is not None:
798            # Use the updated prepacked weight buffer
799            W = template_buffer_node.inputs[1]
800            Y = template_buffer_node
801
802            assert flag_template_buffer_has_other_users is not None
803            template_buffer_has_other_users = flag_template_buffer_has_other_users
804
805        template_buffer = Y
806        gemm_output_buffer = template_buffer
807
808        epilogues: List[ir.IRNode] = []
809        reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = []
810        epilogue_creators: List[Callable[[ir.Buffer], ir.Pointwise]] = []
811        fake_buffers: List[ir.Buffer] = []
812        Y_aliases: Set[str] = set()
813
814        use_local_acc = (
815            self.layout.dtype != torch.float
816            or template_buffer_has_other_users
817            or int8_gemm
818            or self.padded_n != self.n
819            or self.maybe_k_slicing()
820        )
821
822        # TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template,
823        # but we'd better move it here to align with fp.
824        if inp is not None and self.beta != 0 and not int8_gemm:
825            # add an epilogue for bias add
826            def _bias_add_epilogue(buf):
827                return create_epilogue_with_attr(
828                    buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype
829                )
830
831            epilogue_creators.append(_bias_add_epilogue)
832
833        if self.epilogue_creator is not None:
834            epilogue_creators.append(self.epilogue_creator)
835
836        # When the GEMM output buffer is localized but it has users other than the epilogue nodes,
837        # we need to copy the value in the GEMM output local buffer to a global buffer.
838        def need_copy_from_local_to_global_buffer_epilogue(
839            use_local_acc, template_buffer_has_other_users, epilogue_creators
840        ):
841            # The GEMM output buffer is a global buffer, thus copy is not needed.
842            if not use_local_acc:
843                return False
844
845            # The possible value of template_buffer_has_other_users is (None, False, True)
846            # It is None when generating the gemm template during autotune and it will have value during scheduler codegen.
847            # extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases:
848            #   1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune)
849            #   2. template_buffer_has_other_users is False, which means it's safe to keep the value in the
850            #       GEMM output buffer in local buffer only (no users outside of the epilogues will use its value).
851            if not template_buffer_has_other_users:
852                return False
853
854            # When bias is not None or self.epilogue_creator is not None,
855            # there will be epilogue_creators after the GEMM.
856            # The GEMM output buffer is localized while
857            # the output buffer of the epilogue_creators is a global buffer.
858            if epilogue_creators:
859                return False
860
861            return True
862
863        if need_copy_from_local_to_global_buffer_epilogue(
864            use_local_acc, template_buffer_has_other_users, epilogue_creators
865        ):
866
867            def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer):
868                dtype = self.layout.dtype
869                input_loader = input_buffer.make_loader()
870
871                def copy_inner(index):
872                    input = input_loader(index)
873                    result = ops.to_dtype(input, dtype)
874                    return result
875
876                return ir.Pointwise(
877                    device=input_buffer.get_device(),
878                    dtype=self.layout.dtype,
879                    inner_fn=copy_inner,
880                    ranges=input_buffer.get_size(),
881                )
882
883            epilogue_creators.append(copy_from_local_to_global_buffer_epilogue)
884
885        # NOTE [How CPP GEMM template epilogues are organized]
886        #   gemm_output_buffer
887        #     --> zero or more in-template epilogues (created by `epilogue_creators`) -->
888        #   template_buffer
889        #     --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
890        #   Y
891        if epilogue_creators:
892            gemm_output_name = "buf_GemmOut"
893            gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout)
894            current_input_buffer = gemm_output_buffer
895            for i, creator in enumerate(epilogue_creators):
896                if i == len(epilogue_creators) - 1:
897                    buffer_name = template_buffer.get_name()
898                else:
899                    buffer_name = f"buf_GemmOut_epilogue_{i}"
900                epilogues.append(
901                    ir.ComputedBuffer(
902                        name=buffer_name,
903                        layout=template_buffer.layout,
904                        data=creator(current_input_buffer),
905                    )
906                )
907                fake_buffers.append(current_input_buffer)
908                Y_aliases.add(current_input_buffer.get_name())
909                reindexers.append(None)
910                if i < len(epilogue_creators) - 1:
911                    current_input_buffer = ir.Buffer(
912                        buffer_name, template_buffer.layout
913                    )
914
915        Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
916
917        if epilogue_nodes:
918            epilogues.extend(epilogue_nodes)
919            assert Y.get_numel() == epilogues[-1].get_numel()
920            Y = cast(ir.Buffer, epilogues[-1])
921
922            if not template_buffer_has_other_users:
923                Y_aliases.add(template_buffer.get_name())
924
925            if (
926                Y.get_size() == template_buffer.get_size()
927                and Y.get_stride() == template_buffer.get_stride()
928            ):
929                reindexers.extend([None] * len(epilogue_nodes))
930                Y_2d = Y
931            else:
932
933                def get_reindexer(epilogue_node):
934                    # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example:
935                    #   template_buffer:
936                    #       size (324, 512), stride (512, 1)
937                    #   epilogue_node_ordered (ordered by stride decreasingly, in dense format):
938                    #       size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
939                    stride_order = list(
940                        ir.get_stride_order(
941                            V.graph.sizevars.size_hints(epilogue_node.get_stride())
942                        )
943                    )
944                    fill_order = ir.stride_order2fill_order(stride_order)
945                    reversed_fill_order = list(reversed(fill_order))
946                    size_with_stride_ordered_decreasingly = [
947                        epilogue_node.get_size()[i] for i in reversed_fill_order
948                    ]
949                    reshape_reindex = ir.View.dynamic_reshape_indexer(
950                        size_with_stride_ordered_decreasingly,
951                        template_buffer.get_size(),
952                    )
953
954                    # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example:
955                    #   epilogue_node_ordered (ordered by stride decreasingly, in dense format):
956                    #       size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
957                    #   epilogue_node:
958                    #       size (1, 18, 18, 512), stride (165888, 1, 9216, 512)
959                    from_stride_ordered_decreasingly_to_epilogue_node_order = [
960                        (len(stride_order) - 1) - stride_order[i]
961                        for i in range(len(stride_order))
962                    ]
963                    stride_reindex = ir.same_reorder(
964                        from_stride_ordered_decreasingly_to_epilogue_node_order
965                    )
966
967                    reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex)
968                    return reindexer
969
970                reindexers.extend([get_reindexer(epilogue_node) for epilogue_node in epilogue_nodes])  # type: ignore[list-item]
971                if isinstance(Y, ir.BaseView):
972                    storage = ir.StorageBox(Y.unwrap_view())
973                else:
974                    assert isinstance(Y, ir.Buffer)
975                    storage = ir.StorageBox(Y)
976                Y_2d = ir.ReinterpretView(storage, template_buffer.get_layout())
977
978        output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
979            X.get_dtype()
980        )
981        micro_gemm = create_micro_gemm(
982            f"{kernel.kernel_name}_micro_gemm",
983            self.m,
984            self.n,
985            self.k,
986            input_dtype=X.get_dtype(),
987            input2_dtype=W.get_dtype(),
988            output_dtype=output_dtype,
989            compute_dtype=compute_dtype,
990            alpha=self.alpha,
991            num_threads=self.num_threads,
992        )
993        assert micro_gemm is not None
994        assert self.register_blocking == micro_gemm.register_blocking
995        self.log_blockings()
996        if isinstance(micro_gemm, CppMicroGemmAMX):
997            counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
998
999        L1_cache_size = torch._C._cpu._L1d_cache_size()  # per core cache size in Bytes
1000        assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
1001
1002        L2_cache_size = torch._C._cpu._L2_cache_size()  # per core cache size in Bytes
1003        assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
1004
1005        options = dict(
1006            X=X,
1007            W=W,
1008            inp=inp,
1009            Y=Y,
1010            N=self.n,
1011            K=self.k,
1012            PADDED_N=self.padded_n,
1013            GemmOut=gemm_output_buffer,
1014            aliases={alias: Y.get_name() for alias in Y_aliases},
1015            beta=self.beta,
1016            alpha=self.alpha,
1017            num_threads=self.num_threads,
1018            micro_gemm=micro_gemm,
1019            is_dynamic_M=self.is_dynamic_M,
1020            template=self,
1021            kernel=kernel,
1022            export_declaration=get_export_declaration(),
1023            epilogue_nodes=epilogues,
1024            reindexers=reindexers,
1025            Y_2d=Y_2d,
1026            use_local_acc=use_local_acc,
1027            maybe_k_slicing=self.maybe_k_slicing(),
1028            x_scale=x_scale,
1029            x_zp=x_zp,
1030            w_scale=w_scale,
1031            w_zp=w_zp,
1032            acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
1033            DTYPE_TO_CPP=DTYPE_TO_CPP,
1034            L1_cache_size=L1_cache_size,
1035            L2_cache_size=L2_cache_size,
1036            config=config,
1037        )
1038        with contextlib.ExitStack() as stack:
1039            for buf in fake_buffers:
1040                stack.enter_context(
1041                    patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
1042                )
1043            return self._template_from_string(GEMM_TEMPLATE).render(**options)
1044