1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifdef GOOGLE_CUDA
17
18 #include "tensorflow/core/util/cuda_sparse.h"
19
20 #include <complex>
21 #include <memory>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25
26 #include "third_party/gpus/cuda/include/cusparse.h"
27 #include "third_party/gpus/cuda/include/library_types.h"
28 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/core/blocking_counter.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/gtl/inlined_vector.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/stream_executor.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/util/gpu_solvers.h"
41
42 // TODO(rmlarsen,penporn): Investigate using newer kernels in CUDA 10.1+.
43
44 namespace tensorflow {
45 namespace {
46
47 // Type traits to get CUDA complex types from std::complex<>.
48 // TODO: reuse with cuda_solvers
49 template <typename T>
50 struct CudaComplexT {
51 typedef T type;
52 };
53 template <>
54 struct CudaComplexT<std::complex<float>> {
55 typedef cuComplex type;
56 };
57 template <>
58 struct CudaComplexT<std::complex<double>> {
59 typedef cuDoubleComplex type;
60 };
61 // Converts pointers of std::complex<> to pointers of
62 // cuComplex/cuDoubleComplex. No type conversion for non-complex types.
63 template <typename T>
AsCudaComplex(const T * p)64 inline const typename CudaComplexT<T>::type* AsCudaComplex(const T* p) {
65 return reinterpret_cast<const typename CudaComplexT<T>::type*>(p);
66 }
67 template <typename T>
AsCudaComplex(T * p)68 inline typename CudaComplexT<T>::type* AsCudaComplex(T* p) {
69 return reinterpret_cast<typename CudaComplexT<T>::type*>(p);
70 }
71
72 // A set of initialized handles to the underlying Cuda libraries used by
73 // GpuSparse. We maintain one such set of handles per unique stream.
74 class CudaSparseHandles {
75 public:
CudaSparseHandles(cudaStream_t stream)76 explicit CudaSparseHandles(cudaStream_t stream)
77 : initialized_(false), stream_(stream) {}
78
CudaSparseHandles(CudaSparseHandles && rhs)79 CudaSparseHandles(CudaSparseHandles&& rhs)
80 : initialized_(rhs.initialized_),
81 stream_(std::move(rhs.stream_)),
82 cusparse_handle_(rhs.cusparse_handle_) {
83 rhs.initialized_ = false;
84 }
85
operator =(CudaSparseHandles && rhs)86 CudaSparseHandles& operator=(CudaSparseHandles&& rhs) {
87 if (this == &rhs) return *this;
88 Release();
89 stream_ = std::move(rhs.stream_);
90 cusparse_handle_ = std::move(rhs.cusparse_handle_);
91 initialized_ = rhs.initialized_;
92 rhs.initialized_ = false;
93 return *this;
94 }
95
~CudaSparseHandles()96 ~CudaSparseHandles() { Release(); }
97
Initialize()98 Status Initialize() {
99 if (initialized_) return OkStatus();
100 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreate(&cusparse_handle_));
101 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSetStream(cusparse_handle_, stream_));
102 initialized_ = true;
103 return OkStatus();
104 }
105
handle()106 cusparseHandle_t& handle() {
107 DCHECK(initialized_);
108 return cusparse_handle_;
109 }
110
handle() const111 const cusparseHandle_t& handle() const {
112 DCHECK(initialized_);
113 return cusparse_handle_;
114 }
115
116 private:
Release()117 void Release() {
118 if (initialized_) {
119 // This should never return anything other than success
120 auto err = cusparseDestroy(cusparse_handle_);
121 DCHECK(err == CUSPARSE_STATUS_SUCCESS)
122 << "Failed to destroy cuSparse instance.";
123 initialized_ = false;
124 }
125 }
126 bool initialized_;
127 cudaStream_t stream_;
128 cusparseHandle_t cusparse_handle_;
129
130 TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseHandles);
131 };
132
133 // TODO(ebrevdo): Replace global mutex guarding CudaSparseHandles
134 // lookup with one of:
135 // 1. Adding the handle to the CudaStream structure; do the lookup there.
136 // 2. Add a thread-local cusparse, set it to the current stream
137 // upon each call.
138 // #1 seems like the cleanest option but will need to wait until this
139 // is moved into TF core.
140 static mutex handle_map_mutex(LINKER_INITIALIZED);
141
142 using HandleMap = std::unordered_map<cudaStream_t, CudaSparseHandles>;
143
144 // Returns a singleton map used for storing initialized handles for each unique
145 // cuda stream.
GetHandleMapSingleton()146 HandleMap* GetHandleMapSingleton() {
147 static HandleMap* cm = new HandleMap;
148 return cm;
149 }
150
151 } // namespace
152
GpuSparse(OpKernelContext * context)153 GpuSparse::GpuSparse(OpKernelContext* context)
154 : initialized_(false), context_(context) {
155 auto cuda_stream_ptr =
156 reinterpret_cast<const cudaStream_t*>(context->op_device_context()
157 ->stream()
158 ->implementation()
159 ->GpuStreamMemberHack());
160 DCHECK(cuda_stream_ptr);
161 gpu_stream_ = *cuda_stream_ptr;
162 }
163
Initialize()164 Status GpuSparse::Initialize() {
165 HandleMap* handle_map = GetHandleMapSingleton();
166 DCHECK(handle_map);
167 mutex_lock lock(handle_map_mutex);
168 auto it = handle_map->find(gpu_stream_);
169 if (it == handle_map->end()) {
170 LOG(INFO) << "Creating CudaSparse handles for stream " << gpu_stream_;
171 // Previously unseen Cuda stream. Initialize a set of Cuda sparse library
172 // handles for it.
173 CudaSparseHandles new_handles(gpu_stream_);
174 TF_RETURN_IF_ERROR(new_handles.Initialize());
175 it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles)))
176 .first;
177 }
178 gpusparse_handle_ = &it->second.handle();
179 initialized_ = true;
180 return OkStatus();
181 }
182
183 #define TF_CALL_CUSPARSE_DTYPES(m) \
184 m(float, CUDA_R_32F) m(double, CUDA_R_64F) \
185 m(std::complex<float>, CUDA_C_32F) m(std::complex<double>, CUDA_C_64F)
186
187 // Macro that specializes a sparse method for all 4 standard
188 // numeric types.
189 // TODO: reuse with cuda_solvers
190 #define TF_CALL_LAPACK_TYPES(m) \
191 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
192
193 // Macros to construct cusparse method names.
194 #define SPARSE_FN(method, sparse_prefix) cusparse##sparse_prefix##method
195 #define SPARSE_NAME(method, sparse_prefix) "cusparse" #sparse_prefix #method
196 #define BUFSIZE_FN(method, sparse_prefix) \
197 cusparse##sparse_prefix##method##_bufferSizeExt
198
199 //=============================================================================
200 // Wrappers of cuSparse computational methods begin here.
201 //
202 // WARNING to implementers: The function signatures listed in the online docs
203 // are sometimes inaccurate, e.g., are missing 'const' on pointers
204 // to immutable arguments, while the actual headers have them as expected.
205 // Check the actual declarations in the cusparse.h header file.
206 //=============================================================================
207
208 template <typename Scalar, typename SparseFn>
Gtsv2Impl(SparseFn op,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * B,int ldb,void * pBuffer)209 static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle,
210 int m, int n, const Scalar* dl, const Scalar* d,
211 const Scalar* du, Scalar* B, int ldb,
212 void* pBuffer) {
213 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
214 AsCudaComplex(d), AsCudaComplex(du),
215 AsCudaComplex(B), ldb, pBuffer));
216 return OkStatus();
217 }
218
219 #define GTSV2_INSTANCE(Scalar, sparse_prefix) \
220 template <> \
221 Status GpuSparse::Gtsv2<Scalar>(int m, int n, const Scalar* dl, \
222 const Scalar* d, const Scalar* du, \
223 Scalar* B, int ldb, void* pBuffer) const { \
224 DCHECK(initialized_); \
225 return Gtsv2Impl(SPARSE_FN(gtsv2, sparse_prefix), *gpusparse_handle_, m, \
226 n, dl, d, du, B, ldb, pBuffer); \
227 }
228
229 TF_CALL_LAPACK_TYPES(GTSV2_INSTANCE);
230
231 #define GTSV2_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \
232 template <> \
233 Status GpuSparse::Gtsv2NoPivot<Scalar>( \
234 int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
235 Scalar* B, int ldb, void* pBuffer) const { \
236 DCHECK(initialized_); \
237 return Gtsv2Impl(SPARSE_FN(gtsv2_nopivot, sparse_prefix), \
238 *gpusparse_handle_, m, n, dl, d, du, B, ldb, pBuffer); \
239 }
240
241 TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_INSTANCE);
242
243 template <typename Scalar, typename SparseFn>
Gtsv2BufferSizeExtImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * dl,const Scalar * d,const Scalar * du,const Scalar * B,int ldb,size_t * bufferSizeInBytes)244 static inline Status Gtsv2BufferSizeExtImpl(SparseFn op,
245 cusparseHandle_t cusparse_handle,
246 int m, int n, const Scalar* dl,
247 const Scalar* d, const Scalar* du,
248 const Scalar* B, int ldb,
249 size_t* bufferSizeInBytes) {
250 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
251 AsCudaComplex(d), AsCudaComplex(du),
252 AsCudaComplex(B), ldb, bufferSizeInBytes));
253 return OkStatus();
254 }
255
256 #define GTSV2_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
257 template <> \
258 Status GpuSparse::Gtsv2BufferSizeExt<Scalar>( \
259 int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
260 const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \
261 DCHECK(initialized_); \
262 return Gtsv2BufferSizeExtImpl( \
263 SPARSE_FN(gtsv2_bufferSizeExt, sparse_prefix), *gpusparse_handle_, m, \
264 n, dl, d, du, B, ldb, bufferSizeInBytes); \
265 }
266
267 TF_CALL_LAPACK_TYPES(GTSV2_BUFFER_SIZE_INSTANCE);
268
269 #define GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
270 template <> \
271 Status GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>( \
272 int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
273 const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \
274 DCHECK(initialized_); \
275 return Gtsv2BufferSizeExtImpl( \
276 SPARSE_FN(gtsv2_nopivot_bufferSizeExt, sparse_prefix), \
277 *gpusparse_handle_, m, n, dl, d, du, B, ldb, bufferSizeInBytes); \
278 }
279
280 TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE);
281
282 template <typename Scalar, typename SparseFn>
Gtsv2StridedBatchImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * x,int batchCount,int batchStride,void * pBuffer)283 static inline Status Gtsv2StridedBatchImpl(SparseFn op,
284 cusparseHandle_t cusparse_handle,
285 int m, const Scalar* dl,
286 const Scalar* d, const Scalar* du,
287 Scalar* x, int batchCount,
288 int batchStride, void* pBuffer) {
289 TF_RETURN_IF_GPUSPARSE_ERROR(op(
290 cusparse_handle, m, AsCudaComplex(dl), AsCudaComplex(d),
291 AsCudaComplex(du), AsCudaComplex(x), batchCount, batchStride, pBuffer));
292 return OkStatus();
293 }
294
295 #define GTSV2_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \
296 template <> \
297 Status GpuSparse::Gtsv2StridedBatch<Scalar>( \
298 int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \
299 int batchCount, int batchStride, void* pBuffer) const { \
300 DCHECK(initialized_); \
301 return Gtsv2StridedBatchImpl(SPARSE_FN(gtsv2StridedBatch, sparse_prefix), \
302 *gpusparse_handle_, m, dl, d, du, x, \
303 batchCount, batchStride, pBuffer); \
304 }
305
306 TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_INSTANCE);
307
308 template <typename Scalar, typename SparseFn>
Gtsv2StridedBatchBufferSizeImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,const Scalar * dl,const Scalar * d,const Scalar * du,const Scalar * x,int batchCount,int batchStride,size_t * bufferSizeInBytes)309 static inline Status Gtsv2StridedBatchBufferSizeImpl(
310 SparseFn op, cusparseHandle_t cusparse_handle, int m, const Scalar* dl,
311 const Scalar* d, const Scalar* du, const Scalar* x, int batchCount,
312 int batchStride, size_t* bufferSizeInBytes) {
313 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
314 AsCudaComplex(d), AsCudaComplex(du),
315 AsCudaComplex(x), batchCount, batchStride,
316 bufferSizeInBytes));
317 return OkStatus();
318 }
319
320 #define GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
321 template <> \
322 Status GpuSparse::Gtsv2StridedBatchBufferSizeExt<Scalar>( \
323 int m, const Scalar* dl, const Scalar* d, const Scalar* du, \
324 const Scalar* x, int batchCount, int batchStride, \
325 size_t* bufferSizeInBytes) const { \
326 DCHECK(initialized_); \
327 return Gtsv2StridedBatchBufferSizeImpl( \
328 SPARSE_FN(gtsv2StridedBatch_bufferSizeExt, sparse_prefix), \
329 *gpusparse_handle_, m, dl, d, du, x, batchCount, batchStride, \
330 bufferSizeInBytes); \
331 }
332
333 TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE);
334
Coo2csr(const int * cooRowInd,int nnz,int m,int * csrRowPtr) const335 Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
336 int* csrRowPtr) const {
337 // cusparseStatus_t CUSPARSEAPI cusparseXcoo2csr(cusparseHandle_t handle,
338 // const int *cooRowInd,
339 // int nnz,
340 // int m,
341 // int *csrSortedRowPtr,
342 // cusparseIndexBase_t
343 // idxBase);
344 DCHECK(initialized_);
345 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcoo2csr(*gpusparse_handle_, cooRowInd,
346 nnz, m, csrRowPtr,
347 CUSPARSE_INDEX_BASE_ZERO));
348 return OkStatus();
349 }
350
Csr2coo(const int * csrRowPtr,int nnz,int m,int * cooRowInd) const351 Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
352 int* cooRowInd) const {
353 // cusparseStatus_t CUSPARSEAPI cusparseXcsr2coo(cusparseHandle_t handle,
354 // const int *csrRowPtr,
355 // int nnz,
356 // int m,
357 // int *cooRowInd,
358 // cusparseIndexBase_t
359 // idxBase);
360 DCHECK(initialized_);
361 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
362 nnz, m, cooRowInd,
363 CUSPARSE_INDEX_BASE_ZERO));
364 return OkStatus();
365 }
366
CsrgeamNnz(int m,int n,const cusparseMatDescr_t descrA,int nnzA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,int * csrSortedRowPtrC,int * nnzTotalDevHostPtr,void * workspace)367 Status GpuSparse::CsrgeamNnz(
368 int m, int n, const cusparseMatDescr_t descrA, int nnzA,
369 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
370 const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
371 const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
372 int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, void* workspace) {
373 DCHECK(initialized_);
374 DCHECK(nnzTotalDevHostPtr != nullptr);
375 #if CUDA_VERSION >= 10000
376 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeam2Nnz(
377 *gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
378 csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
379 descrC, csrSortedRowPtrC, nnzTotalDevHostPtr, workspace));
380 #else
381 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz(
382 *gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
383 csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
384 descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
385 #endif
386 return OkStatus();
387 }
388
389 #if CUDA_VERSION < 10020
390
391 template <typename Scalar, typename SparseFnT>
CsrmmImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,cusparseOperation_t transB,int m,int n,int k,int nnz,const Scalar * alpha_host,const cusparseMatDescr_t descrA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * B,int ldb,const Scalar * beta_host,Scalar * C,int ldc)392 static inline Status CsrmmImpl(
393 SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
394 cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k,
395 int nnz, const Scalar* alpha_host, const cusparseMatDescr_t descrA,
396 const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
397 const int* csrSortedColIndA, const Scalar* B, int ldb,
398 const Scalar* beta_host, Scalar* C, int ldc) {
399 // cusparseStatus_t CUSPARSEAPI cusparseScsrmm2(
400 // cusparseHandle_t handle, cusparseOperation_t transA,
401 // cusparseOperation_t transB, int m, int n, int k, int nnz,
402 // const float* alpha, const cusparseMatDescr_t descrA,
403 // const float* csrSortedValA, const int* csrSortedRowPtrA,
404 // const int* csrSortedColIndA, const float* B, int ldb, const float*
405 // beta, float* C, int ldc);
406 TF_RETURN_IF_GPUSPARSE_ERROR(op(
407 cusparse_handle, transA, transB, m, n, k, nnz, AsCudaComplex(alpha_host),
408 descrA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
409 AsCudaComplex(B), ldb, AsCudaComplex(beta_host), AsCudaComplex(C), ldc));
410 return OkStatus();
411 }
412
413 #define CSRMM_INSTANCE(Scalar, sparse_prefix) \
414 template <> \
415 Status GpuSparse::Csrmm<Scalar>( \
416 cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, \
417 int k, int nnz, const Scalar* alpha_host, \
418 const cusparseMatDescr_t descrA, const Scalar* csrSortedValA, \
419 const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
420 const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) \
421 const { \
422 DCHECK(initialized_); \
423 return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \
424 *gpusparse_handle_, transA, transB, m, n, k, nnz, \
425 alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \
426 csrSortedColIndA, B, ldb, beta_host, C, ldc); \
427 }
428
429 TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE);
430
431 #else
432
433 #define SPMM_BUFFERSIZE_INSTANCE(Scalar, dtype) \
434 template <> \
435 Status GpuSparse::SpMMBufferSize<Scalar>( \
436 cusparseOperation_t transA, cusparseOperation_t transB, \
437 const Scalar* alpha, const cusparseSpMatDescr_t matA, \
438 const gpusparseDnMatDescr_t matB, const Scalar* beta, \
439 gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, size_t* bufferSize) \
440 const { \
441 DCHECK(initialized_); \
442 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM_bufferSize( \
443 *gpusparse_handle_, transA, transB, alpha, matA, matB, beta, matC, \
444 dtype, alg, bufferSize)); \
445 return Status::OK(); \
446 }
447
448 TF_CALL_CUSPARSE_DTYPES(SPMM_BUFFERSIZE_INSTANCE);
449
450 #define SPMM_INSTANCE(Scalar, dtype) \
451 template <> \
452 Status GpuSparse::SpMM<Scalar>( \
453 cusparseOperation_t transA, cusparseOperation_t transB, \
454 const Scalar* alpha, const cusparseSpMatDescr_t matA, \
455 const gpusparseDnMatDescr_t matB, const Scalar* beta, \
456 gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, int8* buffer) const { \
457 DCHECK(initialized_); \
458 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM(*gpusparse_handle_, transA, \
459 transB, alpha, matA, matB, beta, \
460 matC, dtype, alg, buffer)); \
461 return Status::OK(); \
462 }
463
464 TF_CALL_CUSPARSE_DTYPES(SPMM_INSTANCE);
465
466 #endif
467
468 #if CUDA_VERSION < 10020
469
470 template <typename Scalar, typename SparseFnT>
CsrmvImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,int m,int n,int nnz,const Scalar * alpha_host,const cusparseMatDescr_t descrA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * x,const Scalar * beta_host,Scalar * y)471 static inline Status CsrmvImpl(
472 SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
473 cusparseOperation_t transA, int m, int n, int nnz, const Scalar* alpha_host,
474 const cusparseMatDescr_t descrA, const Scalar* csrSortedValA,
475 const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* x,
476 const Scalar* beta_host, Scalar* y) {
477 TF_RETURN_IF_GPUSPARSE_ERROR(
478 op(cusparse_handle, transA, m, n, nnz, AsCudaComplex(alpha_host), descrA,
479 AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
480 AsCudaComplex(x), AsCudaComplex(beta_host), AsCudaComplex(y)));
481 return OkStatus();
482 }
483
484 // TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
485 #define CSRMV_INSTANCE(Scalar, sparse_prefix) \
486 template <> \
487 Status GpuSparse::Csrmv<Scalar>( \
488 cusparseOperation_t transA, int m, int n, int nnz, \
489 const Scalar* alpha_host, const cusparseMatDescr_t descrA, \
490 const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
491 const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, \
492 Scalar* y) const { \
493 DCHECK(initialized_); \
494 if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
495 return CsrmvImpl(SPARSE_FN(csrmv_mp, sparse_prefix), context_, \
496 *gpusparse_handle_, transA, m, n, nnz, alpha_host, \
497 descrA, csrSortedValA, csrSortedRowPtrA, \
498 csrSortedColIndA, x, beta_host, y); \
499 } else { \
500 return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \
501 *gpusparse_handle_, transA, m, n, nnz, alpha_host, \
502 descrA, csrSortedValA, csrSortedRowPtrA, \
503 csrSortedColIndA, x, beta_host, y); \
504 } \
505 }
506
507 TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE);
508
509 #else
510
511 template <typename Scalar>
CsrmvExImpl(cudaDataType_t dtype,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,int m,int n,int nnz,const Scalar * alpha_host,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * x,const Scalar * beta_host,Scalar * y)512 static inline Status CsrmvExImpl(cudaDataType_t dtype, OpKernelContext* context,
513 cusparseHandle_t cusparse_handle,
514 cusparseOperation_t transA, int m, int n,
515 int nnz, const Scalar* alpha_host,
516 const Scalar* csrSortedValA,
517 const int* csrSortedRowPtrA,
518 const int* csrSortedColIndA, const Scalar* x,
519 const Scalar* beta_host, Scalar* y) {
520 cusparseMatDescr_t descrA;
521 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
522 TF_RETURN_IF_GPUSPARSE_ERROR(
523 cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
524 TF_RETURN_IF_GPUSPARSE_ERROR(
525 cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
526 // CUSPARSE_ALG_MERGE_PATH algo only supports non-transpose matrix.
527 DCHECK(transA == CUSPARSE_OPERATION_NON_TRANSPOSE);
528
529 size_t bufferSize;
530 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx_bufferSize(
531 cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
532 dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
533 x, dtype, beta_host, dtype, y, dtype, dtype, &bufferSize));
534
535 Tensor buffer;
536 TF_RETURN_IF_ERROR(context->allocate_temp(
537 DT_INT8, TensorShape({static_cast<int64_t>(bufferSize)}), &buffer));
538 auto pBuffer = buffer.flat<int8>();
539 DCHECK(pBuffer.data() != nullptr);
540
541 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx(
542 cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
543 dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
544 x, dtype, beta_host, dtype, y, dtype, dtype, pBuffer.data()));
545
546 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyMatDescr(descrA));
547 return Status::OK();
548 }
549
550 template <typename Scalar>
SpMVImpl(cudaDataType_t dtype,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,int m,int n,int nnz,const Scalar * alpha_host,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * x,const Scalar * beta_host,Scalar * y)551 static inline Status SpMVImpl(cudaDataType_t dtype, OpKernelContext* context,
552 cusparseHandle_t cusparse_handle,
553 cusparseOperation_t transA, int m, int n, int nnz,
554 const Scalar* alpha_host,
555 const Scalar* csrSortedValA,
556 const int* csrSortedRowPtrA,
557 const int* csrSortedColIndA, const Scalar* x,
558 const Scalar* beta_host, Scalar* y) {
559 cusparseSpMatDescr_t matA;
560 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
561 &matA, m, n, nnz, const_cast<int*>(csrSortedRowPtrA),
562 const_cast<int*>(csrSortedColIndA), const_cast<Scalar*>(csrSortedValA),
563 CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, dtype));
564
565 cusparseDnVecDescr_t vecX, vecY;
566 int sizeX = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? n : m;
567 int sizeY = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? m : n;
568 TF_RETURN_IF_GPUSPARSE_ERROR(
569 cusparseCreateDnVec(&vecX, sizeX, const_cast<Scalar*>(x), dtype));
570 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnVec(&vecY, sizeY, y, dtype));
571
572 size_t bufferSize;
573 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMV_bufferSize(
574 cusparse_handle, transA, alpha_host, matA, vecX, beta_host, vecY, dtype,
575 CUSPARSE_CSRMV_ALG1, &bufferSize));
576
577 Tensor buffer;
578 TF_RETURN_IF_ERROR(context->allocate_temp(
579 DT_INT8, TensorShape({static_cast<int64_t>(bufferSize)}), &buffer));
580 auto pBuffer = buffer.flat<int8>();
581 DCHECK(pBuffer.data() != nullptr);
582
583 TF_RETURN_IF_GPUSPARSE_ERROR(
584 cusparseSpMV(cusparse_handle, transA, alpha_host, matA, vecX, beta_host,
585 vecY, dtype, CUSPARSE_CSRMV_ALG1, pBuffer.data()));
586
587 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecY));
588 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecX));
589 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA));
590 return Status::OK();
591 }
592
593 #define CSRMV_INSTANCE(Scalar, cudaDataType) \
594 template <> \
595 Status GpuSparse::Csrmv<Scalar>( \
596 cusparseOperation_t transA, int m, int n, int nnz, \
597 const Scalar* alpha_host, const Scalar* csrSortedValA, \
598 const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
599 const Scalar* x, const Scalar* beta_host, Scalar* y) const { \
600 DCHECK(initialized_); \
601 if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
602 return CsrmvExImpl(cudaDataType, context_, *gpusparse_handle_, transA, \
603 m, n, nnz, alpha_host, csrSortedValA, \
604 csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y); \
605 } else { \
606 return SpMVImpl(cudaDataType, context_, *gpusparse_handle_, transA, m, \
607 n, nnz, alpha_host, csrSortedValA, csrSortedRowPtrA, \
608 csrSortedColIndA, x, beta_host, y); \
609 } \
610 }
611
612 TF_CALL_CUSPARSE_DTYPES(CSRMV_INSTANCE);
613
614 #endif // CUDA_VERSION < 10020
615
616 #if CUDA_VERSION < 10000
617
618 template <typename Scalar, typename SparseFnT>
CsrgeamImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * alpha,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * beta,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC)619 static inline Status CsrgeamImpl(
620 SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
621 int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
622 int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
623 const int* csrSortedColIndA, const Scalar* beta,
624 const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
625 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
626 const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
627 int* csrSortedRowPtrC, int* csrSortedColIndC) {
628 TF_RETURN_IF_GPUSPARSE_ERROR(
629 op(cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
630 AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
631 AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
632 csrSortedRowPtrB, csrSortedColIndB, descrC,
633 AsCudaComplex(csrSortedValC), csrSortedRowPtrC, csrSortedColIndC));
634 return Status::OK();
635 }
636
637 #define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
638 template <> \
639 Status GpuSparse::Csrgeam<Scalar>( \
640 int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
641 int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
642 const int* csrSortedColIndA, const Scalar* beta, \
643 const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
644 const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
645 const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
646 int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) { \
647 DCHECK(initialized_); \
648 return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \
649 *gpusparse_handle_, m, n, alpha, descrA, nnzA, \
650 csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
651 beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
652 csrSortedColIndB, descrC, csrSortedValC, \
653 csrSortedRowPtrC, csrSortedColIndC); \
654 }
655
656 #else
657
658 template <typename Scalar, typename SparseFnT>
Csrgeam2Impl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * alpha,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * beta,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC,void * workspace)659 static inline Status Csrgeam2Impl(
660 SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
661 int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
662 int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
663 const int* csrSortedColIndA, const Scalar* beta,
664 const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
665 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
666 const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
667 int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) {
668 TF_RETURN_IF_GPUSPARSE_ERROR(op(
669 cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
670 AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
671 AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
672 csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
673 csrSortedRowPtrC, csrSortedColIndC, workspace));
674 return OkStatus();
675 }
676
677 #define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
678 template <> \
679 Status GpuSparse::Csrgeam<Scalar>( \
680 int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
681 int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
682 const int* csrSortedColIndA, const Scalar* beta, \
683 const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
684 const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
685 const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
686 int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) { \
687 DCHECK(initialized_); \
688 return Csrgeam2Impl(SPARSE_FN(csrgeam2, sparse_prefix), context_, \
689 *gpusparse_handle_, m, n, alpha, descrA, nnzA, \
690 csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
691 beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
692 csrSortedColIndB, descrC, csrSortedValC, \
693 csrSortedRowPtrC, csrSortedColIndC, workspace); \
694 }
695
696 #endif
697
698 TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE);
699
700 #if CUDA_VERSION < 10000
701
702 #define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
703 template <> \
704 Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
705 int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
706 int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
707 const int* csrSortedColIndA, const Scalar* beta, \
708 const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
709 const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
710 const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
711 int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
712 DCHECK(initialized_); \
713 *bufferSize = 0; \
714 return Status::OK(); \
715 }
716
717 #else
718
719 template <typename Scalar, typename SparseFnT>
CsrgeamBufferSizeExtImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t sparse_handle,int m,int n,const Scalar * alpha,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * beta,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC,size_t * bufferSize)720 static inline Status CsrgeamBufferSizeExtImpl(
721 SparseFnT op, OpKernelContext* context, cusparseHandle_t sparse_handle,
722 int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
723 int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
724 const int* csrSortedColIndA, const Scalar* beta,
725 const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
726 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
727 const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
728 int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) {
729 TF_RETURN_IF_GPUSPARSE_ERROR(op(
730 sparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
731 AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
732 AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
733 csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
734 csrSortedRowPtrC, csrSortedColIndC, bufferSize));
735 return OkStatus();
736 }
737
738 #define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
739 template <> \
740 Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
741 int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
742 int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
743 const int* csrSortedColIndA, const Scalar* beta, \
744 const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
745 const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
746 const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
747 int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
748 DCHECK(initialized_); \
749 return CsrgeamBufferSizeExtImpl( \
750 SPARSE_FN(csrgeam2_bufferSizeExt, sparse_prefix), context_, \
751 *gpusparse_handle_, m, n, alpha, descrA, nnzA, csrSortedValA, \
752 csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, csrSortedValB, \
753 csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC, \
754 csrSortedRowPtrC, csrSortedColIndC, bufferSize); \
755 }
756
757 #endif
758
759 TF_CALL_LAPACK_TYPES(CSRGEAM_BUFFERSIZE_INSTANCE);
760
761 #if CUDA_VERSION < 10000
762
CsrgemmNnz(cusparseOperation_t transA,cusparseOperation_t transB,int m,int k,int n,const cusparseMatDescr_t descrA,int nnzA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,int * csrSortedRowPtrC,int * nnzTotalDevHostPtr)763 Status GpuSparse::CsrgemmNnz(
764 cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
765 const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA,
766 const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
767 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
768 const cusparseMatDescr_t descrC, int* csrSortedRowPtrC,
769 int* nnzTotalDevHostPtr) {
770 DCHECK(initialized_);
771 DCHECK(nnzTotalDevHostPtr != nullptr);
772 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemmNnz(
773 *gpusparse_handle_, transA, transB, m, k, n, descrA, nnzA,
774 csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
775 csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
776 return Status::OK();
777 }
778
779 template <typename Scalar, typename SparseFnT>
CsrgemmImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,cusparseOperation_t transB,int m,int k,int n,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC)780 static inline Status CsrgemmImpl(
781 SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
782 cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
783 const cusparseMatDescr_t descrA, int nnzA, const Scalar* csrSortedValA,
784 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
785 const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
786 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
787 const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
788 int* csrSortedRowPtrC, int* csrSortedColIndC) {
789 TF_RETURN_IF_GPUSPARSE_ERROR(
790 op(cusparse_handle, transA, transB, m, k, n, descrA, nnzA,
791 AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
792 descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
793 csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
794 csrSortedRowPtrC, csrSortedColIndC));
795 return Status::OK();
796 }
797
798 #define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
799 template <> \
800 Status GpuSparse::Csrgemm<Scalar>( \
801 cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, \
802 int n, const cusparseMatDescr_t descrA, int nnzA, \
803 const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
804 const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, \
805 const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
806 const int* csrSortedColIndB, const cusparseMatDescr_t descrC, \
807 Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \
808 DCHECK(initialized_); \
809 return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \
810 *gpusparse_handle_, transA, transB, m, k, n, descrA, \
811 nnzA, csrSortedValA, csrSortedRowPtrA, \
812 csrSortedColIndA, descrB, nnzB, csrSortedValB, \
813 csrSortedRowPtrB, csrSortedColIndB, descrC, \
814 csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); \
815 }
816
817 TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
818
819 #else
820
821 template <typename T>
one_ptr()822 static const T* one_ptr() {
823 static const T one = static_cast<T>(1);
824 return &one;
825 }
826
827 template <typename T>
null_ptr()828 static const T* null_ptr() {
829 return nullptr;
830 }
831
832 #define CSRGEMM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
833 template <> \
834 Status GpuSparse::CsrgemmBufferSize<Scalar>( \
835 int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, \
836 const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
837 const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, \
838 const int* csrSortedColIndB, csrgemm2Info_t info, \
839 size_t* workspaceBytes) { \
840 DCHECK(initialized_); \
841 TF_RETURN_IF_GPUSPARSE_ERROR(SPARSE_FN(csrgemm2_bufferSizeExt, \
842 sparse_prefix)( \
843 *gpusparse_handle_, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA, \
844 nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, \
845 csrSortedRowPtrB, csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), \
846 descrA, 0, null_ptr<int>(), null_ptr<int>(), info, workspaceBytes)); \
847 return OkStatus(); \
848 }
849
850 TF_CALL_LAPACK_TYPES(CSRGEMM_BUFFERSIZE_INSTANCE);
851
CsrgemmNnz(int m,int n,int k,const cusparseMatDescr_t descrA,int nnzA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,int * csrSortedRowPtrC,int * nnzTotalDevHostPtr,csrgemm2Info_t info,void * workspace)852 Status GpuSparse::CsrgemmNnz(
853 int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
854 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
855 const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
856 const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
857 int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, csrgemm2Info_t info,
858 void* workspace) {
859 DCHECK(initialized_);
860 DCHECK(nnzTotalDevHostPtr != nullptr);
861
862 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemm2Nnz(
863 *gpusparse_handle_, m, n, k, descrA, nnzA, csrSortedRowPtrA,
864 csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
865 descrA, 0, null_ptr<int>(), null_ptr<int>(), descrC, csrSortedRowPtrC,
866 nnzTotalDevHostPtr, info, workspace));
867 return OkStatus();
868 }
869
870 template <typename Scalar, typename SparseFnT>
CsrgemmImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int k,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC,const csrgemm2Info_t info,void * workspace)871 static inline Status CsrgemmImpl(
872 SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
873 int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
874 const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
875 const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
876 const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
877 const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
878 Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC,
879 const csrgemm2Info_t info, void* workspace) {
880 TF_RETURN_IF_GPUSPARSE_ERROR(
881 op(cusparse_handle, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA,
882 nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
883 descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
884 csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), descrA, 0,
885 AsCudaComplex(null_ptr<Scalar>()), null_ptr<int>(), null_ptr<int>(),
886 descrC, AsCudaComplex(csrSortedValC), csrSortedRowPtrC,
887 csrSortedColIndC, info, workspace));
888 return OkStatus();
889 }
890
891 #define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
892 template <> \
893 Status GpuSparse::Csrgemm<Scalar>( \
894 int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, \
895 const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
896 const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, \
897 const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
898 const int* csrSortedColIndB, const cusparseMatDescr_t descrC, \
899 Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC, \
900 const csrgemm2Info_t info, void* workspace) { \
901 DCHECK(initialized_); \
902 return CsrgemmImpl(SPARSE_FN(csrgemm2, sparse_prefix), context_, \
903 *gpusparse_handle_, m, n, k, descrA, nnzA, \
904 csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
905 descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
906 csrSortedColIndB, descrC, csrSortedValC, \
907 csrSortedRowPtrC, csrSortedColIndC, info, workspace); \
908 }
909
910 TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
911
912 #endif // CUDA_VERSION < 10000
913
914 template <typename Scalar, typename BufferSizeFnT, typename SparseFnT>
Csru2csrImpl(SparseFnT op,BufferSizeFnT buffer_size_op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int nnz,const cusparseMatDescr_t descrA,Scalar * csrVal,const int * csrRowPtr,int * csrColInd)915 static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
916 OpKernelContext* context,
917 cusparseHandle_t cusparse_handle, int m,
918 int n, int nnz,
919 const cusparseMatDescr_t descrA,
920 Scalar* csrVal, const int* csrRowPtr,
921 int* csrColInd) {
922 GpuSparseCsrSortingConversionInfo info;
923 TF_RETURN_IF_ERROR(info.Initialize());
924
925 size_t pBufferSizeInBytes = 0;
926
927 TF_RETURN_IF_GPUSPARSE_ERROR(
928 buffer_size_op(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
929 csrRowPtr, csrColInd, info.info(), &pBufferSizeInBytes));
930
931 Tensor pBuffer_t;
932 TF_RETURN_IF_ERROR(context->allocate_temp(
933 DT_INT8, TensorShape({static_cast<int64_t>(pBufferSizeInBytes)}),
934 &pBuffer_t));
935 auto pBuffer = pBuffer_t.flat<int8>();
936 DCHECK(pBuffer.data() != nullptr);
937
938 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, descrA,
939 AsCudaComplex(csrVal), csrRowPtr, csrColInd,
940 info.info(), pBuffer.data()));
941
942 return OkStatus();
943 }
944
945 #define CSRU2CSR_INSTANCE(Scalar, sparse_prefix) \
946 template <> \
947 Status GpuSparse::Csru2csr<Scalar>( \
948 int m, int n, int nnz, const cusparseMatDescr_t descrA, Scalar* csrVal, \
949 const int* csrRowPtr, int* csrColInd) { \
950 DCHECK(initialized_); \
951 return Csru2csrImpl(SPARSE_FN(csru2csr, sparse_prefix), \
952 BUFSIZE_FN(csru2csr, sparse_prefix), context_, \
953 *gpusparse_handle_, m, n, nnz, descrA, csrVal, \
954 csrRowPtr, csrColInd); \
955 }
956
957 TF_CALL_LAPACK_TYPES(CSRU2CSR_INSTANCE);
958
959 #if CUDA_VERSION < 10010
960
961 template <typename Scalar, typename SparseFnT>
Csr2cscImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int nnz,const Scalar * csrVal,const int * csrRowPtr,const int * csrColInd,Scalar * cscVal,int * cscRowInd,int * cscColPtr,const cusparseAction_t copyValues)962 static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
963 cusparseHandle_t cusparse_handle, int m, int n,
964 int nnz, const Scalar* csrVal,
965 const int* csrRowPtr, const int* csrColInd,
966 Scalar* cscVal, int* cscRowInd, int* cscColPtr,
967 const cusparseAction_t copyValues) {
968 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz,
969 AsCudaComplex(csrVal), csrRowPtr, csrColInd,
970 AsCudaComplex(cscVal), cscRowInd, cscColPtr,
971 copyValues, CUSPARSE_INDEX_BASE_ZERO));
972 return Status::OK();
973 }
974
975 #define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \
976 template <> \
977 Status GpuSparse::Csr2csc<Scalar>( \
978 int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
979 const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
980 const cusparseAction_t copyValues) { \
981 DCHECK(initialized_); \
982 return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \
983 *gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
984 csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
985 }
986
987 TF_CALL_LAPACK_TYPES(CSR2CSC_INSTANCE);
988
989 #else
990
991 template <typename Scalar>
Csr2cscImpl(cudaDataType_t dtype,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int nnz,const Scalar * csrVal,const int * csrRowPtr,const int * csrColInd,Scalar * cscVal,int * cscRowInd,int * cscColPtr,const cusparseAction_t copyValues)992 static inline Status Csr2cscImpl(cudaDataType_t dtype, OpKernelContext* context,
993 cusparseHandle_t cusparse_handle, int m, int n,
994 int nnz, const Scalar* csrVal,
995 const int* csrRowPtr, const int* csrColInd,
996 Scalar* cscVal, int* cscRowInd, int* cscColPtr,
997 const cusparseAction_t copyValues) {
998 size_t bufferSize;
999 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsr2cscEx2_bufferSize(
1000 cusparse_handle, m, n, nnz, AsCudaComplex(csrVal), csrRowPtr, csrColInd,
1001 AsCudaComplex(cscVal), cscColPtr, cscRowInd, dtype, copyValues,
1002 CUSPARSE_INDEX_BASE_ZERO, CUSPARSE_CSR2CSC_ALG2, &bufferSize));
1003
1004 Tensor buffer;
1005 TF_RETURN_IF_ERROR(context->allocate_temp(
1006 DataTypeToEnum<Scalar>::value,
1007 TensorShape({static_cast<int64_t>(bufferSize)}), &buffer));
1008
1009 DCHECK(buffer.flat<Scalar>().data() != nullptr);
1010
1011 TF_RETURN_IF_GPUSPARSE_ERROR(
1012 cusparseCsr2cscEx2(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
1013 csrRowPtr, csrColInd, AsCudaComplex(cscVal), cscColPtr,
1014 cscRowInd, dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO,
1015 CUSPARSE_CSR2CSC_ALG2, buffer.flat<Scalar>().data()));
1016
1017 return OkStatus();
1018 }
1019
1020 #define CSR2CSC_INSTANCE(Scalar, cudaDataType) \
1021 template <> \
1022 Status GpuSparse::Csr2csc<Scalar>( \
1023 int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
1024 const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
1025 const cusparseAction_t copyValues) { \
1026 DCHECK(initialized_); \
1027 return Csr2cscImpl(cudaDataType, context_, *gpusparse_handle_, m, n, nnz, \
1028 csrVal, csrRowPtr, csrColInd, cscVal, cscRowInd, \
1029 cscColPtr, copyValues); \
1030 }
1031
1032 TF_CALL_CUSPARSE_DTYPES(CSR2CSC_INSTANCE);
1033
1034 #endif // CUDA_VERSION < 10010
1035
1036 } // namespace tensorflow
1037
1038 #endif // GOOGLE_CUDA
1039