xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type"
2import copy
3import logging
4import random
5from typing import List, Optional
6
7import sympy
8
9import torch
10from torch._inductor import config
11from torch._inductor.codegen.rocm.ck_template import CKTemplate
12from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel
13from torch._inductor.ir import Buffer, Layout
14
15from ...utils import IndentedBuffer, try_import_ck_lib
16
17
18_, gen_ops_library, gen_ops_preselected, CKGemmOperation = try_import_ck_lib()
19
20
21log = logging.getLogger(__name__)
22
23
24def is_static_int(number):
25    return isinstance(number, (int, sympy.Integer))
26
27
28def torch_layout_to_ck_layout(torch_layout):
29    if torch_layout.stride[-1] == 1:
30        return "Row"
31    elif torch_layout.stride[-2] == 1:
32        return "Col"
33    else:
34        return None
35
36
37class CKGemmTemplate(CKTemplate):
38    # the JINJA template for rendering CK Universal GEMMs
39    gemm_template = r"""{{version_comment}}
40    {{headers}}
41    {{globals}}
42    {{instance_definition}}
43    extern "C" {
44    {{kernel_definition}} {
45        auto gemm = {{instance_type}} {};
46        auto invoker = gemm.MakeInvoker();
47
48        auto argument = gemm.MakeArgument(
49            reinterpret_cast<const {{a_element_dtype}}*>(X),
50            reinterpret_cast<const {{b_element_dtype}}*>(W),
51            std::array<const void*, {{1 if has_bias else 0}}>{ {{'Bias' if has_bias else ''}} },
52            reinterpret_cast<{{c_element_dtype}}*>(Y),
53            M,
54            N,
55            K,
56            LDA,
57            LDB,
58            std::array<ck::index_t, {{1 if has_bias else 0}}>{ {{'LDD' if has_bias else ''}} },
59            LDC,
60            1, // kBatch
61            PassThrough {}, // a_elementwise_op
62            PassThrough {}, // b_elementwise_op
63            {{epilogue}} // c_elementwise_op
64        );
65        if (!gemm.IsSupportedArgument(argument)) {
66            // we do our best to statically avoid this case in `filter_op`
67            std::cerr << "invalid argument for gemm instance " << gemm.GetTypeString() << std::endl;
68            argument.Print();
69            return -23;
70        }
71        if (workspace_size) {
72            *workspace_size = gemm.GetWorkSpaceSize(&argument);
73            return 0;
74        }
75        // run the kernel
76        float elapsed_time = invoker.Run(argument, StreamConfig{stream, /* time kernel */ false, /* log level */ kDEBUG_LOG});
77        return 0;
78    } // kernel definition
79    } // extern C
80    """
81
82    def __init__(
83        self,
84        input_nodes: List[Buffer],
85        layout: Layout,
86        alpha: float,
87        beta: float,
88        input_reorder: Optional[List[int]] = None,
89    ) -> None:
90        super().__init__(
91            "ck_gemm_template",
92            input_nodes=input_nodes,
93            layout=layout,
94            input_reorder=input_reorder,
95        )
96        self.alpha = alpha
97        self.beta = beta
98
99    def header(self) -> IndentedBuffer:
100        res = super().header()
101        res.splice(
102            """
103                // CK GEMM header(s)
104
105                #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
106            """
107        )
108        return res
109
110    def globals(self) -> IndentedBuffer:
111        res = super().globals()
112        res.splice(
113            """
114                // CK GEMM globals
115
116                using Row = ck::tensor_layout::gemm::RowMajor;
117                using Col = ck::tensor_layout::gemm::ColumnMajor;
118
119                using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler;
120                using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
121                using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion;
122            """
123        )
124        return res
125
126    def filter_op(self, op: "CKGemmOperation"):
127        """
128        Determines whether a given op definition is suitable for the current
129        input / output of the operation that this template implements.
130
131        Filter is based on inputs' dtype, layout and statically inferred size.
132
133        Returns None if the op is not suitable, otherwise returns the op to be used.
134        """
135        metas = [T.get_layout() for T in [*self.input_nodes, self.output_node]]
136        X_meta = metas[0]
137        W_meta = metas[1]
138        Y_meta = metas[-1]
139        # disable the instance if dtypes don't match
140        if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]:
141            return None
142        if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]:
143            return None
144        if op.c_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]:
145            return None
146        # disable the instance if layouts don't match
147        if op.a_layout != torch_layout_to_ck_layout(X_meta):
148            return None
149        if op.b_layout != torch_layout_to_ck_layout(W_meta):
150            return None
151        if op.c_layout != torch_layout_to_ck_layout(Y_meta):
152            return None
153        # try to avoid launching the instance with invalid problem size
154        # see GridwiseGemm_xdl_cshuffle_v3::CheckValidity
155
156        M = X_meta.size[-2]
157        K = X_meta.size[-1]
158        N = W_meta.size[-1]
159
160        if is_static_int(M):
161            if not any(
162                m_padding in op.gemm_specialization
163                for m_padding in ["MPadding", "MNPadding", "MKPadding", "MNKPadding"]
164            ):
165                if M % op.m_per_block != 0:
166                    return None
167        if is_static_int(N):
168            if not any(
169                n_padding in op.gemm_specialization
170                for n_padding in ["NPadding", "MNPadding", "NKPadding", "MNKPadding"]
171            ):
172                if N % op.n_per_block != 0:
173                    return None
174        if is_static_int(K):
175            if not any(
176                k_padding in op.gemm_specialization
177                for k_padding in ["KPadding", "MKPadding", "NKPadding", "MNKPadding"]
178            ):
179                if K % op.k_per_block != 0:
180                    return None
181
182        a_contig_size = (
183            K if op.a_layout == "Row" else M if op.a_layout == "Col" else None
184        )
185        if (
186            is_static_int(a_contig_size)
187            and a_contig_size % op.a_block_transfer_src_scalar_per_vector != 0
188        ):
189            return None
190        b_contig_size = (
191            N if op.b_layout == "Row" else K if op.b_layout == "Col" else None
192        )
193        if (
194            is_static_int(b_contig_size)
195            and b_contig_size % op.b_block_transfer_src_scalar_per_vector != 0
196        ):
197            return None
198        c_contig_size = (
199            N if op.c_layout == "Row" else M if op.c_layout == "Col" else None
200        )
201        if (
202            is_static_int(c_contig_size)
203            and c_contig_size
204            % op.c_shuffle_block_transfer_scalar_per_vector_n_per_block
205            != 0
206        ):
207            return None
208
209        # TBD disable instances with invalid number of pipeline prefetch stages
210        # It will avoid compiling a small percentage of unrunnable instances which fail the gemm argument check
211
212        return op
213
214    def emit_ck_instance(self, op: "CKGemmOperation"):
215        # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance
216        template_definition = r"""
217    // Gemm operator {{operation_name}}
218    using Operation_{{operation_name}} =
219        ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
220            {{template_params}}>;
221
222"""
223        # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance
224        template_type = r"""
225    Operation_{{operation_name}}
226"""
227        template_params = []
228        for field_name, field_value in op.dict_items():
229            if isinstance(field_value, tuple):
230                tuple_elements = ", ".join(map(str, iter(field_value)))
231                if "ds" in field_name:  # element type and layout for bias
232                    arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
233                else:  # tile shape
234                    arg = f"/* {field_name} */ S<{tuple_elements}>"
235                template_params.append(arg)
236            else:
237                if field_value is not None:
238                    template_params.append(f"/* {field_name} */ {field_value}")
239        return self._template_from_string(template_definition).render(
240            operation_name=op.name(),
241            template_params=(",\n" + 12 * " ").join(template_params),
242        ), self._template_from_string(template_type).render(operation_name=op.name())
243
244    def render(self, kernel: ROCmTemplateKernel, op: "CKGemmOperation", **kwargs) -> str:  # type: ignore[override]
245        """
246        The primary entry point for the code rendering process used in this template.
247        """
248        epilogue_nodes = kwargs.get("epilogue_nodes", None)
249        assert epilogue_nodes is None or 0 == len(epilogue_nodes)
250        template_buffer_node = kwargs.get("template_buffer_node", None)
251        if template_buffer_node is not None:
252            self.output_node = template_buffer_node
253        X, W = self.input_nodes[0], self.input_nodes[1]
254        Y = self.output_node
255        Bias = self.input_nodes[2] if 3 == len(self.input_nodes) else None
256
257        op = copy.deepcopy(op)
258
259        # This parameter is converted into tuple because of change
260        # from DeviceGemm_Xdl_CShuffleV3 to DeviceGemmMultiD_Xdl_CShuffle_V3.
261        # The first tuple element corresponds to matmul result...
262        op.c_shuffle_block_transfer_scalar_per_vector_n_per_block = (
263            op.c_shuffle_block_transfer_scalar_per_vector_n_per_block,
264        )
265
266        if Bias is not None:
267            op.ds_layouts = (torch_layout_to_ck_layout(Bias.get_layout()),)
268            op.ds_element_dtypes = ((self._TORCH_DTYPE_TO_CK[Bias.get_layout().dtype]),)
269            op.c_elementwise_op = "Bilinear"
270            # c_shuffle_dtype is also used for adding bias to matmul result
271            # before converting down to the result dtype
272            op.c_shuffle_dtype = op.acc_dtype
273            # this parameter needs to be set accordingly to bias stride for correct accumulation
274            if op.ds_layouts[0] == "Row":
275                # bias has (N, ) shape
276                bias_shuffle_block_transfer_scalar_per_vector_n_per_block = (
277                    op.c_shuffle_block_transfer_scalar_per_vector_n_per_block
278                )
279            else:
280                # bias has (M, 1) shape
281                bias_shuffle_block_transfer_scalar_per_vector_n_per_block = (1,)
282            # ...and the second tuple element corresponds to the bias
283            op.c_shuffle_block_transfer_scalar_per_vector_n_per_block += (
284                bias_shuffle_block_transfer_scalar_per_vector_n_per_block
285            )
286
287        instance_definition, instance_type = self.emit_ck_instance(op)
288
289        version_comment = rf"""/**
290* Generated code for CK inductor backend
291* See {type(self).__module__}.{type(self).__qualname__}
292*
293* Template instance {op}
294*
295* {torch.__version__=}
296* {torch.version.git_version=}
297*/
298"""
299
300        return self._template_from_string(self.gemm_template).render(
301            headers=self.header().getvalue(),
302            globals=self.globals().getvalue(),
303            instance_definition=instance_definition,
304            kernel_definition=kernel.def_kernel(
305                inputs=[X, W, Bias],  # type: ignore[list-item]
306                outputs=[Y],
307                names_str="X, W, Bias, Y",
308                input_reorder=self.input_reorder,
309                size_args=[
310                    f"ck::index_t {arg}"
311                    for arg in ["M", "N", "K", "LDA", "LDB", "LDC", "LDD"]
312                ],
313            ),
314            instance_type=instance_type,
315            a_element_dtype=op.a_element_dtype,
316            b_element_dtype=op.b_element_dtype,
317            c_element_dtype=op.c_element_dtype,
318            bias_element_dtype=op.ds_element_dtypes[0] if Bias is not None else "",
319            alpha=self.alpha,
320            beta=self.beta,
321            epilogue=f"Bilinear {{ {self.alpha}, {self.beta} }}"
322            if Bias is not None
323            else "PassThrough {}",
324            has_bias=Bias is not None,
325            version_comment=version_comment,
326        )
327
328    def _is_rcr_f16(self):
329        X_meta, W_meta, Y_meta = (
330            T.get_layout() for T in [*self.input_nodes, self.output_node]
331        )
332        X_dtype, W_dtype, Y_dtype = (
333            self._TORCH_DTYPE_TO_CK[m.dtype] for m in (X_meta, W_meta, Y_meta)
334        )
335        X_layout, W_layout, Y_layout = (
336            torch_layout_to_ck_layout(m) for m in (X_meta, W_meta, Y_meta)
337        )
338
339        return (
340            X_dtype == "F16"
341            and W_dtype == "F16"
342            and Y_dtype == "F16"
343            and X_layout == "Row"
344            and W_layout == "Col"
345            and Y_layout == "Row"
346        )
347
348    def gen_ops(self):
349        """
350        Creates a list of `CKGemmOperation` instances that match the GEMM operation this template represents.
351        The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments.
352
353        An instance may invalidate the GEMM configuration at runtime.
354        Such instances will be assigned +inf runtime by the autotune process.
355        """
356        unfiltered_instances = (
357            gen_ops_preselected()
358            if config.rocm.use_preselected_instances and self._is_rcr_f16()
359            else gen_ops_library()
360        )
361        filtered_instances = list(
362            filter(lambda op: self.filter_op(op), unfiltered_instances)
363        )
364        # NB: when using a fixed list order, most likely we will pick the subset of instances
365        # which are very similar to each other. Randomizing the choice seems to solve this.
366        random.seed(-11)
367        chosen_instances = (
368            random.sample(
369                filtered_instances,
370                min(len(filtered_instances), config.rocm.n_max_profiling_configs),
371            )
372            if config.rocm.n_max_profiling_configs
373            else filtered_instances
374        )
375        log.debug(
376            "generated %d ck instances after filter: %s",
377            len(chosen_instances),
378            chosen_instances,
379        )
380        return chosen_instances
381
382    @staticmethod
383    def add_ck_gemm_choices(
384        choices,
385        layout,
386        input_nodes,
387        alpha=1,
388        beta=0,
389        input_reorder=None,
390    ):
391        """
392        Add Composable Kernel Universal GEMM instance choices to the auto-tuning list.
393        """
394        template = CKGemmTemplate(
395            input_nodes,
396            layout,
397            alpha=alpha,
398            beta=beta,
399            input_reorder=input_reorder,
400        )
401        ops = template.gen_ops()
402        for op in ops:
403            template.maybe_append_choice(
404                choices,
405                op=op,
406            )
407
408    def size_args(self):
409        X = self.input_nodes[0]
410        W = self.input_nodes[1]
411        Bias = self.input_nodes[2] if len(self.input_nodes) > 2 else None
412        Y = self.output_node
413
414        M = X.get_size()[0]
415        K = X.get_size()[1]
416        N = W.get_size()[1]
417        LDA = X.get_stride()[0 if X.get_stride()[1] == 1 else 1]
418        LDB = W.get_stride()[0 if W.get_stride()[1] == 1 else 1]
419        LDC = Y.get_stride()[0 if Y.get_stride()[1] == 1 else 1]
420        LDD = (
421            0
422            if Bias is None
423            else Bias.get_stride()[0 if Bias.get_stride()[1] == 1 else 1]
424        )
425
426        return M, N, K, LDA, LDB, LDC, LDD
427