1 #pragma once
2
3 #ifdef USE_CUDA
4 // WARNING: Be careful when adding new includes here. This header will be used
5 // in model.so, and should not refer to any aten/c10 headers except the stable
6 // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
7 // applies to other files under torch/csrc/inductor/aoti_runtime/.
8 #include <torch/csrc/inductor/aoti_runtime/utils.h>
9
10 #include <cuda.h>
11 #include <cuda_runtime.h>
12
13 namespace torch::aot_inductor {
14
delete_cuda_guard(void * ptr)15 inline void delete_cuda_guard(void* ptr) {
16 AOTI_TORCH_ERROR_CODE_CHECK(
17 aoti_torch_delete_cuda_guard(reinterpret_cast<CUDAGuardHandle>(ptr)));
18 }
19
delete_cuda_stream_guard(void * ptr)20 inline void delete_cuda_stream_guard(void* ptr) {
21 AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_cuda_stream_guard(
22 reinterpret_cast<CUDAStreamGuardHandle>(ptr)));
23 }
24
25 class AOTICudaGuard {
26 public:
AOTICudaGuard(int32_t device_index)27 AOTICudaGuard(int32_t device_index) : guard_(nullptr, delete_cuda_guard) {
28 CUDAGuardHandle ptr = nullptr;
29 AOTI_TORCH_ERROR_CODE_CHECK(
30 aoti_torch_create_cuda_guard(device_index, &ptr));
31 guard_.reset(ptr);
32 }
33
set_index(int32_t device_index)34 void set_index(int32_t device_index) {
35 AOTI_TORCH_ERROR_CODE_CHECK(
36 aoti_torch_cuda_guard_set_index(guard_.get(), device_index));
37 }
38
39 private:
40 std::unique_ptr<CUDAGuardOpaque, DeleterFnPtr> guard_;
41 };
42
43 class AOTICudaStreamGuard {
44 public:
AOTICudaStreamGuard(cudaStream_t stream,int32_t device_index)45 AOTICudaStreamGuard(cudaStream_t stream, int32_t device_index)
46 : guard_(nullptr, delete_cuda_stream_guard) {
47 CUDAStreamGuardHandle ptr = nullptr;
48 AOTI_TORCH_ERROR_CODE_CHECK(
49 aoti_torch_create_cuda_stream_guard(stream, device_index, &ptr));
50 guard_.reset(ptr);
51 }
52
53 private:
54 std::unique_ptr<CUDAStreamGuardOpaque, DeleterFnPtr> guard_;
55 };
56
57 } // namespace torch::aot_inductor
58 #endif // USE_CUDA
59