1 #include <pybind11/pybind11.h>
2 #include <torch/csrc/utils/pybind.h>
3
4 #if defined(USE_CUFILE)
5 #include <c10/cuda/CUDAGuard.h>
6
7 #include <cuda_runtime.h>
8 #include <cufile.h>
9
10 namespace {
11 // To get error message for cuFileRead/Write APIs that return ssize_t (-1 for
12 // filesystem error and a negative CUfileOpError enum value otherwise).
13 template <
14 class T,
15 typename std::enable_if<std::is_integral<T>::value, std::nullptr_t>::type =
16 nullptr>
cuGDSFileGetErrorString(T status)17 std::string cuGDSFileGetErrorString(T status) {
18 status = std::abs(status);
19 return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status))
20 : std::string(std::strerror(errno));
21 }
22
23 // To get error message for Buf/Handle registeration APIs that return
24 // CUfileError_t
25 template <
26 class T,
27 typename std::enable_if<!std::is_integral<T>::value, std::nullptr_t>::type =
28 nullptr>
cuGDSFileGetErrorString(T status)29 std::string cuGDSFileGetErrorString(T status) {
30 std::string errStr = cuGDSFileGetErrorString(static_cast<int>(status.err));
31 if (IS_CUDA_ERR(status))
32 errStr.append(".").append(
33 cudaGetErrorString(static_cast<cudaError_t>(status.cu_err)));
34 return errStr;
35 }
36 } // namespace
37
gds_load_storage(int64_t handle,const at::Storage & storage,off_t offset)38 void gds_load_storage(
39 int64_t handle,
40 const at::Storage& storage,
41 off_t offset) {
42 // NOLINTNEXTLINE(performance-no-int-to-ptr)
43 CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
44 c10::cuda::CUDAGuard gpuGuard(storage.device());
45
46 void* dataPtr = storage.mutable_data();
47 const size_t nbytes = storage.nbytes();
48
49 // Read the binary file
50 ssize_t ret = cuFileRead(cf_handle, (void*)dataPtr, nbytes, offset, 0);
51 TORCH_CHECK(ret >= 0, "cuFileRead failed: ", cuGDSFileGetErrorString(ret));
52 }
53
gds_save_storage(int64_t handle,const at::Storage & storage,off_t offset)54 void gds_save_storage(
55 int64_t handle,
56 const at::Storage& storage,
57 off_t offset) {
58 // NOLINTNEXTLINE(performance-no-int-to-ptr)
59 CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
60 c10::cuda::CUDAGuard gpuGuard(storage.device());
61
62 void* dataPtr = storage.mutable_data();
63 const size_t nbytes = storage.nbytes();
64
65 // Write device memory contents to the file
66 ssize_t ret = cuFileWrite(cf_handle, dataPtr, nbytes, offset, 0);
67 TORCH_CHECK(ret >= 0, "cuFileWrite failed: ", cuGDSFileGetErrorString(ret));
68 }
69
gds_register_buffer(const at::Storage & storage)70 void gds_register_buffer(const at::Storage& storage) {
71 void* dataPtr = storage.mutable_data();
72 const size_t nbytes = storage.nbytes();
73
74 CUfileError_t status = cuFileBufRegister(dataPtr, nbytes, 0);
75 TORCH_CHECK(
76 status.err == CU_FILE_SUCCESS,
77 "cuFileBufRegister failed: ",
78 cuGDSFileGetErrorString(status));
79 return;
80 }
81
gds_deregister_buffer(const at::Storage & storage)82 void gds_deregister_buffer(const at::Storage& storage) {
83 void* dataPtr = storage.mutable_data();
84 CUfileError_t status = cuFileBufDeregister(dataPtr);
85 TORCH_CHECK(
86 status.err == CU_FILE_SUCCESS,
87 "cuFileBufDeregister failed: ",
88 cuGDSFileGetErrorString(status));
89 return;
90 }
91
gds_register_handle(int fd)92 int64_t gds_register_handle(int fd) {
93 CUfileDescr_t cf_descr;
94 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
95 CUfileHandle_t cf_handle;
96 memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t));
97 cf_descr.handle.fd = fd;
98 cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
99 CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);
100 if (status.err != CU_FILE_SUCCESS) {
101 TORCH_CHECK(
102 false,
103 "cuFileHandleRegister failed: ",
104 cuGDSFileGetErrorString(status));
105 }
106
107 // Returning cuFileHandle_t as int64_t
108 return reinterpret_cast<int64_t>(cf_handle);
109 }
110
gds_deregister_handle(int64_t handle)111 void gds_deregister_handle(int64_t handle) {
112 // NOLINTNEXTLINE(performance-no-int-to-ptr)
113 CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
114 cuFileHandleDeregister(cf_handle);
115 }
116
117 #endif
118
119 namespace torch::cuda::shared {
120
initGdsBindings(PyObject * module)121 void initGdsBindings(PyObject* module) {
122 auto m = py::handle(module).cast<py::module>();
123
124 #if defined(USE_CUFILE)
125 m.def("_gds_register_handle", &gds_register_handle);
126 m.def("_gds_deregister_handle", &gds_deregister_handle);
127 m.def("_gds_register_buffer", &gds_register_buffer);
128 m.def("_gds_deregister_buffer", &gds_deregister_buffer);
129 m.def("_gds_load_storage", &gds_load_storage);
130 m.def("_gds_save_storage", &gds_save_storage);
131 #endif
132 }
133
134 } // namespace torch::cuda::shared
135