xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/rocm/ck_template.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate
3from torch._inductor.ir import IRNode
4from torch._inductor.utils import IndentedBuffer
5
6
7class CKTemplate(ROCmTemplate):
8    """
9    Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic
10    """
11
12    _TORCH_DTYPE_TO_CK = {
13        torch.float32: "F32",
14        torch.float64: "F64",
15        torch.float16: "F16",
16        torch.bfloat16: "BF16",
17        torch.int32: "I32",
18        torch.int8: "I8",
19        torch.float8_e4m3fnuz: "F8",
20        torch.float8_e5m2fnuz: "BF8",
21    }
22
23    def header(self) -> IndentedBuffer:
24        res = super().header()
25        res.splice(
26            """
27                // HIP headers
28
29                #include <hip/hip_bfloat16.h>
30
31                // CK headers
32
33                #ifdef DEBUG_LOG
34                #define DEBUG_LOG_TMP DEBUG_LOG
35                #undef DEBUG_LOG
36                #else
37                #define DEBUG_LOG_TMP 0
38                #endif
39                #include "ck/ck.hpp"
40                #undef DEBUG_LOG
41                #define DEBUG_LOG DEBUG_LOG_TMP
42
43                #include "ck/utility/data_type.hpp"
44                #include "ck/library/utility/check_err.hpp"
45                #include "ck/library/utility/device_memory.hpp"
46                #include "ck/library/utility/fill.hpp"
47                #include "ck/library/utility/host_tensor.hpp"
48                #include "ck/library/utility/host_tensor_generator.hpp"
49                #include "ck/library/utility/literals.hpp"
50            """
51        )
52        return res
53
54    def globals(self) -> IndentedBuffer:
55        res = super().globals()
56        res.splice(
57            """
58                // CK globals
59
60                template <ck::index_t... Is>
61                using S = ck::Sequence<Is...>;
62
63                template<typename... Ts>
64                using Tuple = ck::Tuple<Ts...>;
65
66                using PassThrough = ck::tensor_operation::element_wise::PassThrough;
67                using Bilinear = ck::tensor_operation::element_wise::Bilinear;
68
69                // see "composable_kernel/include/ck/utility/data_type.hpp"
70                using F8  = ck::f8_t;
71                using BF8 = ck::bf8_t;
72                using F16 = ck::half_t;
73                using F32 = float;
74                // using F64 = double;
75                using BF16 = ck::bhalf_t;
76                // using I32 = int32_t;
77                // using I8 = int8_t;
78                // using I4 = ck::int4_t;
79
80                #if DEBUG_LOG
81                static constexpr auto kDEBUG_LOG = 1;
82                #else
83                static constexpr auto kDEBUG_LOG = 0;
84                #endif
85            """
86        )
87        return res
88
89    def torch_type_to_ck(self, node: IRNode, ptr: str) -> str:
90        if node is None:
91            return ptr
92        else:
93            return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})"
94