1 #include <ATen/cuda/CUDAContext.h>
2 #include <ATen/cuda/CUDADataType.h>
3 #include <ATen/cuda/CUDASparse.h>
4 #include <ATen/cuda/CUDASparseDescriptors.h>
5 #include <ATen/native/LinearAlgebraUtils.h>
6 #include <ATen/native/cuda/MiscUtils.h>
7
8 namespace at::cuda::sparse {
9
destroyConstDnMat(const cusparseDnMatDescr * dnMatDescr)10 cusparseStatus_t destroyConstDnMat(const cusparseDnMatDescr* dnMatDescr) {
11 return cusparseDestroyDnMat(const_cast<cusparseDnMatDescr*>(dnMatDescr));
12 }
13
14 #if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
15
16 namespace {
17
18 // If a specific GPU model does not provide native support for a given data
19 // type, cuSparse routines return CUSPARSE_STATUS_ARCH_MISMATCH error
check_supported_cuda_type(cudaDataType cuda_type)20 void check_supported_cuda_type(cudaDataType cuda_type) {
21 if (cuda_type == CUDA_R_16F) {
22 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
23 TORCH_CHECK(
24 prop->major >= 5 && ((10 * prop->major + prop->minor) >= 53),
25 "Sparse operations with CUDA tensors of Float16 type are not supported on GPUs with compute capability < 5.3 (current: ",
26 prop->major,
27 ".",
28 prop->minor,
29 ")");
30 }
31 #if !defined(USE_ROCM)
32 if (cuda_type == CUDA_R_16BF) {
33 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
34 TORCH_CHECK(
35 prop->major >= 8,
36 "Sparse operations with CUDA tensors of BFloat16 type are not supported on GPUs with compute capability < 8.0 (current: ",
37 prop->major,
38 ".",
39 prop->minor,
40 ")");
41 }
42 #endif
43 }
44
45 } // anonymous namespace
46
getCuSparseIndexType(const c10::ScalarType & scalar_type)47 cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type) {
48 if (scalar_type == c10::ScalarType::Int) {
49 return CUSPARSE_INDEX_32I;
50 } else if (scalar_type == c10::ScalarType::Long) {
51 return CUSPARSE_INDEX_64I;
52 } else {
53 TORCH_INTERNAL_ASSERT(
54 false, "Cannot convert type ", scalar_type, " to cusparseIndexType.");
55 }
56 }
57
58 #if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
createRawDnMatDescriptor(const Tensor & input,int64_t batch_offset,bool is_const=false)59 cusparseDnMatDescr_t createRawDnMatDescriptor(const Tensor& input, int64_t batch_offset, bool is_const=false) {
60 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.layout() == kStrided);
61 IntArrayRef input_strides = input.strides();
62 IntArrayRef input_sizes = input.sizes();
63 auto ndim = input.dim();
64 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
65 auto rows = input_sizes[ndim - 2];
66 auto cols = input_sizes[ndim - 1];
67
68 bool is_column_major =
69 at::native::is_blas_compatible_column_major_order(input);
70 bool is_row_major = at::native::is_blas_compatible_row_major_order(input);
71 TORCH_INTERNAL_ASSERT(
72 is_column_major || is_row_major,
73 "Expected either row or column major contiguous input.");
74
75 auto leading_dimension =
76 is_row_major ? input_strides[ndim - 2] : input_strides[ndim - 1];
77
78 #if !defined(USE_ROCM)
79 auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
80 #else
81 TORCH_INTERNAL_ASSERT(is_column_major, "Expected column major input.");
82 auto order = CUSPARSE_ORDER_COL;
83 #endif
84
85 auto batch_stride = ndim > 2 && batch_offset >= 0 ? input_strides[ndim - 3] : 0;
86 void* data_ptr = is_const ? const_cast<void*>(input.const_data_ptr()) : input.data_ptr();
87 void* values_ptr = static_cast<char*>(data_ptr) +
88 batch_offset * batch_stride * input.itemsize();
89
90 cudaDataType value_type = ScalarTypeToCudaDataType(input.scalar_type());
91 check_supported_cuda_type(value_type);
92
93 // NOTE: Ideally, in the const case, we would use cusparseConstDnMatDescr_t
94 // and cusparseCreateConstDnMat, but those were introduced in CUDA 12, and we
95 // still need to support CUDA 11
96 cusparseDnMatDescr_t raw_descriptor = nullptr;
97 TORCH_CUDASPARSE_CHECK(cusparseCreateDnMat(
98 &raw_descriptor,
99 rows,
100 cols,
101 leading_dimension,
102 values_ptr,
103 value_type,
104 order));
105
106 if (ndim >= 3 && batch_offset == -1) {
107 int batch_count =
108 at::native::cuda_int_cast(at::native::batchCount(input), "batch_count");
109 TORCH_CUDASPARSE_CHECK(cusparseDnMatSetStridedBatch(
110 raw_descriptor, batch_count, input_strides[ndim - 3]));
111 }
112 return raw_descriptor;
113 }
114
CuSparseDnMatDescriptor(const Tensor & input,int64_t batch_offset)115 CuSparseDnMatDescriptor::CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset) {
116 descriptor_.reset(createRawDnMatDescriptor(input, batch_offset));
117 }
118
CuSparseConstDnMatDescriptor(const Tensor & input,int64_t batch_offset)119 CuSparseConstDnMatDescriptor::CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset) {
120 descriptor_.reset(createRawDnMatDescriptor(input, batch_offset, /*is_const*/true));
121 }
122 #endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
123
CuSparseDnVecDescriptor(const Tensor & input)124 CuSparseDnVecDescriptor::CuSparseDnVecDescriptor(const Tensor& input) {
125 // cuSPARSE doesn't support batched vectors
126 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
127 input.dim() == 1 || (input.dim() == 2 && input.size(-1) == 1));
128
129 // cuSPARSE doesn't support non-contiguous vectors
130 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_contiguous());
131
132 cudaDataType value_type = ScalarTypeToCudaDataType(input.scalar_type());
133 check_supported_cuda_type(value_type);
134
135 cusparseDnVecDescr_t raw_descriptor = nullptr;
136 TORCH_CUDASPARSE_CHECK(cusparseCreateDnVec(
137 &raw_descriptor, input.numel(), input.data_ptr(), value_type));
138 descriptor_.reset(raw_descriptor);
139 }
140
CuSparseSpMatCsrDescriptor(const Tensor & input,int64_t batch_offset)141 CuSparseSpMatCsrDescriptor::CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset) {
142 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_sparse_csr());
143 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2);
144
145 IntArrayRef input_sizes = input.sizes();
146 auto ndim = input.dim();
147 auto rows = input_sizes[ndim - 2];
148 auto cols = input_sizes[ndim - 1];
149
150 auto crow_indices = input.crow_indices();
151 auto col_indices = input.col_indices();
152 auto values = input.values();
153 auto nnz = values.size(-1);
154 c10::MaybeOwned<Tensor> values_ = values.expect_contiguous();
155
156 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous());
157 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous());
158
159 cusparseIndexType_t index_type =
160 getCuSparseIndexType(crow_indices.scalar_type());
161 cudaDataType value_type = ScalarTypeToCudaDataType(input.scalar_type());
162 check_supported_cuda_type(value_type);
163
164 auto crow_indices_batch_stride = crow_indices.dim() >= 2 && batch_offset >= 0
165 ? crow_indices.stride(-2)
166 : 0;
167 auto col_indices_batch_stride =
168 col_indices.dim() >= 2 && batch_offset >= 0 ? col_indices.stride(-2) : 0;
169 auto values_batch_stride =
170 values.dim() >= 2 && batch_offset >= 0 ? values_->stride(-2) : 0;
171
172 cusparseSpMatDescr_t raw_descriptor = nullptr;
173 TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
174 &raw_descriptor, // output descriptor
175 rows,
176 cols,
177 nnz,
178 // row offsets of the sparse matrix, size = rows + 1
179 static_cast<char*>(crow_indices.data_ptr()) +
180 batch_offset * crow_indices_batch_stride * crow_indices.itemsize(),
181 // column indices of the sparse matrix, size = nnz
182 static_cast<char*>(col_indices.data_ptr()) +
183 batch_offset * col_indices_batch_stride * col_indices.itemsize(),
184 // values of the sparse matrix, size = nnz
185 static_cast<char*>(values_->data_ptr()) +
186 batch_offset * values_batch_stride * values.itemsize(),
187 index_type, // data type of row offsets index
188 index_type, // data type of col indices
189 CUSPARSE_INDEX_BASE_ZERO, // base index of row offset and col indes
190 value_type // data type of values
191 ));
192
193 if (ndim == 3 && batch_offset == -1) {
194 int batch_count =
195 at::native::cuda_int_cast(at::native::batchCount(input), "batch_count");
196 if (crow_indices.dim() >= 2 || values.dim() >= 2 ||
197 col_indices.dim() >= 2) {
198 // cuSPARSE ignores the strides and uses only the first batch
199 TORCH_INTERNAL_ASSERT(
200 false,
201 "Support for batched CSR indices and values is not implemented.");
202 TORCH_CUDASPARSE_CHECK(cusparseCsrSetStridedBatch(
203 raw_descriptor,
204 batch_count,
205 crow_indices.stride(-2),
206 values_->stride(-2)));
207 } else {
208 // cuSPARSE allows broadcasting of indices and values across batches for
209 // batched matmul
210 TORCH_CUDASPARSE_CHECK(
211 cusparseCsrSetStridedBatch(raw_descriptor, batch_count, 0, 0));
212 }
213 }
214
215 descriptor_.reset(raw_descriptor);
216 }
217
218 #endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
219
220 } // namespace at::cuda::sparse
221