1import torch 2from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate 3from torch._inductor.ir import IRNode 4from torch._inductor.utils import IndentedBuffer 5 6 7class CKTemplate(ROCmTemplate): 8 """ 9 Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic 10 """ 11 12 _TORCH_DTYPE_TO_CK = { 13 torch.float32: "F32", 14 torch.float64: "F64", 15 torch.float16: "F16", 16 torch.bfloat16: "BF16", 17 torch.int32: "I32", 18 torch.int8: "I8", 19 torch.float8_e4m3fnuz: "F8", 20 torch.float8_e5m2fnuz: "BF8", 21 } 22 23 def header(self) -> IndentedBuffer: 24 res = super().header() 25 res.splice( 26 """ 27 // HIP headers 28 29 #include <hip/hip_bfloat16.h> 30 31 // CK headers 32 33 #ifdef DEBUG_LOG 34 #define DEBUG_LOG_TMP DEBUG_LOG 35 #undef DEBUG_LOG 36 #else 37 #define DEBUG_LOG_TMP 0 38 #endif 39 #include "ck/ck.hpp" 40 #undef DEBUG_LOG 41 #define DEBUG_LOG DEBUG_LOG_TMP 42 43 #include "ck/utility/data_type.hpp" 44 #include "ck/library/utility/check_err.hpp" 45 #include "ck/library/utility/device_memory.hpp" 46 #include "ck/library/utility/fill.hpp" 47 #include "ck/library/utility/host_tensor.hpp" 48 #include "ck/library/utility/host_tensor_generator.hpp" 49 #include "ck/library/utility/literals.hpp" 50 """ 51 ) 52 return res 53 54 def globals(self) -> IndentedBuffer: 55 res = super().globals() 56 res.splice( 57 """ 58 // CK globals 59 60 template <ck::index_t... Is> 61 using S = ck::Sequence<Is...>; 62 63 template<typename... Ts> 64 using Tuple = ck::Tuple<Ts...>; 65 66 using PassThrough = ck::tensor_operation::element_wise::PassThrough; 67 using Bilinear = ck::tensor_operation::element_wise::Bilinear; 68 69 // see "composable_kernel/include/ck/utility/data_type.hpp" 70 using F8 = ck::f8_t; 71 using BF8 = ck::bf8_t; 72 using F16 = ck::half_t; 73 using F32 = float; 74 // using F64 = double; 75 using BF16 = ck::bhalf_t; 76 // using I32 = int32_t; 77 // using I8 = int8_t; 78 // using I4 = ck::int4_t; 79 80 #if DEBUG_LOG 81 static constexpr auto kDEBUG_LOG = 1; 82 #else 83 static constexpr auto kDEBUG_LOG = 0; 84 #endif 85 """ 86 ) 87 return res 88 89 def torch_type_to_ck(self, node: IRNode, ptr: str) -> str: 90 if node is None: 91 return ptr 92 else: 93 return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})" 94