xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/codegen_device_driver.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2
3
4# Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose
5
6
7def cuda_kernel_driver() -> str:
8    source_codes = """
9            #define CUDA_DRIVER_CHECK(EXPR)                    \\
10            do {                                               \\
11                CUresult code = EXPR;                          \\
12                const char *msg;                               \\
13                cuGetErrorString(code, &msg);                  \\
14                if (code != CUDA_SUCCESS) {                    \\
15                    throw std::runtime_error(                  \\
16                        std::string("CUDA driver error: ") +   \\
17                        std::string(msg));                     \\
18                }                                              \\
19            } while (0);
20
21            namespace {
22
23            struct Grid {
24                Grid(uint32_t x, uint32_t y, uint32_t z)
25                  : grid_x(x), grid_y(y), grid_z(z) {}
26                uint32_t grid_x;
27                uint32_t grid_y;
28                uint32_t grid_z;
29
30                bool is_non_zero() {
31                    return grid_x > 0 && grid_y > 0 && grid_z > 0;
32                }
33            };
34
35            }  // anonymous namespace
36
37            static inline CUfunction loadKernel(
38                    std::string filePath,
39                    const std::string &funcName,
40                    uint32_t sharedMemBytes,
41                    const std::optional<std::string> &cubinDir = std::nullopt) {
42                if (cubinDir) {
43                    std::filesystem::path p1{*cubinDir};
44                    std::filesystem::path p2{filePath};
45                    filePath = (p1 / p2.filename()).string();
46                }
47
48                CUmodule mod;
49                CUfunction func;
50                CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
51                CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
52                if (sharedMemBytes > 0) {
53                    CUDA_DRIVER_CHECK(cuFuncSetAttribute(
54                        func,
55                        CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
56                        sharedMemBytes
57                    ))
58                }
59                return func;
60            }
61
62            static inline void launchKernel(
63                    CUfunction func,
64                    uint32_t gridX,
65                    uint32_t gridY,
66                    uint32_t gridZ,
67                    uint32_t numWarps,
68                    uint32_t sharedMemBytes,
69                    void* args[],
70                    cudaStream_t stream) {
71                CUDA_DRIVER_CHECK(cuLaunchKernel(
72                    func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
73                ));
74            }
75    """
76    if torch.version.hip is not None:
77        # Adjusting the warp size to GPU supported wavefront size on AMD GPU
78        prop = torch.cuda.get_device_properties(torch.cuda.current_device())
79        source_codes = source_codes.replace(
80            "32*numWarps", str(prop.warp_size) + "*numWarps"
81        )
82    return source_codes
83
84
85def cuda_kernel_header() -> str:
86    source_codes = """
87        #include <c10/cuda/CUDAGuard.h>
88        #include <c10/cuda/CUDAStream.h>
89        #include <ATen/cuda/EmptyTensor.h>
90    """
91    return source_codes
92