xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_runtime/utils_cuda.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_CUDA
4 // WARNING: Be careful when adding new includes here. This header will be used
5 // in model.so, and should not refer to any aten/c10 headers except the stable
6 // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
7 // applies to other files under torch/csrc/inductor/aoti_runtime/.
8 #include <torch/csrc/inductor/aoti_runtime/utils.h>
9 
10 #include <cuda.h>
11 #include <cuda_runtime.h>
12 
13 namespace torch::aot_inductor {
14 
delete_cuda_guard(void * ptr)15 inline void delete_cuda_guard(void* ptr) {
16   AOTI_TORCH_ERROR_CODE_CHECK(
17       aoti_torch_delete_cuda_guard(reinterpret_cast<CUDAGuardHandle>(ptr)));
18 }
19 
delete_cuda_stream_guard(void * ptr)20 inline void delete_cuda_stream_guard(void* ptr) {
21   AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_cuda_stream_guard(
22       reinterpret_cast<CUDAStreamGuardHandle>(ptr)));
23 }
24 
25 class AOTICudaGuard {
26  public:
AOTICudaGuard(int32_t device_index)27   AOTICudaGuard(int32_t device_index) : guard_(nullptr, delete_cuda_guard) {
28     CUDAGuardHandle ptr = nullptr;
29     AOTI_TORCH_ERROR_CODE_CHECK(
30         aoti_torch_create_cuda_guard(device_index, &ptr));
31     guard_.reset(ptr);
32   }
33 
set_index(int32_t device_index)34   void set_index(int32_t device_index) {
35     AOTI_TORCH_ERROR_CODE_CHECK(
36         aoti_torch_cuda_guard_set_index(guard_.get(), device_index));
37   }
38 
39  private:
40   std::unique_ptr<CUDAGuardOpaque, DeleterFnPtr> guard_;
41 };
42 
43 class AOTICudaStreamGuard {
44  public:
AOTICudaStreamGuard(cudaStream_t stream,int32_t device_index)45   AOTICudaStreamGuard(cudaStream_t stream, int32_t device_index)
46       : guard_(nullptr, delete_cuda_stream_guard) {
47     CUDAStreamGuardHandle ptr = nullptr;
48     AOTI_TORCH_ERROR_CODE_CHECK(
49         aoti_torch_create_cuda_stream_guard(stream, device_index, &ptr));
50     guard_.reset(ptr);
51   }
52 
53  private:
54   std::unique_ptr<CUDAStreamGuardOpaque, DeleterFnPtr> guard_;
55 };
56 
57 } // namespace torch::aot_inductor
58 #endif // USE_CUDA
59