1# mypy: allow-untyped-defs 2import functools 3import os 4from itertools import chain, count 5from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union 6 7import sympy 8 9from torch import dtype as torch_dtype 10from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name 11from torch._inductor.runtime.triton_heuristics import grid as default_grid 12 13from .. import config 14from ..codecache import CudaKernelParamCache 15from ..utils import DeferredLineBase 16from ..virtualized import V 17from .aoti_hipify_utils import maybe_hipify_code_wrapper 18from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header 19from .cpp_utils import cexpr, DTYPE_TO_CPP 20from .cpp_wrapper_cpu import CppWrapperCpu 21from .wrapper import SymbolicCallArg 22 23 24if TYPE_CHECKING: 25 from ..graph import GraphLowering 26 27 28class DeferredCudaKernelLine(DeferredLineBase): 29 """ 30 When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels 31 to be tuned and stored as cubin files, so use a deferred line to backfill those information 32 """ 33 34 def __init__( 35 self, 36 kernel_name: str, 37 line_template: str, 38 keys: Tuple[str, ...], 39 ): 40 super().__init__(line_template) 41 assert not isinstance(line_template, DeferredLineBase) 42 self.kernel_name = kernel_name 43 self.line_template = line_template 44 self.keys = keys 45 46 def __call__(self): 47 params = CudaKernelParamCache.get(self.kernel_name) 48 assert ( 49 params is not None 50 ), f"{self.kernel_name} not found in CudaKernelParamCache" 51 for key in self.keys: 52 assert ( 53 key in params 54 ), f"{key} not found in CudaKernelParamCache[{self.kernel_name}]" 55 if key == get_cpp_wrapper_cubin_path_name(): 56 assert os.path.exists(params[key]), f"{params[key]} does not exist" 57 58 return self.line_template % tuple(params[key] for key in self.keys) 59 60 def _new_line(self, line): 61 return DeferredCudaKernelLine(self.kernel_name, line, self.keys) 62 63 64class DeferredCudaDefaultGrid: 65 """ 66 A container for the default grid, which may be used by DeferredCudaGridLine 67 """ 68 69 def __init__( 70 self, 71 kernel_name: str, 72 grid, 73 grid_callable: Optional[Callable[..., Any]] = None, 74 **grid_extra_kwargs, 75 ): 76 self.kernel_name = kernel_name 77 self.grid = grid 78 self.grid_callable = grid_callable 79 self.grid_extra_kwargs = grid_extra_kwargs 80 81 def _process_grid(self, grid: Union[List[Any], Tuple[Any, ...]]): 82 if isinstance(grid, (list, tuple)): 83 return [self._process_grid(e) for e in grid] 84 else: 85 return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid 86 87 def __call__(self): 88 grid = self.grid 89 assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list" 90 grid = self._process_grid(grid) 91 grid_callable = self.grid_callable or default_grid 92 if not self.grid_extra_kwargs: 93 grid_fn = grid_callable(*grid) 94 else: 95 grid_fn = grid_callable(*grid, **self.grid_extra_kwargs) 96 97 params = CudaKernelParamCache.get(self.kernel_name) 98 assert ( 99 params is not None 100 ), f"{self.kernel_name} not found in CudaKernelParamCache" 101 block_cfg = { 102 "XBLOCK": params["x_block"], 103 "YBLOCK": params["y_block"], 104 "ZBLOCK": params["z_block"], 105 } 106 return grid_fn(block_cfg) 107 108 109class DeferredCudaGridLine(DeferredLineBase): 110 """ 111 When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels 112 to be tuned and stored as cubin files, so use a deferred line to backfill those information 113 """ 114 115 def __init__( 116 self, 117 kernel_name: str, 118 grid_var: str, 119 grid, 120 autotune_configs, 121 ): 122 super().__init__("") 123 self.kernel_name = kernel_name 124 self.grid_var = grid_var 125 self.grid = grid 126 self.autotune_configs = autotune_configs 127 128 def __call__(self): 129 params = CudaKernelParamCache.get(self.kernel_name) 130 assert ( 131 params is not None 132 ), f"{self.kernel_name} not found in CudaKernelParamCache" 133 134 if self.autotune_configs is not None: 135 # This indicates the Triton kernel is a user-defined one. 136 grid = None 137 if len(self.grid) == 1: 138 grid = self.grid[0] 139 else: 140 for i, c in enumerate(self.autotune_configs): 141 if all(arg == params["meta"][key] for key, arg in c.kwargs.items()): 142 grid = self.grid[i] 143 break 144 assert grid is not None 145 elif isinstance(self.grid, DeferredCudaDefaultGrid): 146 grid = self.grid() 147 else: 148 grid = self.grid 149 150 assert len(grid) != 0, "Grid can't be empty" 151 grid_args_str = ", ".join( 152 [cexpr(V.graph.sizevars.simplify(item)) for item in grid] 153 ) 154 return f" Grid {self.grid_var} = Grid({grid_args_str});" 155 156 def _new_line(self, line): 157 return DeferredCudaGridLine( 158 self.kernel_name, self.grid_var, self.grid, self.autotune_configs 159 ) 160 161 162class CppWrapperCuda(CppWrapperCpu): 163 """ 164 Generates cpp wrapper for running on GPU and calls CUDA kernels 165 """ 166 167 def __init__(self) -> None: 168 self.device = "cuda" 169 super().__init__() 170 self.grid_id = count() 171 self.cuda = True 172 173 def write_header(self): 174 if V.graph.is_const_graph: 175 # We do not write header for constant graph, it will be written by main module. 176 return 177 178 super().write_header() 179 180 self.header.splice("#include <filesystem>") 181 if config.abi_compatible: 182 self.header.splice( 183 "#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>" 184 ) 185 else: 186 self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_header())) 187 self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_driver())) 188 189 def write_get_raw_stream(self, index, graph=None): 190 name = f"stream{index}" 191 self.writeline(maybe_hipify_code_wrapper(f"cudaStream_t {name};")) 192 self.writeline( 193 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));" 194 ) 195 return name 196 197 def define_kernel( 198 self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True 199 ): 200 if not cuda: 201 return super().define_kernel(name, kernel, metadata, cuda) 202 203 def generate(self, is_inference): 204 self.prefix.writeline("\n") 205 if not V.graph.aot_mode: 206 for kernel in chain( 207 sorted(self.src_to_kernel.values()), 208 sorted([entry[0] for entry in self.user_defined_kernel_cache.values()]), 209 ): 210 self.prefix.writeline( 211 maybe_hipify_code_wrapper(f"static CUfunction {kernel} = nullptr;") 212 ) 213 self.prefix.writeline("\n") 214 return super().generate(is_inference) 215 216 def generate_user_defined_triton_kernel( 217 self, 218 kernel_name: str, 219 raw_args: List[Any], 220 grid: List[Any], 221 configs, 222 triton_meta, 223 constexprs, 224 ): 225 # in C++ wrapper, we don't pass constexpr args, as they don't 226 # get added as parameters to the PTX code compiled from the 227 # user-defined Triton kernel (only non-constexpr args do) 228 raw_args = [ 229 raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs 230 ] 231 args = [self.val_to_arg_str(v) for v in raw_args] 232 arg_types = [ 233 arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg) 234 for arg in raw_args 235 ] 236 self.generate_kernel_call( 237 kernel_name, 238 args, 239 arg_types=arg_types, 240 raw_args=raw_args, 241 grid=grid, 242 cuda=True, 243 triton=True, 244 triton_meta=triton_meta, 245 autotune_configs=configs, 246 ) 247 248 @functools.lru_cache(None) # noqa: B019 249 def generate_load_kernel_once( 250 self, 251 kernel_name: str, 252 graph: "GraphLowering", # for per-graph caching 253 ): 254 keys = (get_cpp_wrapper_cubin_path_name(), "mangled_name", "shared_mem") 255 kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name 256 self.writeline(f"if ({kernel_var_name} == nullptr) {{") 257 self.writeline( 258 DeferredCudaKernelLine( 259 kernel_name, 260 """ """ 261 + kernel_var_name 262 + """ = loadKernel("%s", "%s", %s, this->cubin_dir_);""" 263 if V.graph.aot_mode 264 else """ """ 265 + kernel_var_name 266 + """ = loadKernel("%s", "%s", %s);""", 267 keys, 268 ) 269 ) 270 self.writeline("}") 271 return kernel_var_name 272 273 def generate_args_decl(self, call_args, arg_types): 274 new_args = [] 275 for arg, arg_type in zip(call_args, arg_types): 276 var_name = f"var_{next(self.arg_var_id)}" 277 if isinstance(arg_type, torch_dtype): 278 if arg.endswith(".item()"): 279 # Need to declare a scalar in this case 280 ctype = DTYPE_TO_CPP[arg_type] 281 arg = arg[:-7] 282 if config.abi_compatible: 283 self.codegen_tensor_item( 284 arg_type, 285 arg, 286 var_name, 287 ) 288 else: 289 from torch import bfloat16, float16 290 291 if arg_type in (float16, bfloat16): 292 var_name_tmp = f"{var_name}_tmp" 293 self.writeline( 294 f"{ctype} {var_name_tmp} = {arg}.item<{ctype}>();" 295 ) 296 self.writeline(f"float {var_name} = float({var_name_tmp});") 297 else: 298 self.writeline( 299 f"{ctype} {var_name} = {arg}.item<{ctype}>();" 300 ) 301 else: 302 if config.abi_compatible: 303 self.writeline( 304 maybe_hipify_code_wrapper(f"CUdeviceptr {var_name};") 305 ) 306 self.writeline( 307 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));" 308 ) 309 else: 310 self.writeline( 311 maybe_hipify_code_wrapper( 312 f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());" 313 ) 314 ) 315 elif arg_type in (sympy.Integer, int): 316 self.writeline(f"int {var_name} = {self.expr_printer(arg)};") 317 elif arg_type in (sympy.Float, float): 318 self.writeline(f"float {var_name} = {self.expr_printer(arg)};") 319 else: 320 self.writeline(f"auto {var_name} = {self.expr_printer(arg)};") 321 new_args.append(f"&{var_name}") 322 323 return ", ".join(new_args) 324 325 def generate_default_grid( 326 self, 327 kernel_name: str, 328 grid: List[Any], 329 cuda: bool = True, 330 grid_callable: Optional[Callable[..., Any]] = None, 331 **grid_extra_kwargs, 332 ): 333 """ 334 Generate grid configs for launching a CUDA kernel using the grid 335 function from triton_heuristics. Because its computation needs 336 to read kernel config after autotune, it is done in a deferred way 337 using DeferredCudaDefaultGrid. 338 """ 339 if not cuda: 340 return grid 341 return DeferredCudaDefaultGrid( 342 kernel_name, grid, grid_callable, **grid_extra_kwargs 343 ) 344 345 def generate_kernel_call( 346 self, 347 kernel_name: str, 348 call_args, 349 grid=None, 350 device_index=None, 351 cuda=True, 352 triton=True, 353 arg_types=None, 354 raw_args=None, 355 grid_fn: str = "grid", 356 triton_meta=None, 357 autotune_configs=None, 358 grid_extra_kwargs="", 359 ): 360 assert arg_types is not None and len(call_args) == len( 361 arg_types 362 ), "call_args and arg_types do not match" 363 364 if not cuda: 365 # Even in CppWrapperCuda, we may see cpp kernels 366 return super().generate_kernel_call( 367 kernel_name, 368 call_args, 369 grid, 370 device_index, 371 cuda, 372 triton, 373 arg_types, 374 raw_args, 375 grid_fn, 376 triton_meta, 377 autotune_configs, 378 grid_extra_kwargs, 379 ) 380 381 device_index, call_args = self.prepare_triton_kernel_call( 382 device_index, call_args 383 ) 384 kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph) 385 386 # args with value 1 are added into equal_to_1 and constants 387 # in triton_meta (in the Python codegen) which makes them 388 # inlined in the PTX and compiled CUBIN 389 if ( 390 triton_meta is not None 391 and "configs" in triton_meta 392 and triton_meta["configs"] 393 ): 394 equal_to_1 = triton_meta["configs"][0].equal_to_1 395 call_args = [arg for i, arg in enumerate(call_args) if i not in equal_to_1] 396 arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1] 397 398 call_args_str = self.generate_args_decl(call_args, arg_types) 399 kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}" 400 self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};") 401 stream = ( 402 "stream" 403 if V.graph.aot_mode 404 else self.write_get_raw_stream(device_index, V.graph) 405 ) 406 407 grid_var = f"{kernel_name}_grid_{next(self.grid_id)}" 408 self.writeline( 409 DeferredCudaGridLine(kernel_name, grid_var, grid, autotune_configs) 410 ) 411 412 kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name 413 # add debug printer code for all triton kernel related calls 414 debug_printer_manager = V.graph.wrapper_code.debug_printer 415 debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None) 416 with debug_printer_manager: 417 self.writeline(f"if ({grid_var}.is_non_zero()) {{") 418 self.writeline( 419 DeferredCudaKernelLine( 420 kernel_name, 421 r" launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format( 422 kernel_var_name, 423 f"{grid_var}.grid_x", 424 f"{grid_var}.grid_y", 425 f"{grid_var}.grid_z", 426 kernel_args_var, 427 stream, 428 ), 429 ("num_warps", "shared_mem"), 430 ), 431 ) 432 self.writeline("}") 433