1import torch 2 3 4# Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose 5 6 7def cuda_kernel_driver() -> str: 8 source_codes = """ 9 #define CUDA_DRIVER_CHECK(EXPR) \\ 10 do { \\ 11 CUresult code = EXPR; \\ 12 const char *msg; \\ 13 cuGetErrorString(code, &msg); \\ 14 if (code != CUDA_SUCCESS) { \\ 15 throw std::runtime_error( \\ 16 std::string("CUDA driver error: ") + \\ 17 std::string(msg)); \\ 18 } \\ 19 } while (0); 20 21 namespace { 22 23 struct Grid { 24 Grid(uint32_t x, uint32_t y, uint32_t z) 25 : grid_x(x), grid_y(y), grid_z(z) {} 26 uint32_t grid_x; 27 uint32_t grid_y; 28 uint32_t grid_z; 29 30 bool is_non_zero() { 31 return grid_x > 0 && grid_y > 0 && grid_z > 0; 32 } 33 }; 34 35 } // anonymous namespace 36 37 static inline CUfunction loadKernel( 38 std::string filePath, 39 const std::string &funcName, 40 uint32_t sharedMemBytes, 41 const std::optional<std::string> &cubinDir = std::nullopt) { 42 if (cubinDir) { 43 std::filesystem::path p1{*cubinDir}; 44 std::filesystem::path p2{filePath}; 45 filePath = (p1 / p2.filename()).string(); 46 } 47 48 CUmodule mod; 49 CUfunction func; 50 CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); 51 CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); 52 if (sharedMemBytes > 0) { 53 CUDA_DRIVER_CHECK(cuFuncSetAttribute( 54 func, 55 CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, 56 sharedMemBytes 57 )) 58 } 59 return func; 60 } 61 62 static inline void launchKernel( 63 CUfunction func, 64 uint32_t gridX, 65 uint32_t gridY, 66 uint32_t gridZ, 67 uint32_t numWarps, 68 uint32_t sharedMemBytes, 69 void* args[], 70 cudaStream_t stream) { 71 CUDA_DRIVER_CHECK(cuLaunchKernel( 72 func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr 73 )); 74 } 75 """ 76 if torch.version.hip is not None: 77 # Adjusting the warp size to GPU supported wavefront size on AMD GPU 78 prop = torch.cuda.get_device_properties(torch.cuda.current_device()) 79 source_codes = source_codes.replace( 80 "32*numWarps", str(prop.warp_size) + "*numWarps" 81 ) 82 return source_codes 83 84 85def cuda_kernel_header() -> str: 86 source_codes = """ 87 #include <c10/cuda/CUDAGuard.h> 88 #include <c10/cuda/CUDAStream.h> 89 #include <ATen/cuda/EmptyTensor.h> 90 """ 91 return source_codes 92