xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/cpp_wrapper_cuda.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import os
4from itertools import chain, count
5from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union
6
7import sympy
8
9from torch import dtype as torch_dtype
10from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
11from torch._inductor.runtime.triton_heuristics import grid as default_grid
12
13from .. import config
14from ..codecache import CudaKernelParamCache
15from ..utils import DeferredLineBase
16from ..virtualized import V
17from .aoti_hipify_utils import maybe_hipify_code_wrapper
18from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header
19from .cpp_utils import cexpr, DTYPE_TO_CPP
20from .cpp_wrapper_cpu import CppWrapperCpu
21from .wrapper import SymbolicCallArg
22
23
24if TYPE_CHECKING:
25    from ..graph import GraphLowering
26
27
28class DeferredCudaKernelLine(DeferredLineBase):
29    """
30    When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels
31    to be tuned and stored as cubin files, so use a deferred line to backfill those information
32    """
33
34    def __init__(
35        self,
36        kernel_name: str,
37        line_template: str,
38        keys: Tuple[str, ...],
39    ):
40        super().__init__(line_template)
41        assert not isinstance(line_template, DeferredLineBase)
42        self.kernel_name = kernel_name
43        self.line_template = line_template
44        self.keys = keys
45
46    def __call__(self):
47        params = CudaKernelParamCache.get(self.kernel_name)
48        assert (
49            params is not None
50        ), f"{self.kernel_name} not found in CudaKernelParamCache"
51        for key in self.keys:
52            assert (
53                key in params
54            ), f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
55            if key == get_cpp_wrapper_cubin_path_name():
56                assert os.path.exists(params[key]), f"{params[key]} does not exist"
57
58        return self.line_template % tuple(params[key] for key in self.keys)
59
60    def _new_line(self, line):
61        return DeferredCudaKernelLine(self.kernel_name, line, self.keys)
62
63
64class DeferredCudaDefaultGrid:
65    """
66    A container for the default grid, which may be used by DeferredCudaGridLine
67    """
68
69    def __init__(
70        self,
71        kernel_name: str,
72        grid,
73        grid_callable: Optional[Callable[..., Any]] = None,
74        **grid_extra_kwargs,
75    ):
76        self.kernel_name = kernel_name
77        self.grid = grid
78        self.grid_callable = grid_callable
79        self.grid_extra_kwargs = grid_extra_kwargs
80
81    def _process_grid(self, grid: Union[List[Any], Tuple[Any, ...]]):
82        if isinstance(grid, (list, tuple)):
83            return [self._process_grid(e) for e in grid]
84        else:
85            return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid
86
87    def __call__(self):
88        grid = self.grid
89        assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
90        grid = self._process_grid(grid)
91        grid_callable = self.grid_callable or default_grid
92        if not self.grid_extra_kwargs:
93            grid_fn = grid_callable(*grid)
94        else:
95            grid_fn = grid_callable(*grid, **self.grid_extra_kwargs)
96
97        params = CudaKernelParamCache.get(self.kernel_name)
98        assert (
99            params is not None
100        ), f"{self.kernel_name} not found in CudaKernelParamCache"
101        block_cfg = {
102            "XBLOCK": params["x_block"],
103            "YBLOCK": params["y_block"],
104            "ZBLOCK": params["z_block"],
105        }
106        return grid_fn(block_cfg)
107
108
109class DeferredCudaGridLine(DeferredLineBase):
110    """
111    When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels
112    to be tuned and stored as cubin files, so use a deferred line to backfill those information
113    """
114
115    def __init__(
116        self,
117        kernel_name: str,
118        grid_var: str,
119        grid,
120        autotune_configs,
121    ):
122        super().__init__("")
123        self.kernel_name = kernel_name
124        self.grid_var = grid_var
125        self.grid = grid
126        self.autotune_configs = autotune_configs
127
128    def __call__(self):
129        params = CudaKernelParamCache.get(self.kernel_name)
130        assert (
131            params is not None
132        ), f"{self.kernel_name} not found in CudaKernelParamCache"
133
134        if self.autotune_configs is not None:
135            # This indicates the Triton kernel is a user-defined one.
136            grid = None
137            if len(self.grid) == 1:
138                grid = self.grid[0]
139            else:
140                for i, c in enumerate(self.autotune_configs):
141                    if all(arg == params["meta"][key] for key, arg in c.kwargs.items()):
142                        grid = self.grid[i]
143                        break
144            assert grid is not None
145        elif isinstance(self.grid, DeferredCudaDefaultGrid):
146            grid = self.grid()
147        else:
148            grid = self.grid
149
150        assert len(grid) != 0, "Grid can't be empty"
151        grid_args_str = ", ".join(
152            [cexpr(V.graph.sizevars.simplify(item)) for item in grid]
153        )
154        return f"    Grid {self.grid_var} = Grid({grid_args_str});"
155
156    def _new_line(self, line):
157        return DeferredCudaGridLine(
158            self.kernel_name, self.grid_var, self.grid, self.autotune_configs
159        )
160
161
162class CppWrapperCuda(CppWrapperCpu):
163    """
164    Generates cpp wrapper for running on GPU and calls CUDA kernels
165    """
166
167    def __init__(self) -> None:
168        self.device = "cuda"
169        super().__init__()
170        self.grid_id = count()
171        self.cuda = True
172
173    def write_header(self):
174        if V.graph.is_const_graph:
175            # We do not write header for constant graph, it will be written by main module.
176            return
177
178        super().write_header()
179
180        self.header.splice("#include <filesystem>")
181        if config.abi_compatible:
182            self.header.splice(
183                "#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
184            )
185        else:
186            self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_header()))
187        self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_driver()))
188
189    def write_get_raw_stream(self, index, graph=None):
190        name = f"stream{index}"
191        self.writeline(maybe_hipify_code_wrapper(f"cudaStream_t {name};"))
192        self.writeline(
193            f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));"
194        )
195        return name
196
197    def define_kernel(
198        self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
199    ):
200        if not cuda:
201            return super().define_kernel(name, kernel, metadata, cuda)
202
203    def generate(self, is_inference):
204        self.prefix.writeline("\n")
205        if not V.graph.aot_mode:
206            for kernel in chain(
207                sorted(self.src_to_kernel.values()),
208                sorted([entry[0] for entry in self.user_defined_kernel_cache.values()]),
209            ):
210                self.prefix.writeline(
211                    maybe_hipify_code_wrapper(f"static CUfunction {kernel} = nullptr;")
212                )
213            self.prefix.writeline("\n")
214        return super().generate(is_inference)
215
216    def generate_user_defined_triton_kernel(
217        self,
218        kernel_name: str,
219        raw_args: List[Any],
220        grid: List[Any],
221        configs,
222        triton_meta,
223        constexprs,
224    ):
225        # in C++ wrapper, we don't pass constexpr args, as they don't
226        # get added as parameters to the PTX code compiled from the
227        # user-defined Triton kernel (only non-constexpr args do)
228        raw_args = [
229            raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs
230        ]
231        args = [self.val_to_arg_str(v) for v in raw_args]
232        arg_types = [
233            arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg)
234            for arg in raw_args
235        ]
236        self.generate_kernel_call(
237            kernel_name,
238            args,
239            arg_types=arg_types,
240            raw_args=raw_args,
241            grid=grid,
242            cuda=True,
243            triton=True,
244            triton_meta=triton_meta,
245            autotune_configs=configs,
246        )
247
248    @functools.lru_cache(None)  # noqa: B019
249    def generate_load_kernel_once(
250        self,
251        kernel_name: str,
252        graph: "GraphLowering",  # for per-graph caching
253    ):
254        keys = (get_cpp_wrapper_cubin_path_name(), "mangled_name", "shared_mem")
255        kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
256        self.writeline(f"if ({kernel_var_name} == nullptr) {{")
257        self.writeline(
258            DeferredCudaKernelLine(
259                kernel_name,
260                """    """
261                + kernel_var_name
262                + """ = loadKernel("%s", "%s", %s, this->cubin_dir_);"""
263                if V.graph.aot_mode
264                else """    """
265                + kernel_var_name
266                + """ = loadKernel("%s", "%s", %s);""",
267                keys,
268            )
269        )
270        self.writeline("}")
271        return kernel_var_name
272
273    def generate_args_decl(self, call_args, arg_types):
274        new_args = []
275        for arg, arg_type in zip(call_args, arg_types):
276            var_name = f"var_{next(self.arg_var_id)}"
277            if isinstance(arg_type, torch_dtype):
278                if arg.endswith(".item()"):
279                    # Need to declare a scalar in this case
280                    ctype = DTYPE_TO_CPP[arg_type]
281                    arg = arg[:-7]
282                    if config.abi_compatible:
283                        self.codegen_tensor_item(
284                            arg_type,
285                            arg,
286                            var_name,
287                        )
288                    else:
289                        from torch import bfloat16, float16
290
291                        if arg_type in (float16, bfloat16):
292                            var_name_tmp = f"{var_name}_tmp"
293                            self.writeline(
294                                f"{ctype} {var_name_tmp} = {arg}.item<{ctype}>();"
295                            )
296                            self.writeline(f"float {var_name} = float({var_name_tmp});")
297                        else:
298                            self.writeline(
299                                f"{ctype} {var_name} = {arg}.item<{ctype}>();"
300                            )
301                else:
302                    if config.abi_compatible:
303                        self.writeline(
304                            maybe_hipify_code_wrapper(f"CUdeviceptr {var_name};")
305                        )
306                        self.writeline(
307                            f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
308                        )
309                    else:
310                        self.writeline(
311                            maybe_hipify_code_wrapper(
312                                f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
313                            )
314                        )
315            elif arg_type in (sympy.Integer, int):
316                self.writeline(f"int {var_name} = {self.expr_printer(arg)};")
317            elif arg_type in (sympy.Float, float):
318                self.writeline(f"float {var_name} = {self.expr_printer(arg)};")
319            else:
320                self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
321            new_args.append(f"&{var_name}")
322
323        return ", ".join(new_args)
324
325    def generate_default_grid(
326        self,
327        kernel_name: str,
328        grid: List[Any],
329        cuda: bool = True,
330        grid_callable: Optional[Callable[..., Any]] = None,
331        **grid_extra_kwargs,
332    ):
333        """
334        Generate grid configs for launching a CUDA kernel using the grid
335        function from triton_heuristics. Because its computation needs
336        to read kernel config after autotune, it is done in a deferred way
337        using DeferredCudaDefaultGrid.
338        """
339        if not cuda:
340            return grid
341        return DeferredCudaDefaultGrid(
342            kernel_name, grid, grid_callable, **grid_extra_kwargs
343        )
344
345    def generate_kernel_call(
346        self,
347        kernel_name: str,
348        call_args,
349        grid=None,
350        device_index=None,
351        cuda=True,
352        triton=True,
353        arg_types=None,
354        raw_args=None,
355        grid_fn: str = "grid",
356        triton_meta=None,
357        autotune_configs=None,
358        grid_extra_kwargs="",
359    ):
360        assert arg_types is not None and len(call_args) == len(
361            arg_types
362        ), "call_args and arg_types do not match"
363
364        if not cuda:
365            # Even in CppWrapperCuda, we may see cpp kernels
366            return super().generate_kernel_call(
367                kernel_name,
368                call_args,
369                grid,
370                device_index,
371                cuda,
372                triton,
373                arg_types,
374                raw_args,
375                grid_fn,
376                triton_meta,
377                autotune_configs,
378                grid_extra_kwargs,
379            )
380
381        device_index, call_args = self.prepare_triton_kernel_call(
382            device_index, call_args
383        )
384        kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph)
385
386        # args with value 1 are added into equal_to_1 and constants
387        # in triton_meta (in the Python codegen) which makes them
388        # inlined in the PTX and compiled CUBIN
389        if (
390            triton_meta is not None
391            and "configs" in triton_meta
392            and triton_meta["configs"]
393        ):
394            equal_to_1 = triton_meta["configs"][0].equal_to_1
395            call_args = [arg for i, arg in enumerate(call_args) if i not in equal_to_1]
396            arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1]
397
398        call_args_str = self.generate_args_decl(call_args, arg_types)
399        kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
400        self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};")
401        stream = (
402            "stream"
403            if V.graph.aot_mode
404            else self.write_get_raw_stream(device_index, V.graph)
405        )
406
407        grid_var = f"{kernel_name}_grid_{next(self.grid_id)}"
408        self.writeline(
409            DeferredCudaGridLine(kernel_name, grid_var, grid, autotune_configs)
410        )
411
412        kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
413        # add debug printer code for all triton kernel related calls
414        debug_printer_manager = V.graph.wrapper_code.debug_printer
415        debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None)
416        with debug_printer_manager:
417            self.writeline(f"if ({grid_var}.is_non_zero()) {{")
418            self.writeline(
419                DeferredCudaKernelLine(
420                    kernel_name,
421                    r"    launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format(
422                        kernel_var_name,
423                        f"{grid_var}.grid_x",
424                        f"{grid_var}.grid_y",
425                        f"{grid_var}.grid_z",
426                        kernel_args_var,
427                        stream,
428                    ),
429                    ("num_warps", "shared_mem"),
430                ),
431            )
432            self.writeline("}")
433