xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDASparseDescriptors.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cuda/CUDAContext.h>
2 #include <ATen/cuda/CUDADataType.h>
3 #include <ATen/cuda/CUDASparse.h>
4 #include <ATen/cuda/CUDASparseDescriptors.h>
5 #include <ATen/native/LinearAlgebraUtils.h>
6 #include <ATen/native/cuda/MiscUtils.h>
7 
8 namespace at::cuda::sparse {
9 
destroyConstDnMat(const cusparseDnMatDescr * dnMatDescr)10 cusparseStatus_t destroyConstDnMat(const cusparseDnMatDescr* dnMatDescr) {
11   return cusparseDestroyDnMat(const_cast<cusparseDnMatDescr*>(dnMatDescr));
12 }
13 
14 #if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
15 
16 namespace {
17 
18 // If a specific GPU model does not provide native support for a given data
19 // type, cuSparse routines return CUSPARSE_STATUS_ARCH_MISMATCH error
check_supported_cuda_type(cudaDataType cuda_type)20 void check_supported_cuda_type(cudaDataType cuda_type) {
21   if (cuda_type == CUDA_R_16F) {
22     cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
23     TORCH_CHECK(
24         prop->major >= 5 && ((10 * prop->major + prop->minor) >= 53),
25         "Sparse operations with CUDA tensors of Float16 type are not supported on GPUs with compute capability < 5.3 (current: ",
26         prop->major,
27         ".",
28         prop->minor,
29         ")");
30   }
31 #if !defined(USE_ROCM)
32   if (cuda_type == CUDA_R_16BF) {
33     cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
34     TORCH_CHECK(
35         prop->major >= 8,
36         "Sparse operations with CUDA tensors of BFloat16 type are not supported on GPUs with compute capability < 8.0 (current: ",
37         prop->major,
38         ".",
39         prop->minor,
40         ")");
41   }
42 #endif
43 }
44 
45 } // anonymous namespace
46 
getCuSparseIndexType(const c10::ScalarType & scalar_type)47 cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type) {
48   if (scalar_type == c10::ScalarType::Int) {
49     return CUSPARSE_INDEX_32I;
50   } else if (scalar_type == c10::ScalarType::Long) {
51     return CUSPARSE_INDEX_64I;
52   } else {
53     TORCH_INTERNAL_ASSERT(
54         false, "Cannot convert type ", scalar_type, " to cusparseIndexType.");
55   }
56 }
57 
58 #if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
createRawDnMatDescriptor(const Tensor & input,int64_t batch_offset,bool is_const=false)59 cusparseDnMatDescr_t createRawDnMatDescriptor(const Tensor& input, int64_t batch_offset, bool is_const=false) {
60   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.layout() == kStrided);
61   IntArrayRef input_strides = input.strides();
62   IntArrayRef input_sizes = input.sizes();
63   auto ndim = input.dim();
64   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
65   auto rows = input_sizes[ndim - 2];
66   auto cols = input_sizes[ndim - 1];
67 
68   bool is_column_major =
69       at::native::is_blas_compatible_column_major_order(input);
70   bool is_row_major = at::native::is_blas_compatible_row_major_order(input);
71   TORCH_INTERNAL_ASSERT(
72       is_column_major || is_row_major,
73       "Expected either row or column major contiguous input.");
74 
75   auto leading_dimension =
76       is_row_major ? input_strides[ndim - 2] : input_strides[ndim - 1];
77 
78 #if !defined(USE_ROCM)
79   auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
80 #else
81   TORCH_INTERNAL_ASSERT(is_column_major, "Expected column major input.");
82   auto order = CUSPARSE_ORDER_COL;
83 #endif
84 
85   auto batch_stride = ndim > 2 && batch_offset >= 0 ? input_strides[ndim - 3] : 0;
86   void* data_ptr = is_const ? const_cast<void*>(input.const_data_ptr()) : input.data_ptr();
87   void* values_ptr = static_cast<char*>(data_ptr) +
88       batch_offset * batch_stride * input.itemsize();
89 
90   cudaDataType value_type = ScalarTypeToCudaDataType(input.scalar_type());
91   check_supported_cuda_type(value_type);
92 
93   // NOTE: Ideally, in the const case, we would use cusparseConstDnMatDescr_t
94   // and cusparseCreateConstDnMat, but those were introduced in CUDA 12, and we
95   // still need to support CUDA 11
96   cusparseDnMatDescr_t raw_descriptor = nullptr;
97   TORCH_CUDASPARSE_CHECK(cusparseCreateDnMat(
98       &raw_descriptor,
99       rows,
100       cols,
101       leading_dimension,
102       values_ptr,
103       value_type,
104       order));
105 
106   if (ndim >= 3 && batch_offset == -1) {
107     int batch_count =
108         at::native::cuda_int_cast(at::native::batchCount(input), "batch_count");
109     TORCH_CUDASPARSE_CHECK(cusparseDnMatSetStridedBatch(
110         raw_descriptor, batch_count, input_strides[ndim - 3]));
111   }
112   return raw_descriptor;
113 }
114 
CuSparseDnMatDescriptor(const Tensor & input,int64_t batch_offset)115 CuSparseDnMatDescriptor::CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset) {
116   descriptor_.reset(createRawDnMatDescriptor(input, batch_offset));
117 }
118 
CuSparseConstDnMatDescriptor(const Tensor & input,int64_t batch_offset)119 CuSparseConstDnMatDescriptor::CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset) {
120   descriptor_.reset(createRawDnMatDescriptor(input, batch_offset, /*is_const*/true));
121 }
122 #endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
123 
CuSparseDnVecDescriptor(const Tensor & input)124 CuSparseDnVecDescriptor::CuSparseDnVecDescriptor(const Tensor& input) {
125   // cuSPARSE doesn't support batched vectors
126   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
127       input.dim() == 1 || (input.dim() == 2 && input.size(-1) == 1));
128 
129   // cuSPARSE doesn't support non-contiguous vectors
130   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_contiguous());
131 
132   cudaDataType value_type = ScalarTypeToCudaDataType(input.scalar_type());
133   check_supported_cuda_type(value_type);
134 
135   cusparseDnVecDescr_t raw_descriptor = nullptr;
136   TORCH_CUDASPARSE_CHECK(cusparseCreateDnVec(
137       &raw_descriptor, input.numel(), input.data_ptr(), value_type));
138   descriptor_.reset(raw_descriptor);
139 }
140 
CuSparseSpMatCsrDescriptor(const Tensor & input,int64_t batch_offset)141 CuSparseSpMatCsrDescriptor::CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset) {
142   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_sparse_csr());
143   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2);
144 
145   IntArrayRef input_sizes = input.sizes();
146   auto ndim = input.dim();
147   auto rows = input_sizes[ndim - 2];
148   auto cols = input_sizes[ndim - 1];
149 
150   auto crow_indices = input.crow_indices();
151   auto col_indices = input.col_indices();
152   auto values = input.values();
153   auto nnz = values.size(-1);
154   c10::MaybeOwned<Tensor> values_ = values.expect_contiguous();
155 
156   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous());
157   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous());
158 
159   cusparseIndexType_t index_type =
160       getCuSparseIndexType(crow_indices.scalar_type());
161   cudaDataType value_type = ScalarTypeToCudaDataType(input.scalar_type());
162   check_supported_cuda_type(value_type);
163 
164   auto crow_indices_batch_stride = crow_indices.dim() >= 2 && batch_offset >= 0
165       ? crow_indices.stride(-2)
166       : 0;
167   auto col_indices_batch_stride =
168       col_indices.dim() >= 2 && batch_offset >= 0 ? col_indices.stride(-2) : 0;
169   auto values_batch_stride =
170       values.dim() >= 2 && batch_offset >= 0 ? values_->stride(-2) : 0;
171 
172   cusparseSpMatDescr_t raw_descriptor = nullptr;
173   TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
174       &raw_descriptor, // output descriptor
175       rows,
176       cols,
177       nnz,
178       // row offsets of the sparse matrix, size = rows + 1
179       static_cast<char*>(crow_indices.data_ptr()) +
180           batch_offset * crow_indices_batch_stride * crow_indices.itemsize(),
181       // column indices of the sparse matrix, size = nnz
182       static_cast<char*>(col_indices.data_ptr()) +
183           batch_offset * col_indices_batch_stride * col_indices.itemsize(),
184       // values of the sparse matrix, size = nnz
185       static_cast<char*>(values_->data_ptr()) +
186           batch_offset * values_batch_stride * values.itemsize(),
187       index_type, // data type of row offsets index
188       index_type, // data type of col indices
189       CUSPARSE_INDEX_BASE_ZERO, // base index of row offset and col indes
190       value_type // data type of values
191       ));
192 
193   if (ndim == 3 && batch_offset == -1) {
194     int batch_count =
195         at::native::cuda_int_cast(at::native::batchCount(input), "batch_count");
196     if (crow_indices.dim() >= 2 || values.dim() >= 2 ||
197         col_indices.dim() >= 2) {
198       // cuSPARSE ignores the strides and uses only the first batch
199       TORCH_INTERNAL_ASSERT(
200           false,
201           "Support for batched CSR indices and values is not implemented.");
202       TORCH_CUDASPARSE_CHECK(cusparseCsrSetStridedBatch(
203           raw_descriptor,
204           batch_count,
205           crow_indices.stride(-2),
206           values_->stride(-2)));
207     } else {
208       // cuSPARSE allows broadcasting of indices and values across batches for
209       // batched matmul
210       TORCH_CUDASPARSE_CHECK(
211           cusparseCsrSetStridedBatch(raw_descriptor, batch_count, 0, 0));
212     }
213   }
214 
215   descriptor_.reset(raw_descriptor);
216 }
217 
218 #endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
219 
220 } // namespace at::cuda::sparse
221