xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/shim_cuda.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #include <torch/csrc/inductor/aoti_torch/c/shim.h>
3 #include <torch/csrc/inductor/aoti_torch/utils.h>
4 
5 #include <c10/cuda/CUDAGuard.h>
6 #include <c10/cuda/CUDAStream.h>
7 
aoti_torch_create_cuda_guard(int32_t device_index,CUDAGuardHandle * ret_guard)8 AOTITorchError aoti_torch_create_cuda_guard(
9     int32_t device_index,
10     CUDAGuardHandle* ret_guard // returns new reference
11 ) {
12   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
13     at::cuda::CUDAGuard* guard = new at::cuda::CUDAGuard(device_index);
14     *ret_guard = reinterpret_cast<CUDAGuardHandle>(guard);
15   });
16 }
17 
aoti_torch_delete_cuda_guard(CUDAGuardHandle guard)18 AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard) {
19   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
20       { delete reinterpret_cast<at::cuda::CUDAGuard*>(guard); });
21 }
22 
aoti_torch_cuda_guard_set_index(CUDAGuardHandle guard,int32_t device_index)23 AOTITorchError aoti_torch_cuda_guard_set_index(
24     CUDAGuardHandle guard,
25     int32_t device_index) {
26   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
27     reinterpret_cast<at::cuda::CUDAGuard*>(guard)->set_index(device_index);
28   });
29 }
30 
aoti_torch_create_cuda_stream_guard(void * stream,int32_t device_index,CUDAStreamGuardHandle * ret_guard)31 AOTITorchError aoti_torch_create_cuda_stream_guard(
32     void* stream,
33     int32_t device_index,
34     CUDAStreamGuardHandle* ret_guard) {
35   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
36     at::cuda::CUDAStreamGuard* guard =
37         new at::cuda::CUDAStreamGuard(at::cuda::getStreamFromExternal(
38             static_cast<cudaStream_t>(stream), device_index));
39     *ret_guard = reinterpret_cast<CUDAStreamGuardHandle>(guard);
40   });
41 }
42 
aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard)43 AOTITorchError aoti_torch_delete_cuda_stream_guard(
44     CUDAStreamGuardHandle guard) {
45   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
46       { delete reinterpret_cast<at::cuda::CUDAStreamGuard*>(guard); });
47 }
48 
49 AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_cuda_stream(int32_t device_index,void ** ret_stream)50 aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream) {
51   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
52     *(cudaStream_t*)(ret_stream) = at::cuda::getCurrentCUDAStream(device_index);
53   });
54 }
55