1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/codegen/fuser/fused_kernel.h> 5 6 #include <cuda.h> 7 #include <cuda_runtime.h> 8 #include <nvrtc.h> 9 10 #include <cstdint> 11 #include <string> 12 #include <vector> 13 14 namespace torch::jit::fuser::cuda { 15 16 // query codegen output arch and target 17 TORCH_CUDA_CU_API void codegenOutputQuery( 18 const cudaDeviceProp* const prop, 19 int& major, 20 int& minor, 21 bool& compile_to_sass); 22 23 // A class holding metadata for an actual CUDA function. 24 // Note: CUDA functions are per device. 25 struct TORCH_CUDA_CU_API FusedKernelCUDA 26 : public ::torch::jit::fuser::FusedKernel { 27 FusedKernelCUDA( 28 at::DeviceIndex device, 29 std::string name, 30 std::string code, 31 std::vector<TensorDesc> input_desc, 32 std::vector<TensorDesc> output_desc, 33 std::vector<PartitionDesc> chunk_desc, 34 std::vector<PartitionDesc> concat_desc, 35 bool has_random); 36 37 ~FusedKernelCUDA() override; 38 39 void launch_raw(const uint32_t numel, std::vector<void*>& arguments) 40 const override; 41 backendFusedKernelCUDA42 at::Backend backend() const override { 43 return at::Backend::CUDA; 44 } 45 46 private: 47 static constexpr auto kBlockSize = 128; 48 49 // Note: per device to store device properties and compute launch heuristics 50 // Acquiring these values at launch time would be too slow 51 at::DeviceIndex device_; 52 int maxBlocks_{}; 53 cudaDeviceProp* prop_{}; 54 std::vector<char> ptx_; 55 CUmodule module_{}; 56 CUfunction function_{}; 57 }; 58 59 } // namespace torch::jit::fuser::cuda 60