1# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type" 2import copy 3import logging 4import random 5from typing import List, Optional 6 7import sympy 8 9import torch 10from torch._inductor import config 11from torch._inductor.codegen.rocm.ck_template import CKTemplate 12from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel 13from torch._inductor.ir import Buffer, Layout 14 15from ...utils import IndentedBuffer, try_import_ck_lib 16 17 18_, gen_ops_library, gen_ops_preselected, CKGemmOperation = try_import_ck_lib() 19 20 21log = logging.getLogger(__name__) 22 23 24def is_static_int(number): 25 return isinstance(number, (int, sympy.Integer)) 26 27 28def torch_layout_to_ck_layout(torch_layout): 29 if torch_layout.stride[-1] == 1: 30 return "Row" 31 elif torch_layout.stride[-2] == 1: 32 return "Col" 33 else: 34 return None 35 36 37class CKGemmTemplate(CKTemplate): 38 # the JINJA template for rendering CK Universal GEMMs 39 gemm_template = r"""{{version_comment}} 40 {{headers}} 41 {{globals}} 42 {{instance_definition}} 43 extern "C" { 44 {{kernel_definition}} { 45 auto gemm = {{instance_type}} {}; 46 auto invoker = gemm.MakeInvoker(); 47 48 auto argument = gemm.MakeArgument( 49 reinterpret_cast<const {{a_element_dtype}}*>(X), 50 reinterpret_cast<const {{b_element_dtype}}*>(W), 51 std::array<const void*, {{1 if has_bias else 0}}>{ {{'Bias' if has_bias else ''}} }, 52 reinterpret_cast<{{c_element_dtype}}*>(Y), 53 M, 54 N, 55 K, 56 LDA, 57 LDB, 58 std::array<ck::index_t, {{1 if has_bias else 0}}>{ {{'LDD' if has_bias else ''}} }, 59 LDC, 60 1, // kBatch 61 PassThrough {}, // a_elementwise_op 62 PassThrough {}, // b_elementwise_op 63 {{epilogue}} // c_elementwise_op 64 ); 65 if (!gemm.IsSupportedArgument(argument)) { 66 // we do our best to statically avoid this case in `filter_op` 67 std::cerr << "invalid argument for gemm instance " << gemm.GetTypeString() << std::endl; 68 argument.Print(); 69 return -23; 70 } 71 if (workspace_size) { 72 *workspace_size = gemm.GetWorkSpaceSize(&argument); 73 return 0; 74 } 75 // run the kernel 76 float elapsed_time = invoker.Run(argument, StreamConfig{stream, /* time kernel */ false, /* log level */ kDEBUG_LOG}); 77 return 0; 78 } // kernel definition 79 } // extern C 80 """ 81 82 def __init__( 83 self, 84 input_nodes: List[Buffer], 85 layout: Layout, 86 alpha: float, 87 beta: float, 88 input_reorder: Optional[List[int]] = None, 89 ) -> None: 90 super().__init__( 91 "ck_gemm_template", 92 input_nodes=input_nodes, 93 layout=layout, 94 input_reorder=input_reorder, 95 ) 96 self.alpha = alpha 97 self.beta = beta 98 99 def header(self) -> IndentedBuffer: 100 res = super().header() 101 res.splice( 102 """ 103 // CK GEMM header(s) 104 105 #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" 106 """ 107 ) 108 return res 109 110 def globals(self) -> IndentedBuffer: 111 res = super().globals() 112 res.splice( 113 """ 114 // CK GEMM globals 115 116 using Row = ck::tensor_layout::gemm::RowMajor; 117 using Col = ck::tensor_layout::gemm::ColumnMajor; 118 119 using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler; 120 using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; 121 using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion; 122 """ 123 ) 124 return res 125 126 def filter_op(self, op: "CKGemmOperation"): 127 """ 128 Determines whether a given op definition is suitable for the current 129 input / output of the operation that this template implements. 130 131 Filter is based on inputs' dtype, layout and statically inferred size. 132 133 Returns None if the op is not suitable, otherwise returns the op to be used. 134 """ 135 metas = [T.get_layout() for T in [*self.input_nodes, self.output_node]] 136 X_meta = metas[0] 137 W_meta = metas[1] 138 Y_meta = metas[-1] 139 # disable the instance if dtypes don't match 140 if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]: 141 return None 142 if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]: 143 return None 144 if op.c_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]: 145 return None 146 # disable the instance if layouts don't match 147 if op.a_layout != torch_layout_to_ck_layout(X_meta): 148 return None 149 if op.b_layout != torch_layout_to_ck_layout(W_meta): 150 return None 151 if op.c_layout != torch_layout_to_ck_layout(Y_meta): 152 return None 153 # try to avoid launching the instance with invalid problem size 154 # see GridwiseGemm_xdl_cshuffle_v3::CheckValidity 155 156 M = X_meta.size[-2] 157 K = X_meta.size[-1] 158 N = W_meta.size[-1] 159 160 if is_static_int(M): 161 if not any( 162 m_padding in op.gemm_specialization 163 for m_padding in ["MPadding", "MNPadding", "MKPadding", "MNKPadding"] 164 ): 165 if M % op.m_per_block != 0: 166 return None 167 if is_static_int(N): 168 if not any( 169 n_padding in op.gemm_specialization 170 for n_padding in ["NPadding", "MNPadding", "NKPadding", "MNKPadding"] 171 ): 172 if N % op.n_per_block != 0: 173 return None 174 if is_static_int(K): 175 if not any( 176 k_padding in op.gemm_specialization 177 for k_padding in ["KPadding", "MKPadding", "NKPadding", "MNKPadding"] 178 ): 179 if K % op.k_per_block != 0: 180 return None 181 182 a_contig_size = ( 183 K if op.a_layout == "Row" else M if op.a_layout == "Col" else None 184 ) 185 if ( 186 is_static_int(a_contig_size) 187 and a_contig_size % op.a_block_transfer_src_scalar_per_vector != 0 188 ): 189 return None 190 b_contig_size = ( 191 N if op.b_layout == "Row" else K if op.b_layout == "Col" else None 192 ) 193 if ( 194 is_static_int(b_contig_size) 195 and b_contig_size % op.b_block_transfer_src_scalar_per_vector != 0 196 ): 197 return None 198 c_contig_size = ( 199 N if op.c_layout == "Row" else M if op.c_layout == "Col" else None 200 ) 201 if ( 202 is_static_int(c_contig_size) 203 and c_contig_size 204 % op.c_shuffle_block_transfer_scalar_per_vector_n_per_block 205 != 0 206 ): 207 return None 208 209 # TBD disable instances with invalid number of pipeline prefetch stages 210 # It will avoid compiling a small percentage of unrunnable instances which fail the gemm argument check 211 212 return op 213 214 def emit_ck_instance(self, op: "CKGemmOperation"): 215 # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance 216 template_definition = r""" 217 // Gemm operator {{operation_name}} 218 using Operation_{{operation_name}} = 219 ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< 220 {{template_params}}>; 221 222""" 223 # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance 224 template_type = r""" 225 Operation_{{operation_name}} 226""" 227 template_params = [] 228 for field_name, field_value in op.dict_items(): 229 if isinstance(field_value, tuple): 230 tuple_elements = ", ".join(map(str, iter(field_value))) 231 if "ds" in field_name: # element type and layout for bias 232 arg = f"/* {field_name} */ Tuple<{tuple_elements}>" 233 else: # tile shape 234 arg = f"/* {field_name} */ S<{tuple_elements}>" 235 template_params.append(arg) 236 else: 237 if field_value is not None: 238 template_params.append(f"/* {field_name} */ {field_value}") 239 return self._template_from_string(template_definition).render( 240 operation_name=op.name(), 241 template_params=(",\n" + 12 * " ").join(template_params), 242 ), self._template_from_string(template_type).render(operation_name=op.name()) 243 244 def render(self, kernel: ROCmTemplateKernel, op: "CKGemmOperation", **kwargs) -> str: # type: ignore[override] 245 """ 246 The primary entry point for the code rendering process used in this template. 247 """ 248 epilogue_nodes = kwargs.get("epilogue_nodes", None) 249 assert epilogue_nodes is None or 0 == len(epilogue_nodes) 250 template_buffer_node = kwargs.get("template_buffer_node", None) 251 if template_buffer_node is not None: 252 self.output_node = template_buffer_node 253 X, W = self.input_nodes[0], self.input_nodes[1] 254 Y = self.output_node 255 Bias = self.input_nodes[2] if 3 == len(self.input_nodes) else None 256 257 op = copy.deepcopy(op) 258 259 # This parameter is converted into tuple because of change 260 # from DeviceGemm_Xdl_CShuffleV3 to DeviceGemmMultiD_Xdl_CShuffle_V3. 261 # The first tuple element corresponds to matmul result... 262 op.c_shuffle_block_transfer_scalar_per_vector_n_per_block = ( 263 op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, 264 ) 265 266 if Bias is not None: 267 op.ds_layouts = (torch_layout_to_ck_layout(Bias.get_layout()),) 268 op.ds_element_dtypes = ((self._TORCH_DTYPE_TO_CK[Bias.get_layout().dtype]),) 269 op.c_elementwise_op = "Bilinear" 270 # c_shuffle_dtype is also used for adding bias to matmul result 271 # before converting down to the result dtype 272 op.c_shuffle_dtype = op.acc_dtype 273 # this parameter needs to be set accordingly to bias stride for correct accumulation 274 if op.ds_layouts[0] == "Row": 275 # bias has (N, ) shape 276 bias_shuffle_block_transfer_scalar_per_vector_n_per_block = ( 277 op.c_shuffle_block_transfer_scalar_per_vector_n_per_block 278 ) 279 else: 280 # bias has (M, 1) shape 281 bias_shuffle_block_transfer_scalar_per_vector_n_per_block = (1,) 282 # ...and the second tuple element corresponds to the bias 283 op.c_shuffle_block_transfer_scalar_per_vector_n_per_block += ( 284 bias_shuffle_block_transfer_scalar_per_vector_n_per_block 285 ) 286 287 instance_definition, instance_type = self.emit_ck_instance(op) 288 289 version_comment = rf"""/** 290* Generated code for CK inductor backend 291* See {type(self).__module__}.{type(self).__qualname__} 292* 293* Template instance {op} 294* 295* {torch.__version__=} 296* {torch.version.git_version=} 297*/ 298""" 299 300 return self._template_from_string(self.gemm_template).render( 301 headers=self.header().getvalue(), 302 globals=self.globals().getvalue(), 303 instance_definition=instance_definition, 304 kernel_definition=kernel.def_kernel( 305 inputs=[X, W, Bias], # type: ignore[list-item] 306 outputs=[Y], 307 names_str="X, W, Bias, Y", 308 input_reorder=self.input_reorder, 309 size_args=[ 310 f"ck::index_t {arg}" 311 for arg in ["M", "N", "K", "LDA", "LDB", "LDC", "LDD"] 312 ], 313 ), 314 instance_type=instance_type, 315 a_element_dtype=op.a_element_dtype, 316 b_element_dtype=op.b_element_dtype, 317 c_element_dtype=op.c_element_dtype, 318 bias_element_dtype=op.ds_element_dtypes[0] if Bias is not None else "", 319 alpha=self.alpha, 320 beta=self.beta, 321 epilogue=f"Bilinear {{ {self.alpha}, {self.beta} }}" 322 if Bias is not None 323 else "PassThrough {}", 324 has_bias=Bias is not None, 325 version_comment=version_comment, 326 ) 327 328 def _is_rcr_f16(self): 329 X_meta, W_meta, Y_meta = ( 330 T.get_layout() for T in [*self.input_nodes, self.output_node] 331 ) 332 X_dtype, W_dtype, Y_dtype = ( 333 self._TORCH_DTYPE_TO_CK[m.dtype] for m in (X_meta, W_meta, Y_meta) 334 ) 335 X_layout, W_layout, Y_layout = ( 336 torch_layout_to_ck_layout(m) for m in (X_meta, W_meta, Y_meta) 337 ) 338 339 return ( 340 X_dtype == "F16" 341 and W_dtype == "F16" 342 and Y_dtype == "F16" 343 and X_layout == "Row" 344 and W_layout == "Col" 345 and Y_layout == "Row" 346 ) 347 348 def gen_ops(self): 349 """ 350 Creates a list of `CKGemmOperation` instances that match the GEMM operation this template represents. 351 The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments. 352 353 An instance may invalidate the GEMM configuration at runtime. 354 Such instances will be assigned +inf runtime by the autotune process. 355 """ 356 unfiltered_instances = ( 357 gen_ops_preselected() 358 if config.rocm.use_preselected_instances and self._is_rcr_f16() 359 else gen_ops_library() 360 ) 361 filtered_instances = list( 362 filter(lambda op: self.filter_op(op), unfiltered_instances) 363 ) 364 # NB: when using a fixed list order, most likely we will pick the subset of instances 365 # which are very similar to each other. Randomizing the choice seems to solve this. 366 random.seed(-11) 367 chosen_instances = ( 368 random.sample( 369 filtered_instances, 370 min(len(filtered_instances), config.rocm.n_max_profiling_configs), 371 ) 372 if config.rocm.n_max_profiling_configs 373 else filtered_instances 374 ) 375 log.debug( 376 "generated %d ck instances after filter: %s", 377 len(chosen_instances), 378 chosen_instances, 379 ) 380 return chosen_instances 381 382 @staticmethod 383 def add_ck_gemm_choices( 384 choices, 385 layout, 386 input_nodes, 387 alpha=1, 388 beta=0, 389 input_reorder=None, 390 ): 391 """ 392 Add Composable Kernel Universal GEMM instance choices to the auto-tuning list. 393 """ 394 template = CKGemmTemplate( 395 input_nodes, 396 layout, 397 alpha=alpha, 398 beta=beta, 399 input_reorder=input_reorder, 400 ) 401 ops = template.gen_ops() 402 for op in ops: 403 template.maybe_append_choice( 404 choices, 405 op=op, 406 ) 407 408 def size_args(self): 409 X = self.input_nodes[0] 410 W = self.input_nodes[1] 411 Bias = self.input_nodes[2] if len(self.input_nodes) > 2 else None 412 Y = self.output_node 413 414 M = X.get_size()[0] 415 K = X.get_size()[1] 416 N = W.get_size()[1] 417 LDA = X.get_stride()[0 if X.get_stride()[1] == 1 else 1] 418 LDB = W.get_stride()[0 if W.get_stride()[1] == 1 else 1] 419 LDC = Y.get_stride()[0 if Y.get_stride()[1] == 1 else 1] 420 LDD = ( 421 0 422 if Bias is None 423 else Bias.get_stride()[0 if Bias.get_stride()[1] == 1 else 1] 424 ) 425 426 return M, N, K, LDA, LDB, LDC, LDD 427