1
2 #include <torch/csrc/inductor/aoti_torch/c/shim.h>
3 #include <torch/csrc/inductor/aoti_torch/utils.h>
4
5 #include <c10/cuda/CUDAGuard.h>
6 #include <c10/cuda/CUDAStream.h>
7
aoti_torch_create_cuda_guard(int32_t device_index,CUDAGuardHandle * ret_guard)8 AOTITorchError aoti_torch_create_cuda_guard(
9 int32_t device_index,
10 CUDAGuardHandle* ret_guard // returns new reference
11 ) {
12 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
13 at::cuda::CUDAGuard* guard = new at::cuda::CUDAGuard(device_index);
14 *ret_guard = reinterpret_cast<CUDAGuardHandle>(guard);
15 });
16 }
17
aoti_torch_delete_cuda_guard(CUDAGuardHandle guard)18 AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard) {
19 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
20 { delete reinterpret_cast<at::cuda::CUDAGuard*>(guard); });
21 }
22
aoti_torch_cuda_guard_set_index(CUDAGuardHandle guard,int32_t device_index)23 AOTITorchError aoti_torch_cuda_guard_set_index(
24 CUDAGuardHandle guard,
25 int32_t device_index) {
26 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
27 reinterpret_cast<at::cuda::CUDAGuard*>(guard)->set_index(device_index);
28 });
29 }
30
aoti_torch_create_cuda_stream_guard(void * stream,int32_t device_index,CUDAStreamGuardHandle * ret_guard)31 AOTITorchError aoti_torch_create_cuda_stream_guard(
32 void* stream,
33 int32_t device_index,
34 CUDAStreamGuardHandle* ret_guard) {
35 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
36 at::cuda::CUDAStreamGuard* guard =
37 new at::cuda::CUDAStreamGuard(at::cuda::getStreamFromExternal(
38 static_cast<cudaStream_t>(stream), device_index));
39 *ret_guard = reinterpret_cast<CUDAStreamGuardHandle>(guard);
40 });
41 }
42
aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard)43 AOTITorchError aoti_torch_delete_cuda_stream_guard(
44 CUDAStreamGuardHandle guard) {
45 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
46 { delete reinterpret_cast<at::cuda::CUDAStreamGuard*>(guard); });
47 }
48
49 AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_cuda_stream(int32_t device_index,void ** ret_stream)50 aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream) {
51 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
52 *(cudaStream_t*)(ret_stream) = at::cuda::getCurrentCUDAStream(device_index);
53 });
54 }
55