1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
17 #define TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
18
19 // This header declares the class GpuSparse, which contains wrappers of
20 // cuSparse libraries for use in TensorFlow kernels.
21
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24 #include <functional>
25 #include <vector>
26
27 #if GOOGLE_CUDA
28
29 #include "third_party/gpus/cuda/include/cuda.h"
30 #include "third_party/gpus/cuda/include/cusparse.h"
31
32 using gpusparseStatus_t = cusparseStatus_t;
33 using gpusparseOperation_t = cusparseOperation_t;
34 using gpusparseMatDescr_t = cusparseMatDescr_t;
35 using gpusparseAction_t = cusparseAction_t;
36 using gpusparseHandle_t = cusparseHandle_t;
37 using gpuStream_t = cudaStream_t;
38 #if CUDA_VERSION >= 10020
39 using gpusparseDnMatDescr_t = cusparseDnMatDescr_t;
40 using gpusparseSpMatDescr_t = cusparseSpMatDescr_t;
41 using gpusparseSpMMAlg_t = cusparseSpMMAlg_t;
42 #endif
43
44 #define GPUSPARSE(postfix) CUSPARSE_##postfix
45 #define gpusparse(postfix) cusparse##postfix
46
47 #elif TENSORFLOW_USE_ROCM
48
49 #include "tensorflow/stream_executor/rocm/hipsparse_wrapper.h"
50
51 using gpusparseStatus_t = hipsparseStatus_t;
52 using gpusparseOperation_t = hipsparseOperation_t;
53 using gpusparseMatDescr_t = hipsparseMatDescr_t;
54 using gpusparseAction_t = hipsparseAction_t;
55 using gpusparseHandle_t = hipsparseHandle_t;
56 using gpuStream_t = hipStream_t;
57 #if TF_ROCM_VERSION >= 40200
58 using gpusparseDnMatDescr_t = hipsparseDnMatDescr_t;
59 using gpusparseSpMatDescr_t = hipsparseSpMatDescr_t;
60 using gpusparseSpMMAlg_t = hipsparseSpMMAlg_t;
61 #endif
62 #define GPUSPARSE(postfix) HIPSPARSE_##postfix
63 #define gpusparse(postfix) hipsparse##postfix
64
65 #endif
66
67 #include "tensorflow/core/framework/op_kernel.h"
68 #include "tensorflow/core/framework/tensor.h"
69 #include "tensorflow/core/framework/tensor_types.h"
70 #include "tensorflow/core/lib/core/status.h"
71 #include "tensorflow/core/platform/stream_executor.h"
72 #include "tensorflow/core/public/version.h"
73
74 // Macro that specializes a sparse method for all 4 standard
75 // numeric types.
76 // TODO: reuse with cuda_solvers
77 #define TF_CALL_LAPACK_TYPES(m) \
78 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
79
80 namespace tensorflow {
81
ConvertGPUSparseErrorToString(const gpusparseStatus_t status)82 inline std::string ConvertGPUSparseErrorToString(
83 const gpusparseStatus_t status) {
84 switch (status) {
85 #define STRINGIZE(q) #q
86 #define RETURN_IF_STATUS(err) \
87 case err: \
88 return STRINGIZE(err);
89
90 #if GOOGLE_CUDA
91
92 RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
93 RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
94 RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
95 RETURN_IF_STATUS(CUSPARSE_STATUS_INVALID_VALUE)
96 RETURN_IF_STATUS(CUSPARSE_STATUS_ARCH_MISMATCH)
97 RETURN_IF_STATUS(CUSPARSE_STATUS_MAPPING_ERROR)
98 RETURN_IF_STATUS(CUSPARSE_STATUS_EXECUTION_FAILED)
99 RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
100 RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
101
102 default:
103 return strings::StrCat("Unknown CUSPARSE error: ",
104 static_cast<int>(status));
105 #elif TENSORFLOW_USE_ROCM
106
107 RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
108 RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED)
109 RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED)
110 RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE)
111 RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH)
112 RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR)
113 RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED)
114 RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR)
115 RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
116 RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT)
117
118 default:
119 return strings::StrCat("Unknown hipSPARSE error: ",
120 static_cast<int>(status));
121 #endif
122
123 #undef RETURN_IF_STATUS
124 #undef STRINGIZE
125 }
126 }
127
128 #if GOOGLE_CUDA
129
130 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
131 do { \
132 auto status = (expr); \
133 if (TF_PREDICT_FALSE(status != CUSPARSE_STATUS_SUCCESS)) { \
134 return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
135 "): cuSparse call failed with status ", \
136 ConvertGPUSparseErrorToString(status)); \
137 } \
138 } while (0)
139
140 #elif TENSORFLOW_USE_ROCM
141
142 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
143 do { \
144 auto status = (expr); \
145 if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) { \
146 return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
147 "): hipSPARSE call failed with status ", \
148 ConvertGPUSparseErrorToString(status)); \
149 } \
150 } while (0)
151
152 #endif
153
TransposeAndConjugateToGpuSparseOp(bool transpose,bool conjugate,Status * status)154 inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
155 bool conjugate,
156 Status* status) {
157 #if GOOGLE_CUDA
158 if (transpose) {
159 return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
160 : CUSPARSE_OPERATION_TRANSPOSE;
161 } else {
162 if (conjugate) {
163 DCHECK(status != nullptr);
164 *status = errors::InvalidArgument(
165 "Conjugate == True and transpose == False is not supported.");
166 }
167 return CUSPARSE_OPERATION_NON_TRANSPOSE;
168 }
169 #elif TENSORFLOW_USE_ROCM
170 if (transpose) {
171 return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
172 : HIPSPARSE_OPERATION_TRANSPOSE;
173 } else {
174 if (conjugate) {
175 DCHECK(status != nullptr);
176 *status = errors::InvalidArgument(
177 "Conjugate == True and transpose == False is not supported.");
178 }
179 return HIPSPARSE_OPERATION_NON_TRANSPOSE;
180 }
181 #endif
182 }
183
184 // The GpuSparse class provides a simplified templated API for cuSparse
185 // (http://docs.nvidia.com/cuda/cusparse/index.html).
186 // An object of this class wraps static cuSparse instances,
187 // and will launch Cuda kernels on the stream wrapped by the GPU device
188 // in the OpKernelContext provided to the constructor.
189 //
190 // Notice: All the computational member functions are asynchronous and simply
191 // launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSparse
192 // object.
193
194 class GpuSparse {
195 public:
196 // This object stores a pointer to context, which must outlive it.
197 explicit GpuSparse(OpKernelContext* context);
~GpuSparse()198 virtual ~GpuSparse() {}
199
200 // This initializes the GpuSparse class if it hasn't
201 // been initialized yet. All following public methods require the
202 // class has been initialized. Can be run multiple times; all
203 // subsequent calls after the first have no effect.
204 Status Initialize(); // Move to constructor?
205
206 // ====================================================================
207 // Wrappers for cuSparse start here.
208 //
209
210 // Solves tridiagonal system of equations.
211 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
212 template <typename Scalar>
213 Status Gtsv2(int m, int n, const Scalar* dl, const Scalar* d,
214 const Scalar* du, Scalar* B, int ldb, void* pBuffer) const;
215
216 // Computes the size of a temporary buffer used by Gtsv2.
217 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_bufferSize
218 template <typename Scalar>
219 Status Gtsv2BufferSizeExt(int m, int n, const Scalar* dl, const Scalar* d,
220 const Scalar* du, const Scalar* B, int ldb,
221 size_t* bufferSizeInBytes) const;
222
223 // Solves tridiagonal system of equations without partial pivoting.
224 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot
225 template <typename Scalar>
226 Status Gtsv2NoPivot(int m, int n, const Scalar* dl, const Scalar* d,
227 const Scalar* du, Scalar* B, int ldb,
228 void* pBuffer) const;
229
230 // Computes the size of a temporary buffer used by Gtsv2NoPivot.
231 // See:
232 // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot_bufferSize
233 template <typename Scalar>
234 Status Gtsv2NoPivotBufferSizeExt(int m, int n, const Scalar* dl,
235 const Scalar* d, const Scalar* du,
236 const Scalar* B, int ldb,
237 size_t* bufferSizeInBytes) const;
238
239 // Solves a batch of tridiagonal systems of equations. Doesn't support
240 // multiple right-hand sides per each system. Doesn't do pivoting.
241 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch
242 template <typename Scalar>
243 Status Gtsv2StridedBatch(int m, const Scalar* dl, const Scalar* d,
244 const Scalar* du, Scalar* x, int batchCount,
245 int batchStride, void* pBuffer) const;
246
247 // Computes the size of a temporary buffer used by Gtsv2StridedBatch.
248 // See:
249 // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch_bufferSize
250 template <typename Scalar>
251 Status Gtsv2StridedBatchBufferSizeExt(int m, const Scalar* dl,
252 const Scalar* d, const Scalar* du,
253 const Scalar* x, int batchCount,
254 int batchStride,
255 size_t* bufferSizeInBytes) const;
256
257 // Compresses the indices of rows or columns. It can be interpreted as a
258 // conversion from COO to CSR sparse storage format. See:
259 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csr2coo.
260 Status Csr2coo(const int* CsrRowPtr, int nnz, int m, int* cooRowInd) const;
261
262 // Uncompresses the indices of rows or columns. It can be interpreted as a
263 // conversion from CSR to COO sparse storage format. See:
264 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
265 Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
266
267 #if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || \
268 (TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 40200)
269 // Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
270 // where A is a sparse matrix in CSR format, B and C are dense tall
271 // matrices. This routine allows transposition of matrix B, which
272 // may improve performance. See:
273 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmm2
274 //
275 // **NOTE** Matrices B and C are expected to be in column-major
276 // order; to make them consistent with TensorFlow they
277 // must be transposed (or the matmul op's pre/post-processing must take this
278 // into account).
279 //
280 // **NOTE** This is an in-place operation for data in C.
281 template <typename Scalar>
282 Status Csrmm(gpusparseOperation_t transA, gpusparseOperation_t transB, int m,
283 int n, int k, int nnz, const Scalar* alpha_host,
284 const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA,
285 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
286 const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
287 int ldc) const;
288 #else // CUDA_VERSION >=10200 || TF_ROCM_VERSION >= 40200
289 // Workspace size query for sparse-dense matrix multiplication. Helper
290 // function for SpMM which computes y = alpha * op(A) * op(B) + beta * C,
291 // where A is a sparse matrix in CSR format, B and C are dense matricies in
292 // column-major format. Returns needed workspace size in bytes.
293 template <typename Scalar>
294 Status SpMMBufferSize(gpusparseOperation_t transA,
295 gpusparseOperation_t transB, const Scalar* alpha,
296 const gpusparseSpMatDescr_t matA,
297 const gpusparseDnMatDescr_t matB, const Scalar* beta,
298 gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
299 size_t* bufferSize) const;
300
301 // Sparse-dense matrix multiplication y = alpha * op(A) * op(B) + beta * C,
302 // where A is a sparse matrix in CSR format, B and C are dense matricies in
303 // column-major format. Buffer is assumed to be at least as large as the
304 // workspace size returned by SpMMBufferSize().
305 //
306 // **NOTE** This is an in-place operation for data in C.
307 template <typename Scalar>
308 Status SpMM(gpusparseOperation_t transA, gpusparseOperation_t transB,
309 const Scalar* alpha, const gpusparseSpMatDescr_t matA,
310 const gpusparseDnMatDescr_t matB, const Scalar* beta,
311 gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
312 int8* buffer) const;
313 #endif
314
315 // Sparse-dense vector multiplication y = alpha * op(A) * x + beta * y,
316 // where A is a sparse matrix in CSR format, x and y are dense vectors. See:
317 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
318 //
319 // **NOTE** This is an in-place operation for data in y.
320 #if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
321 template <typename Scalar>
322 Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
323 const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
324 const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
325 const int* csrSortedColIndA, const Scalar* x,
326 const Scalar* beta_host, Scalar* y) const;
327 #else
328 template <typename Scalar>
329 Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
330 const Scalar* alpha_host, const Scalar* csrSortedValA,
331 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
332 const Scalar* x, const Scalar* beta_host, Scalar* y) const;
333 #endif // CUDA_VERSION < 10020
334
335 // Computes workspace size for sparse - sparse matrix addition of matrices
336 // stored in CSR format.
337 template <typename Scalar>
338 Status CsrgeamBufferSizeExt(
339 int m, int n, const Scalar* alpha, const gpusparseMatDescr_t descrA,
340 int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
341 const int* csrSortedColIndA, const Scalar* beta,
342 const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
343 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
344 const gpusparseMatDescr_t descrC, Scalar* csrSortedValC,
345 int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize);
346
347 // Computes sparse-sparse matrix addition of matrices
348 // stored in CSR format. This is part one: calculate nnz of the
349 // output. csrSortedRowPtrC must be preallocated on device with
350 // m + 1 entries. See:
351 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
352 Status CsrgeamNnz(int m, int n, const gpusparseMatDescr_t descrA, int nnzA,
353 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
354 const gpusparseMatDescr_t descrB, int nnzB,
355 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
356 const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
357 int* nnzTotalDevHostPtr, void* workspace);
358
359 // Computes sparse - sparse matrix addition of matrices
360 // stored in CSR format. This is part two: perform sparse-sparse
361 // addition. csrValC and csrColIndC must be allocated on the device
362 // with nnzTotalDevHostPtr entries (as calculated by CsrgeamNnz). See:
363 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
364 template <typename Scalar>
365 Status Csrgeam(int m, int n, const Scalar* alpha,
366 const gpusparseMatDescr_t descrA, int nnzA,
367 const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
368 const int* csrSortedColIndA, const Scalar* beta,
369 const gpusparseMatDescr_t descrB, int nnzB,
370 const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
371 const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
372 Scalar* csrSortedValC, int* csrSortedRowPtrC,
373 int* csrSortedColIndC, void* workspace);
374
375 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
376 // Computes sparse-sparse matrix multiplication of matrices
377 // stored in CSR format. This is part zero: calculate required workspace
378 // size.
379 template <typename Scalar>
380 Status CsrgemmBufferSize(
381 int m, int n, int k, const gpusparseMatDescr_t descrA, int nnzA,
382 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
383 const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
384 const int* csrSortedColIndB, csrgemm2Info_t info, size_t* workspaceBytes);
385 #endif
386
387 // Computes sparse-sparse matrix multiplication of matrices
388 // stored in CSR format. This is part one: calculate nnz of the
389 // output. csrSortedRowPtrC must be preallocated on device with
390 // m + 1 entries. See:
391 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
392 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
393 Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
394 int m, int k, int n, const gpusparseMatDescr_t descrA,
395 int nnzA, const int* csrSortedRowPtrA,
396 const int* csrSortedColIndA,
397 const gpusparseMatDescr_t descrB, int nnzB,
398 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
399 const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
400 int* nnzTotalDevHostPtr);
401 #else
402 Status CsrgemmNnz(int m, int n, int k, const gpusparseMatDescr_t descrA,
403 int nnzA, const int* csrSortedRowPtrA,
404 const int* csrSortedColIndA,
405 const gpusparseMatDescr_t descrB, int nnzB,
406 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
407 const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
408 int* nnzTotalDevHostPtr, csrgemm2Info_t info,
409 void* workspace);
410 #endif
411
412 // Computes sparse - sparse matrix matmul of matrices
413 // stored in CSR format. This is part two: perform sparse-sparse
414 // addition. csrValC and csrColIndC must be allocated on the device
415 // with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
416 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
417 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
418 template <typename Scalar>
419 Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
420 int m, int k, int n, const gpusparseMatDescr_t descrA,
421 int nnzA, const Scalar* csrSortedValA,
422 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
423 const gpusparseMatDescr_t descrB, int nnzB,
424 const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
425 const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
426 Scalar* csrSortedValC, int* csrSortedRowPtrC,
427 int* csrSortedColIndC);
428 #else
429 template <typename Scalar>
430 Status Csrgemm(int m, int n, int k, const gpusparseMatDescr_t descrA,
431 int nnzA, const Scalar* csrSortedValA,
432 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
433 const gpusparseMatDescr_t descrB, int nnzB,
434 const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
435 const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
436 Scalar* csrSortedValC, int* csrSortedRowPtrC,
437 int* csrSortedColIndC, const csrgemm2Info_t info,
438 void* workspace);
439 #endif
440
441 // In-place reordering of unsorted CSR to sorted CSR.
442 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
443 template <typename Scalar>
444 Status Csru2csr(int m, int n, int nnz, const gpusparseMatDescr_t descrA,
445 Scalar* csrVal, const int* csrRowPtr, int* csrColInd);
446
447 // Converts from CSR to CSC format (equivalently, transpose).
448 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-csr2cscEx
449 template <typename Scalar>
450 Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal,
451 const int* csrRowPtr, const int* csrColInd, Scalar* cscVal,
452 int* cscRowInd, int* cscColPtr,
453 const gpusparseAction_t copyValues);
454
455 private:
456 bool initialized_;
457 OpKernelContext* context_; // not owned.
458 gpuStream_t gpu_stream_;
459 gpusparseHandle_t* gpusparse_handle_; // not owned.
460
461 TF_DISALLOW_COPY_AND_ASSIGN(GpuSparse);
462 };
463
464 // A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized
465 // only once. For more details on the descriptor (gpusparseMatDescr_t), see:
466 // https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt
467 class GpuSparseMatrixDescriptor {
468 public:
GpuSparseMatrixDescriptor()469 explicit GpuSparseMatrixDescriptor() : initialized_(false) {}
470
GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor && rhs)471 GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor&& rhs)
472 : initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) {
473 rhs.initialized_ = false;
474 }
475
476 GpuSparseMatrixDescriptor& operator=(GpuSparseMatrixDescriptor&& rhs) {
477 if (this == &rhs) return *this;
478 Release();
479 initialized_ = rhs.initialized_;
480 descr_ = std::move(rhs.descr_);
481 rhs.initialized_ = false;
482 return *this;
483 }
484
~GpuSparseMatrixDescriptor()485 ~GpuSparseMatrixDescriptor() { Release(); }
486
487 // Initializes the underlying descriptor. Will fail on the second call if
488 // called more than once.
Initialize()489 Status Initialize() {
490 DCHECK(!initialized_);
491 #if GOOGLE_CUDA
492 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
493 #elif TENSORFLOW_USE_ROCM
494 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descr_));
495 #endif
496 initialized_ = true;
497 return OkStatus();
498 }
499
descr()500 gpusparseMatDescr_t& descr() {
501 DCHECK(initialized_);
502 return descr_;
503 }
504
descr()505 const gpusparseMatDescr_t& descr() const {
506 DCHECK(initialized_);
507 return descr_;
508 }
509
510 private:
Release()511 void Release() {
512 if (initialized_) {
513 #if GOOGLE_CUDA
514 cusparseDestroyMatDescr(descr_);
515 #elif TENSORFLOW_USE_ROCM
516 wrap::hipsparseDestroyMatDescr(descr_);
517 #endif
518 initialized_ = false;
519 }
520 }
521
522 bool initialized_;
523 gpusparseMatDescr_t descr_;
524
525 TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
526 };
527
528 #if GOOGLE_CUDA
529
530 // A wrapper class to ensure that an unsorted/sorted CSR conversion information
531 // struct (csru2csrInfo_t) is initialized only once. See:
532 // https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
533 class GpuSparseCsrSortingConversionInfo {
534 public:
GpuSparseCsrSortingConversionInfo()535 explicit GpuSparseCsrSortingConversionInfo() : initialized_(false) {}
536
GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo && rhs)537 GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo&& rhs)
538 : initialized_(rhs.initialized_), info_(std::move(rhs.info_)) {
539 rhs.initialized_ = false;
540 }
541
542 GpuSparseCsrSortingConversionInfo& operator=(
543 GpuSparseCsrSortingConversionInfo&& rhs) {
544 if (this == &rhs) return *this;
545 Release();
546 initialized_ = rhs.initialized_;
547 info_ = std::move(rhs.info_);
548 rhs.initialized_ = false;
549 return *this;
550 }
551
~GpuSparseCsrSortingConversionInfo()552 ~GpuSparseCsrSortingConversionInfo() { Release(); }
553
554 // Initializes the underlying info. Will fail on the second call if called
555 // more than once.
Initialize()556 Status Initialize() {
557 DCHECK(!initialized_);
558 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_));
559 initialized_ = true;
560 return OkStatus();
561 }
562
info()563 csru2csrInfo_t& info() {
564 DCHECK(initialized_);
565 return info_;
566 }
567
info()568 const csru2csrInfo_t& info() const {
569 DCHECK(initialized_);
570 return info_;
571 }
572
573 private:
Release()574 void Release() {
575 if (initialized_) {
576 cusparseDestroyCsru2csrInfo(info_);
577 initialized_ = false;
578 }
579 }
580
581 bool initialized_;
582 csru2csrInfo_t info_;
583
584 TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
585 };
586
587 #endif // GOOGLE_CUDA
588
589 } // namespace tensorflow
590
591 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
592
593 #endif // TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
594