xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/GdsFile.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pybind11/pybind11.h>
2 #include <torch/csrc/utils/pybind.h>
3 
4 #if defined(USE_CUFILE)
5 #include <c10/cuda/CUDAGuard.h>
6 
7 #include <cuda_runtime.h>
8 #include <cufile.h>
9 
10 namespace {
11 // To get error message for cuFileRead/Write APIs that return ssize_t (-1 for
12 // filesystem error and a negative CUfileOpError enum value otherwise).
13 template <
14     class T,
15     typename std::enable_if<std::is_integral<T>::value, std::nullptr_t>::type =
16         nullptr>
cuGDSFileGetErrorString(T status)17 std::string cuGDSFileGetErrorString(T status) {
18   status = std::abs(status);
19   return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status))
20                                : std::string(std::strerror(errno));
21 }
22 
23 // To get error message for Buf/Handle registeration APIs that return
24 // CUfileError_t
25 template <
26     class T,
27     typename std::enable_if<!std::is_integral<T>::value, std::nullptr_t>::type =
28         nullptr>
cuGDSFileGetErrorString(T status)29 std::string cuGDSFileGetErrorString(T status) {
30   std::string errStr = cuGDSFileGetErrorString(static_cast<int>(status.err));
31   if (IS_CUDA_ERR(status))
32     errStr.append(".").append(
33         cudaGetErrorString(static_cast<cudaError_t>(status.cu_err)));
34   return errStr;
35 }
36 } // namespace
37 
gds_load_storage(int64_t handle,const at::Storage & storage,off_t offset)38 void gds_load_storage(
39     int64_t handle,
40     const at::Storage& storage,
41     off_t offset) {
42   // NOLINTNEXTLINE(performance-no-int-to-ptr)
43   CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
44   c10::cuda::CUDAGuard gpuGuard(storage.device());
45 
46   void* dataPtr = storage.mutable_data();
47   const size_t nbytes = storage.nbytes();
48 
49   // Read the binary file
50   ssize_t ret = cuFileRead(cf_handle, (void*)dataPtr, nbytes, offset, 0);
51   TORCH_CHECK(ret >= 0, "cuFileRead failed: ", cuGDSFileGetErrorString(ret));
52 }
53 
gds_save_storage(int64_t handle,const at::Storage & storage,off_t offset)54 void gds_save_storage(
55     int64_t handle,
56     const at::Storage& storage,
57     off_t offset) {
58   // NOLINTNEXTLINE(performance-no-int-to-ptr)
59   CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
60   c10::cuda::CUDAGuard gpuGuard(storage.device());
61 
62   void* dataPtr = storage.mutable_data();
63   const size_t nbytes = storage.nbytes();
64 
65   // Write device memory contents to the file
66   ssize_t ret = cuFileWrite(cf_handle, dataPtr, nbytes, offset, 0);
67   TORCH_CHECK(ret >= 0, "cuFileWrite failed: ", cuGDSFileGetErrorString(ret));
68 }
69 
gds_register_buffer(const at::Storage & storage)70 void gds_register_buffer(const at::Storage& storage) {
71   void* dataPtr = storage.mutable_data();
72   const size_t nbytes = storage.nbytes();
73 
74   CUfileError_t status = cuFileBufRegister(dataPtr, nbytes, 0);
75   TORCH_CHECK(
76       status.err == CU_FILE_SUCCESS,
77       "cuFileBufRegister failed: ",
78       cuGDSFileGetErrorString(status));
79   return;
80 }
81 
gds_deregister_buffer(const at::Storage & storage)82 void gds_deregister_buffer(const at::Storage& storage) {
83   void* dataPtr = storage.mutable_data();
84   CUfileError_t status = cuFileBufDeregister(dataPtr);
85   TORCH_CHECK(
86       status.err == CU_FILE_SUCCESS,
87       "cuFileBufDeregister failed: ",
88       cuGDSFileGetErrorString(status));
89   return;
90 }
91 
gds_register_handle(int fd)92 int64_t gds_register_handle(int fd) {
93   CUfileDescr_t cf_descr;
94   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
95   CUfileHandle_t cf_handle;
96   memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t));
97   cf_descr.handle.fd = fd;
98   cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
99   CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);
100   if (status.err != CU_FILE_SUCCESS) {
101     TORCH_CHECK(
102         false,
103         "cuFileHandleRegister failed: ",
104         cuGDSFileGetErrorString(status));
105   }
106 
107   // Returning cuFileHandle_t as int64_t
108   return reinterpret_cast<int64_t>(cf_handle);
109 }
110 
gds_deregister_handle(int64_t handle)111 void gds_deregister_handle(int64_t handle) {
112   // NOLINTNEXTLINE(performance-no-int-to-ptr)
113   CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
114   cuFileHandleDeregister(cf_handle);
115 }
116 
117 #endif
118 
119 namespace torch::cuda::shared {
120 
initGdsBindings(PyObject * module)121 void initGdsBindings(PyObject* module) {
122   auto m = py::handle(module).cast<py::module>();
123 
124 #if defined(USE_CUFILE)
125   m.def("_gds_register_handle", &gds_register_handle);
126   m.def("_gds_deregister_handle", &gds_deregister_handle);
127   m.def("_gds_register_buffer", &gds_register_buffer);
128   m.def("_gds_deregister_buffer", &gds_deregister_buffer);
129   m.def("_gds_load_storage", &gds_load_storage);
130   m.def("_gds_save_storage", &gds_save_storage);
131 #endif
132 }
133 
134 } // namespace torch::cuda::shared
135