xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/cuda_sparse.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef GOOGLE_CUDA
17 
18 #include "tensorflow/core/util/cuda_sparse.h"
19 
20 #include <complex>
21 #include <memory>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "third_party/gpus/cuda/include/cusparse.h"
27 #include "third_party/gpus/cuda/include/library_types.h"
28 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/core/blocking_counter.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/gtl/inlined_vector.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/stream_executor.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/util/gpu_solvers.h"
41 
42 // TODO(rmlarsen,penporn): Investigate using newer kernels in CUDA 10.1+.
43 
44 namespace tensorflow {
45 namespace {
46 
47 // Type traits to get CUDA complex types from std::complex<>.
48 // TODO: reuse with cuda_solvers
49 template <typename T>
50 struct CudaComplexT {
51   typedef T type;
52 };
53 template <>
54 struct CudaComplexT<std::complex<float>> {
55   typedef cuComplex type;
56 };
57 template <>
58 struct CudaComplexT<std::complex<double>> {
59   typedef cuDoubleComplex type;
60 };
61 // Converts pointers of std::complex<> to pointers of
62 // cuComplex/cuDoubleComplex. No type conversion for non-complex types.
63 template <typename T>
AsCudaComplex(const T * p)64 inline const typename CudaComplexT<T>::type* AsCudaComplex(const T* p) {
65   return reinterpret_cast<const typename CudaComplexT<T>::type*>(p);
66 }
67 template <typename T>
AsCudaComplex(T * p)68 inline typename CudaComplexT<T>::type* AsCudaComplex(T* p) {
69   return reinterpret_cast<typename CudaComplexT<T>::type*>(p);
70 }
71 
72 // A set of initialized handles to the underlying Cuda libraries used by
73 // GpuSparse. We maintain one such set of handles per unique stream.
74 class CudaSparseHandles {
75  public:
CudaSparseHandles(cudaStream_t stream)76   explicit CudaSparseHandles(cudaStream_t stream)
77       : initialized_(false), stream_(stream) {}
78 
CudaSparseHandles(CudaSparseHandles && rhs)79   CudaSparseHandles(CudaSparseHandles&& rhs)
80       : initialized_(rhs.initialized_),
81         stream_(std::move(rhs.stream_)),
82         cusparse_handle_(rhs.cusparse_handle_) {
83     rhs.initialized_ = false;
84   }
85 
operator =(CudaSparseHandles && rhs)86   CudaSparseHandles& operator=(CudaSparseHandles&& rhs) {
87     if (this == &rhs) return *this;
88     Release();
89     stream_ = std::move(rhs.stream_);
90     cusparse_handle_ = std::move(rhs.cusparse_handle_);
91     initialized_ = rhs.initialized_;
92     rhs.initialized_ = false;
93     return *this;
94   }
95 
~CudaSparseHandles()96   ~CudaSparseHandles() { Release(); }
97 
Initialize()98   Status Initialize() {
99     if (initialized_) return OkStatus();
100     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreate(&cusparse_handle_));
101     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSetStream(cusparse_handle_, stream_));
102     initialized_ = true;
103     return OkStatus();
104   }
105 
handle()106   cusparseHandle_t& handle() {
107     DCHECK(initialized_);
108     return cusparse_handle_;
109   }
110 
handle() const111   const cusparseHandle_t& handle() const {
112     DCHECK(initialized_);
113     return cusparse_handle_;
114   }
115 
116  private:
Release()117   void Release() {
118     if (initialized_) {
119       // This should never return anything other than success
120       auto err = cusparseDestroy(cusparse_handle_);
121       DCHECK(err == CUSPARSE_STATUS_SUCCESS)
122           << "Failed to destroy cuSparse instance.";
123       initialized_ = false;
124     }
125   }
126   bool initialized_;
127   cudaStream_t stream_;
128   cusparseHandle_t cusparse_handle_;
129 
130   TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseHandles);
131 };
132 
133 // TODO(ebrevdo): Replace global mutex guarding CudaSparseHandles
134 // lookup with one of:
135 //    1. Adding the handle to the CudaStream structure; do the lookup there.
136 //    2. Add a thread-local cusparse, set it to the current stream
137 //       upon each call.
138 // #1 seems like the cleanest option but will need to wait until this
139 // is moved into TF core.
140 static mutex handle_map_mutex(LINKER_INITIALIZED);
141 
142 using HandleMap = std::unordered_map<cudaStream_t, CudaSparseHandles>;
143 
144 // Returns a singleton map used for storing initialized handles for each unique
145 // cuda stream.
GetHandleMapSingleton()146 HandleMap* GetHandleMapSingleton() {
147   static HandleMap* cm = new HandleMap;
148   return cm;
149 }
150 
151 }  // namespace
152 
GpuSparse(OpKernelContext * context)153 GpuSparse::GpuSparse(OpKernelContext* context)
154     : initialized_(false), context_(context) {
155   auto cuda_stream_ptr =
156       reinterpret_cast<const cudaStream_t*>(context->op_device_context()
157                                                 ->stream()
158                                                 ->implementation()
159                                                 ->GpuStreamMemberHack());
160   DCHECK(cuda_stream_ptr);
161   gpu_stream_ = *cuda_stream_ptr;
162 }
163 
Initialize()164 Status GpuSparse::Initialize() {
165   HandleMap* handle_map = GetHandleMapSingleton();
166   DCHECK(handle_map);
167   mutex_lock lock(handle_map_mutex);
168   auto it = handle_map->find(gpu_stream_);
169   if (it == handle_map->end()) {
170     LOG(INFO) << "Creating CudaSparse handles for stream " << gpu_stream_;
171     // Previously unseen Cuda stream. Initialize a set of Cuda sparse library
172     // handles for it.
173     CudaSparseHandles new_handles(gpu_stream_);
174     TF_RETURN_IF_ERROR(new_handles.Initialize());
175     it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles)))
176              .first;
177   }
178   gpusparse_handle_ = &it->second.handle();
179   initialized_ = true;
180   return OkStatus();
181 }
182 
183 #define TF_CALL_CUSPARSE_DTYPES(m)           \
184   m(float, CUDA_R_32F) m(double, CUDA_R_64F) \
185       m(std::complex<float>, CUDA_C_32F) m(std::complex<double>, CUDA_C_64F)
186 
187 // Macro that specializes a sparse method for all 4 standard
188 // numeric types.
189 // TODO: reuse with cuda_solvers
190 #define TF_CALL_LAPACK_TYPES(m) \
191   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
192 
193 // Macros to construct cusparse method names.
194 #define SPARSE_FN(method, sparse_prefix) cusparse##sparse_prefix##method
195 #define SPARSE_NAME(method, sparse_prefix) "cusparse" #sparse_prefix #method
196 #define BUFSIZE_FN(method, sparse_prefix) \
197   cusparse##sparse_prefix##method##_bufferSizeExt
198 
199 //=============================================================================
200 // Wrappers of cuSparse computational methods begin here.
201 //
202 // WARNING to implementers: The function signatures listed in the online docs
203 // are sometimes inaccurate, e.g., are missing 'const' on pointers
204 // to immutable arguments, while the actual headers have them as expected.
205 // Check the actual declarations in the cusparse.h header file.
206 //=============================================================================
207 
208 template <typename Scalar, typename SparseFn>
Gtsv2Impl(SparseFn op,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * B,int ldb,void * pBuffer)209 static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle,
210                                int m, int n, const Scalar* dl, const Scalar* d,
211                                const Scalar* du, Scalar* B, int ldb,
212                                void* pBuffer) {
213   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
214                                   AsCudaComplex(d), AsCudaComplex(du),
215                                   AsCudaComplex(B), ldb, pBuffer));
216   return OkStatus();
217 }
218 
219 #define GTSV2_INSTANCE(Scalar, sparse_prefix)                                \
220   template <>                                                                \
221   Status GpuSparse::Gtsv2<Scalar>(int m, int n, const Scalar* dl,            \
222                                   const Scalar* d, const Scalar* du,         \
223                                   Scalar* B, int ldb, void* pBuffer) const { \
224     DCHECK(initialized_);                                                    \
225     return Gtsv2Impl(SPARSE_FN(gtsv2, sparse_prefix), *gpusparse_handle_, m, \
226                      n, dl, d, du, B, ldb, pBuffer);                         \
227   }
228 
229 TF_CALL_LAPACK_TYPES(GTSV2_INSTANCE);
230 
231 #define GTSV2_NO_PIVOT_INSTANCE(Scalar, sparse_prefix)                      \
232   template <>                                                               \
233   Status GpuSparse::Gtsv2NoPivot<Scalar>(                                   \
234       int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du,    \
235       Scalar* B, int ldb, void* pBuffer) const {                            \
236     DCHECK(initialized_);                                                   \
237     return Gtsv2Impl(SPARSE_FN(gtsv2_nopivot, sparse_prefix),               \
238                      *gpusparse_handle_, m, n, dl, d, du, B, ldb, pBuffer); \
239   }
240 
241 TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_INSTANCE);
242 
243 template <typename Scalar, typename SparseFn>
Gtsv2BufferSizeExtImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * dl,const Scalar * d,const Scalar * du,const Scalar * B,int ldb,size_t * bufferSizeInBytes)244 static inline Status Gtsv2BufferSizeExtImpl(SparseFn op,
245                                             cusparseHandle_t cusparse_handle,
246                                             int m, int n, const Scalar* dl,
247                                             const Scalar* d, const Scalar* du,
248                                             const Scalar* B, int ldb,
249                                             size_t* bufferSizeInBytes) {
250   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
251                                   AsCudaComplex(d), AsCudaComplex(du),
252                                   AsCudaComplex(B), ldb, bufferSizeInBytes));
253   return OkStatus();
254 }
255 
256 #define GTSV2_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix)                     \
257   template <>                                                                 \
258   Status GpuSparse::Gtsv2BufferSizeExt<Scalar>(                               \
259       int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du,      \
260       const Scalar* B, int ldb, size_t* bufferSizeInBytes) const {            \
261     DCHECK(initialized_);                                                     \
262     return Gtsv2BufferSizeExtImpl(                                            \
263         SPARSE_FN(gtsv2_bufferSizeExt, sparse_prefix), *gpusparse_handle_, m, \
264         n, dl, d, du, B, ldb, bufferSizeInBytes);                             \
265   }
266 
267 TF_CALL_LAPACK_TYPES(GTSV2_BUFFER_SIZE_INSTANCE);
268 
269 #define GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix)       \
270   template <>                                                            \
271   Status GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>(                   \
272       int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
273       const Scalar* B, int ldb, size_t* bufferSizeInBytes) const {       \
274     DCHECK(initialized_);                                                \
275     return Gtsv2BufferSizeExtImpl(                                       \
276         SPARSE_FN(gtsv2_nopivot_bufferSizeExt, sparse_prefix),           \
277         *gpusparse_handle_, m, n, dl, d, du, B, ldb, bufferSizeInBytes); \
278   }
279 
280 TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE);
281 
282 template <typename Scalar, typename SparseFn>
Gtsv2StridedBatchImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * x,int batchCount,int batchStride,void * pBuffer)283 static inline Status Gtsv2StridedBatchImpl(SparseFn op,
284                                            cusparseHandle_t cusparse_handle,
285                                            int m, const Scalar* dl,
286                                            const Scalar* d, const Scalar* du,
287                                            Scalar* x, int batchCount,
288                                            int batchStride, void* pBuffer) {
289   TF_RETURN_IF_GPUSPARSE_ERROR(op(
290       cusparse_handle, m, AsCudaComplex(dl), AsCudaComplex(d),
291       AsCudaComplex(du), AsCudaComplex(x), batchCount, batchStride, pBuffer));
292   return OkStatus();
293 }
294 
295 #define GTSV2_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix)                   \
296   template <>                                                                 \
297   Status GpuSparse::Gtsv2StridedBatch<Scalar>(                                \
298       int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x,  \
299       int batchCount, int batchStride, void* pBuffer) const {                 \
300     DCHECK(initialized_);                                                     \
301     return Gtsv2StridedBatchImpl(SPARSE_FN(gtsv2StridedBatch, sparse_prefix), \
302                                  *gpusparse_handle_, m, dl, d, du, x,         \
303                                  batchCount, batchStride, pBuffer);           \
304   }
305 
306 TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_INSTANCE);
307 
308 template <typename Scalar, typename SparseFn>
Gtsv2StridedBatchBufferSizeImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,const Scalar * dl,const Scalar * d,const Scalar * du,const Scalar * x,int batchCount,int batchStride,size_t * bufferSizeInBytes)309 static inline Status Gtsv2StridedBatchBufferSizeImpl(
310     SparseFn op, cusparseHandle_t cusparse_handle, int m, const Scalar* dl,
311     const Scalar* d, const Scalar* du, const Scalar* x, int batchCount,
312     int batchStride, size_t* bufferSizeInBytes) {
313   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
314                                   AsCudaComplex(d), AsCudaComplex(du),
315                                   AsCudaComplex(x), batchCount, batchStride,
316                                   bufferSizeInBytes));
317   return OkStatus();
318 }
319 
320 #define GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
321   template <>                                                           \
322   Status GpuSparse::Gtsv2StridedBatchBufferSizeExt<Scalar>(             \
323       int m, const Scalar* dl, const Scalar* d, const Scalar* du,       \
324       const Scalar* x, int batchCount, int batchStride,                 \
325       size_t* bufferSizeInBytes) const {                                \
326     DCHECK(initialized_);                                               \
327     return Gtsv2StridedBatchBufferSizeImpl(                             \
328         SPARSE_FN(gtsv2StridedBatch_bufferSizeExt, sparse_prefix),      \
329         *gpusparse_handle_, m, dl, d, du, x, batchCount, batchStride,   \
330         bufferSizeInBytes);                                             \
331   }
332 
333 TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE);
334 
Coo2csr(const int * cooRowInd,int nnz,int m,int * csrRowPtr) const335 Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
336                           int* csrRowPtr) const {
337   // cusparseStatus_t CUSPARSEAPI cusparseXcoo2csr(cusparseHandle_t handle,
338   //                                               const int *cooRowInd,
339   //                                               int nnz,
340   //                                               int m,
341   //                                               int *csrSortedRowPtr,
342   //                                               cusparseIndexBase_t
343   //                                               idxBase);
344   DCHECK(initialized_);
345   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcoo2csr(*gpusparse_handle_, cooRowInd,
346                                                 nnz, m, csrRowPtr,
347                                                 CUSPARSE_INDEX_BASE_ZERO));
348   return OkStatus();
349 }
350 
Csr2coo(const int * csrRowPtr,int nnz,int m,int * cooRowInd) const351 Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
352                           int* cooRowInd) const {
353   // cusparseStatus_t CUSPARSEAPI cusparseXcsr2coo(cusparseHandle_t handle,
354   //                                               const int *csrRowPtr,
355   //                                               int nnz,
356   //                                               int m,
357   //                                               int *cooRowInd,
358   //                                               cusparseIndexBase_t
359   //                                               idxBase);
360   DCHECK(initialized_);
361   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
362                                                 nnz, m, cooRowInd,
363                                                 CUSPARSE_INDEX_BASE_ZERO));
364   return OkStatus();
365 }
366 
CsrgeamNnz(int m,int n,const cusparseMatDescr_t descrA,int nnzA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,int * csrSortedRowPtrC,int * nnzTotalDevHostPtr,void * workspace)367 Status GpuSparse::CsrgeamNnz(
368     int m, int n, const cusparseMatDescr_t descrA, int nnzA,
369     const int* csrSortedRowPtrA, const int* csrSortedColIndA,
370     const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
371     const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
372     int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, void* workspace) {
373   DCHECK(initialized_);
374   DCHECK(nnzTotalDevHostPtr != nullptr);
375 #if CUDA_VERSION >= 10000
376   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeam2Nnz(
377       *gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
378       csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
379       descrC, csrSortedRowPtrC, nnzTotalDevHostPtr, workspace));
380 #else
381   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz(
382       *gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
383       csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
384       descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
385 #endif
386   return OkStatus();
387 }
388 
389 #if CUDA_VERSION < 10020
390 
391 template <typename Scalar, typename SparseFnT>
CsrmmImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,cusparseOperation_t transB,int m,int n,int k,int nnz,const Scalar * alpha_host,const cusparseMatDescr_t descrA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * B,int ldb,const Scalar * beta_host,Scalar * C,int ldc)392 static inline Status CsrmmImpl(
393     SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
394     cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k,
395     int nnz, const Scalar* alpha_host, const cusparseMatDescr_t descrA,
396     const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
397     const int* csrSortedColIndA, const Scalar* B, int ldb,
398     const Scalar* beta_host, Scalar* C, int ldc) {
399   // cusparseStatus_t CUSPARSEAPI cusparseScsrmm2(
400   //     cusparseHandle_t handle, cusparseOperation_t transA,
401   //     cusparseOperation_t transB, int m, int n, int k, int nnz,
402   //     const float* alpha, const cusparseMatDescr_t descrA,
403   //     const float* csrSortedValA, const int* csrSortedRowPtrA,
404   //     const int* csrSortedColIndA, const float* B, int ldb, const float*
405   //     beta, float* C, int ldc);
406   TF_RETURN_IF_GPUSPARSE_ERROR(op(
407       cusparse_handle, transA, transB, m, n, k, nnz, AsCudaComplex(alpha_host),
408       descrA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
409       AsCudaComplex(B), ldb, AsCudaComplex(beta_host), AsCudaComplex(C), ldc));
410   return OkStatus();
411 }
412 
413 #define CSRMM_INSTANCE(Scalar, sparse_prefix)                                \
414   template <>                                                                \
415   Status GpuSparse::Csrmm<Scalar>(                                           \
416       cusparseOperation_t transA, cusparseOperation_t transB, int m, int n,  \
417       int k, int nnz, const Scalar* alpha_host,                              \
418       const cusparseMatDescr_t descrA, const Scalar* csrSortedValA,          \
419       const int* csrSortedRowPtrA, const int* csrSortedColIndA,              \
420       const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) \
421       const {                                                                \
422     DCHECK(initialized_);                                                    \
423     return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_,             \
424                      *gpusparse_handle_, transA, transB, m, n, k, nnz,       \
425                      alpha_host, descrA, csrSortedValA, csrSortedRowPtrA,    \
426                      csrSortedColIndA, B, ldb, beta_host, C, ldc);           \
427   }
428 
429 TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE);
430 
431 #else
432 
433 #define SPMM_BUFFERSIZE_INSTANCE(Scalar, dtype)                              \
434   template <>                                                                \
435   Status GpuSparse::SpMMBufferSize<Scalar>(                                  \
436       cusparseOperation_t transA, cusparseOperation_t transB,                \
437       const Scalar* alpha, const cusparseSpMatDescr_t matA,                  \
438       const gpusparseDnMatDescr_t matB, const Scalar* beta,                  \
439       gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, size_t* bufferSize) \
440       const {                                                                \
441     DCHECK(initialized_);                                                    \
442     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM_bufferSize(                    \
443         *gpusparse_handle_, transA, transB, alpha, matA, matB, beta, matC,   \
444         dtype, alg, bufferSize));                                            \
445     return Status::OK();                                                     \
446   }
447 
448 TF_CALL_CUSPARSE_DTYPES(SPMM_BUFFERSIZE_INSTANCE);
449 
450 #define SPMM_INSTANCE(Scalar, dtype)                                           \
451   template <>                                                                  \
452   Status GpuSparse::SpMM<Scalar>(                                              \
453       cusparseOperation_t transA, cusparseOperation_t transB,                  \
454       const Scalar* alpha, const cusparseSpMatDescr_t matA,                    \
455       const gpusparseDnMatDescr_t matB, const Scalar* beta,                    \
456       gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, int8* buffer) const { \
457     DCHECK(initialized_);                                                      \
458     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM(*gpusparse_handle_, transA,      \
459                                               transB, alpha, matA, matB, beta, \
460                                               matC, dtype, alg, buffer));      \
461     return Status::OK();                                                       \
462   }
463 
464 TF_CALL_CUSPARSE_DTYPES(SPMM_INSTANCE);
465 
466 #endif
467 
468 #if CUDA_VERSION < 10020
469 
470 template <typename Scalar, typename SparseFnT>
CsrmvImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,int m,int n,int nnz,const Scalar * alpha_host,const cusparseMatDescr_t descrA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * x,const Scalar * beta_host,Scalar * y)471 static inline Status CsrmvImpl(
472     SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
473     cusparseOperation_t transA, int m, int n, int nnz, const Scalar* alpha_host,
474     const cusparseMatDescr_t descrA, const Scalar* csrSortedValA,
475     const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* x,
476     const Scalar* beta_host, Scalar* y) {
477   TF_RETURN_IF_GPUSPARSE_ERROR(
478       op(cusparse_handle, transA, m, n, nnz, AsCudaComplex(alpha_host), descrA,
479          AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
480          AsCudaComplex(x), AsCudaComplex(beta_host), AsCudaComplex(y)));
481   return OkStatus();
482 }
483 
484 // TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
485 #define CSRMV_INSTANCE(Scalar, sparse_prefix)                                \
486   template <>                                                                \
487   Status GpuSparse::Csrmv<Scalar>(                                           \
488       cusparseOperation_t transA, int m, int n, int nnz,                     \
489       const Scalar* alpha_host, const cusparseMatDescr_t descrA,             \
490       const Scalar* csrSortedValA, const int* csrSortedRowPtrA,              \
491       const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, \
492       Scalar* y) const {                                                     \
493     DCHECK(initialized_);                                                    \
494     if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) {                        \
495       return CsrmvImpl(SPARSE_FN(csrmv_mp, sparse_prefix), context_,         \
496                        *gpusparse_handle_, transA, m, n, nnz, alpha_host,    \
497                        descrA, csrSortedValA, csrSortedRowPtrA,              \
498                        csrSortedColIndA, x, beta_host, y);                   \
499     } else {                                                                 \
500       return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_,            \
501                        *gpusparse_handle_, transA, m, n, nnz, alpha_host,    \
502                        descrA, csrSortedValA, csrSortedRowPtrA,              \
503                        csrSortedColIndA, x, beta_host, y);                   \
504     }                                                                        \
505   }
506 
507 TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE);
508 
509 #else
510 
511 template <typename Scalar>
CsrmvExImpl(cudaDataType_t dtype,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,int m,int n,int nnz,const Scalar * alpha_host,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * x,const Scalar * beta_host,Scalar * y)512 static inline Status CsrmvExImpl(cudaDataType_t dtype, OpKernelContext* context,
513                                  cusparseHandle_t cusparse_handle,
514                                  cusparseOperation_t transA, int m, int n,
515                                  int nnz, const Scalar* alpha_host,
516                                  const Scalar* csrSortedValA,
517                                  const int* csrSortedRowPtrA,
518                                  const int* csrSortedColIndA, const Scalar* x,
519                                  const Scalar* beta_host, Scalar* y) {
520   cusparseMatDescr_t descrA;
521   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
522   TF_RETURN_IF_GPUSPARSE_ERROR(
523       cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
524   TF_RETURN_IF_GPUSPARSE_ERROR(
525       cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
526   // CUSPARSE_ALG_MERGE_PATH algo only supports non-transpose matrix.
527   DCHECK(transA == CUSPARSE_OPERATION_NON_TRANSPOSE);
528 
529   size_t bufferSize;
530   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx_bufferSize(
531       cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
532       dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
533       x, dtype, beta_host, dtype, y, dtype, dtype, &bufferSize));
534 
535   Tensor buffer;
536   TF_RETURN_IF_ERROR(context->allocate_temp(
537       DT_INT8, TensorShape({static_cast<int64_t>(bufferSize)}), &buffer));
538   auto pBuffer = buffer.flat<int8>();
539   DCHECK(pBuffer.data() != nullptr);
540 
541   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx(
542       cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
543       dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
544       x, dtype, beta_host, dtype, y, dtype, dtype, pBuffer.data()));
545 
546   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyMatDescr(descrA));
547   return Status::OK();
548 }
549 
550 template <typename Scalar>
SpMVImpl(cudaDataType_t dtype,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,int m,int n,int nnz,const Scalar * alpha_host,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * x,const Scalar * beta_host,Scalar * y)551 static inline Status SpMVImpl(cudaDataType_t dtype, OpKernelContext* context,
552                               cusparseHandle_t cusparse_handle,
553                               cusparseOperation_t transA, int m, int n, int nnz,
554                               const Scalar* alpha_host,
555                               const Scalar* csrSortedValA,
556                               const int* csrSortedRowPtrA,
557                               const int* csrSortedColIndA, const Scalar* x,
558                               const Scalar* beta_host, Scalar* y) {
559   cusparseSpMatDescr_t matA;
560   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
561       &matA, m, n, nnz, const_cast<int*>(csrSortedRowPtrA),
562       const_cast<int*>(csrSortedColIndA), const_cast<Scalar*>(csrSortedValA),
563       CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, dtype));
564 
565   cusparseDnVecDescr_t vecX, vecY;
566   int sizeX = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? n : m;
567   int sizeY = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? m : n;
568   TF_RETURN_IF_GPUSPARSE_ERROR(
569       cusparseCreateDnVec(&vecX, sizeX, const_cast<Scalar*>(x), dtype));
570   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnVec(&vecY, sizeY, y, dtype));
571 
572   size_t bufferSize;
573   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMV_bufferSize(
574       cusparse_handle, transA, alpha_host, matA, vecX, beta_host, vecY, dtype,
575       CUSPARSE_CSRMV_ALG1, &bufferSize));
576 
577   Tensor buffer;
578   TF_RETURN_IF_ERROR(context->allocate_temp(
579       DT_INT8, TensorShape({static_cast<int64_t>(bufferSize)}), &buffer));
580   auto pBuffer = buffer.flat<int8>();
581   DCHECK(pBuffer.data() != nullptr);
582 
583   TF_RETURN_IF_GPUSPARSE_ERROR(
584       cusparseSpMV(cusparse_handle, transA, alpha_host, matA, vecX, beta_host,
585                    vecY, dtype, CUSPARSE_CSRMV_ALG1, pBuffer.data()));
586 
587   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecY));
588   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecX));
589   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA));
590   return Status::OK();
591 }
592 
593 #define CSRMV_INSTANCE(Scalar, cudaDataType)                                   \
594   template <>                                                                  \
595   Status GpuSparse::Csrmv<Scalar>(                                             \
596       cusparseOperation_t transA, int m, int n, int nnz,                       \
597       const Scalar* alpha_host, const Scalar* csrSortedValA,                   \
598       const int* csrSortedRowPtrA, const int* csrSortedColIndA,                \
599       const Scalar* x, const Scalar* beta_host, Scalar* y) const {             \
600     DCHECK(initialized_);                                                      \
601     if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) {                          \
602       return CsrmvExImpl(cudaDataType, context_, *gpusparse_handle_, transA,   \
603                          m, n, nnz, alpha_host, csrSortedValA,                 \
604                          csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y); \
605     } else {                                                                   \
606       return SpMVImpl(cudaDataType, context_, *gpusparse_handle_, transA, m,   \
607                       n, nnz, alpha_host, csrSortedValA, csrSortedRowPtrA,     \
608                       csrSortedColIndA, x, beta_host, y);                      \
609     }                                                                          \
610   }
611 
612 TF_CALL_CUSPARSE_DTYPES(CSRMV_INSTANCE);
613 
614 #endif  // CUDA_VERSION < 10020
615 
616 #if CUDA_VERSION < 10000
617 
618 template <typename Scalar, typename SparseFnT>
CsrgeamImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * alpha,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * beta,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC)619 static inline Status CsrgeamImpl(
620     SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
621     int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
622     int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
623     const int* csrSortedColIndA, const Scalar* beta,
624     const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
625     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
626     const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
627     int* csrSortedRowPtrC, int* csrSortedColIndC) {
628   TF_RETURN_IF_GPUSPARSE_ERROR(
629       op(cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
630          AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
631          AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
632          csrSortedRowPtrB, csrSortedColIndB, descrC,
633          AsCudaComplex(csrSortedValC), csrSortedRowPtrC, csrSortedColIndC));
634   return Status::OK();
635 }
636 
637 #define CSRGEAM_INSTANCE(Scalar, sparse_prefix)                               \
638   template <>                                                                 \
639   Status GpuSparse::Csrgeam<Scalar>(                                          \
640       int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,     \
641       int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,     \
642       const int* csrSortedColIndA, const Scalar* beta,                        \
643       const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
644       const int* csrSortedRowPtrB, const int* csrSortedColIndB,               \
645       const cusparseMatDescr_t descrC, Scalar* csrSortedValC,                 \
646       int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) {        \
647     DCHECK(initialized_);                                                     \
648     return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_,           \
649                        *gpusparse_handle_, m, n, alpha, descrA, nnzA,         \
650                        csrSortedValA, csrSortedRowPtrA, csrSortedColIndA,     \
651                        beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB,   \
652                        csrSortedColIndB, descrC, csrSortedValC,               \
653                        csrSortedRowPtrC, csrSortedColIndC);                   \
654   }
655 
656 #else
657 
658 template <typename Scalar, typename SparseFnT>
Csrgeam2Impl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * alpha,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * beta,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC,void * workspace)659 static inline Status Csrgeam2Impl(
660     SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
661     int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
662     int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
663     const int* csrSortedColIndA, const Scalar* beta,
664     const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
665     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
666     const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
667     int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) {
668   TF_RETURN_IF_GPUSPARSE_ERROR(op(
669       cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
670       AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
671       AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
672       csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
673       csrSortedRowPtrC, csrSortedColIndC, workspace));
674   return OkStatus();
675 }
676 
677 #define CSRGEAM_INSTANCE(Scalar, sparse_prefix)                               \
678   template <>                                                                 \
679   Status GpuSparse::Csrgeam<Scalar>(                                          \
680       int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,     \
681       int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,     \
682       const int* csrSortedColIndA, const Scalar* beta,                        \
683       const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
684       const int* csrSortedRowPtrB, const int* csrSortedColIndB,               \
685       const cusparseMatDescr_t descrC, Scalar* csrSortedValC,                 \
686       int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) {        \
687     DCHECK(initialized_);                                                     \
688     return Csrgeam2Impl(SPARSE_FN(csrgeam2, sparse_prefix), context_,         \
689                         *gpusparse_handle_, m, n, alpha, descrA, nnzA,        \
690                         csrSortedValA, csrSortedRowPtrA, csrSortedColIndA,    \
691                         beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB,  \
692                         csrSortedColIndB, descrC, csrSortedValC,              \
693                         csrSortedRowPtrC, csrSortedColIndC, workspace);       \
694   }
695 
696 #endif
697 
698 TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE);
699 
700 #if CUDA_VERSION < 10000
701 
702 #define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix)                    \
703   template <>                                                                 \
704   Status GpuSparse::CsrgeamBufferSizeExt<Scalar>(                             \
705       int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,     \
706       int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,     \
707       const int* csrSortedColIndA, const Scalar* beta,                        \
708       const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
709       const int* csrSortedRowPtrB, const int* csrSortedColIndB,               \
710       const cusparseMatDescr_t descrC, Scalar* csrSortedValC,                 \
711       int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) {     \
712     DCHECK(initialized_);                                                     \
713     *bufferSize = 0;                                                          \
714     return Status::OK();                                                      \
715   }
716 
717 #else
718 
719 template <typename Scalar, typename SparseFnT>
CsrgeamBufferSizeExtImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t sparse_handle,int m,int n,const Scalar * alpha,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * beta,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC,size_t * bufferSize)720 static inline Status CsrgeamBufferSizeExtImpl(
721     SparseFnT op, OpKernelContext* context, cusparseHandle_t sparse_handle,
722     int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
723     int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
724     const int* csrSortedColIndA, const Scalar* beta,
725     const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
726     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
727     const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
728     int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) {
729   TF_RETURN_IF_GPUSPARSE_ERROR(op(
730       sparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
731       AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
732       AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
733       csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
734       csrSortedRowPtrC, csrSortedColIndC, bufferSize));
735   return OkStatus();
736 }
737 
738 #define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix)                     \
739   template <>                                                                  \
740   Status GpuSparse::CsrgeamBufferSizeExt<Scalar>(                              \
741       int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,      \
742       int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,      \
743       const int* csrSortedColIndA, const Scalar* beta,                         \
744       const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,  \
745       const int* csrSortedRowPtrB, const int* csrSortedColIndB,                \
746       const cusparseMatDescr_t descrC, Scalar* csrSortedValC,                  \
747       int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) {      \
748     DCHECK(initialized_);                                                      \
749     return CsrgeamBufferSizeExtImpl(                                           \
750         SPARSE_FN(csrgeam2_bufferSizeExt, sparse_prefix), context_,            \
751         *gpusparse_handle_, m, n, alpha, descrA, nnzA, csrSortedValA,          \
752         csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, csrSortedValB, \
753         csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC,             \
754         csrSortedRowPtrC, csrSortedColIndC, bufferSize);                       \
755   }
756 
757 #endif
758 
759 TF_CALL_LAPACK_TYPES(CSRGEAM_BUFFERSIZE_INSTANCE);
760 
761 #if CUDA_VERSION < 10000
762 
CsrgemmNnz(cusparseOperation_t transA,cusparseOperation_t transB,int m,int k,int n,const cusparseMatDescr_t descrA,int nnzA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,int * csrSortedRowPtrC,int * nnzTotalDevHostPtr)763 Status GpuSparse::CsrgemmNnz(
764     cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
765     const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA,
766     const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
767     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
768     const cusparseMatDescr_t descrC, int* csrSortedRowPtrC,
769     int* nnzTotalDevHostPtr) {
770   DCHECK(initialized_);
771   DCHECK(nnzTotalDevHostPtr != nullptr);
772   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemmNnz(
773       *gpusparse_handle_, transA, transB, m, k, n, descrA, nnzA,
774       csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
775       csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
776   return Status::OK();
777 }
778 
779 template <typename Scalar, typename SparseFnT>
CsrgemmImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,cusparseOperation_t transB,int m,int k,int n,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC)780 static inline Status CsrgemmImpl(
781     SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
782     cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
783     const cusparseMatDescr_t descrA, int nnzA, const Scalar* csrSortedValA,
784     const int* csrSortedRowPtrA, const int* csrSortedColIndA,
785     const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
786     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
787     const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
788     int* csrSortedRowPtrC, int* csrSortedColIndC) {
789   TF_RETURN_IF_GPUSPARSE_ERROR(
790       op(cusparse_handle, transA, transB, m, k, n, descrA, nnzA,
791          AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
792          descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
793          csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
794          csrSortedRowPtrC, csrSortedColIndC));
795   return Status::OK();
796 }
797 
798 #define CSRGEMM_INSTANCE(Scalar, sparse_prefix)                               \
799   template <>                                                                 \
800   Status GpuSparse::Csrgemm<Scalar>(                                          \
801       cusparseOperation_t transA, cusparseOperation_t transB, int m, int k,   \
802       int n, const cusparseMatDescr_t descrA, int nnzA,                       \
803       const Scalar* csrSortedValA, const int* csrSortedRowPtrA,               \
804       const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, \
805       const Scalar* csrSortedValB, const int* csrSortedRowPtrB,               \
806       const int* csrSortedColIndB, const cusparseMatDescr_t descrC,           \
807       Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) {  \
808     DCHECK(initialized_);                                                     \
809     return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_,           \
810                        *gpusparse_handle_, transA, transB, m, k, n, descrA,   \
811                        nnzA, csrSortedValA, csrSortedRowPtrA,                 \
812                        csrSortedColIndA, descrB, nnzB, csrSortedValB,         \
813                        csrSortedRowPtrB, csrSortedColIndB, descrC,            \
814                        csrSortedValC, csrSortedRowPtrC, csrSortedColIndC);    \
815   }
816 
817 TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
818 
819 #else
820 
821 template <typename T>
one_ptr()822 static const T* one_ptr() {
823   static const T one = static_cast<T>(1);
824   return &one;
825 }
826 
827 template <typename T>
null_ptr()828 static const T* null_ptr() {
829   return nullptr;
830 }
831 
832 #define CSRGEMM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix)                     \
833   template <>                                                                  \
834   Status GpuSparse::CsrgemmBufferSize<Scalar>(                                 \
835       int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,          \
836       const int* csrSortedRowPtrA, const int* csrSortedColIndA,                \
837       const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,  \
838       const int* csrSortedColIndB, csrgemm2Info_t info,                        \
839       size_t* workspaceBytes) {                                                \
840     DCHECK(initialized_);                                                      \
841     TF_RETURN_IF_GPUSPARSE_ERROR(SPARSE_FN(csrgemm2_bufferSizeExt,             \
842                                            sparse_prefix)(                     \
843         *gpusparse_handle_, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA, \
844         nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB,                \
845         csrSortedRowPtrB, csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), \
846         descrA, 0, null_ptr<int>(), null_ptr<int>(), info, workspaceBytes));   \
847     return OkStatus();                                                         \
848   }
849 
850 TF_CALL_LAPACK_TYPES(CSRGEMM_BUFFERSIZE_INSTANCE);
851 
CsrgemmNnz(int m,int n,int k,const cusparseMatDescr_t descrA,int nnzA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,int * csrSortedRowPtrC,int * nnzTotalDevHostPtr,csrgemm2Info_t info,void * workspace)852 Status GpuSparse::CsrgemmNnz(
853     int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
854     const int* csrSortedRowPtrA, const int* csrSortedColIndA,
855     const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
856     const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
857     int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, csrgemm2Info_t info,
858     void* workspace) {
859   DCHECK(initialized_);
860   DCHECK(nnzTotalDevHostPtr != nullptr);
861 
862   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemm2Nnz(
863       *gpusparse_handle_, m, n, k, descrA, nnzA, csrSortedRowPtrA,
864       csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
865       descrA, 0, null_ptr<int>(), null_ptr<int>(), descrC, csrSortedRowPtrC,
866       nnzTotalDevHostPtr, info, workspace));
867   return OkStatus();
868 }
869 
870 template <typename Scalar, typename SparseFnT>
CsrgemmImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int k,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC,const csrgemm2Info_t info,void * workspace)871 static inline Status CsrgemmImpl(
872     SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
873     int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
874     const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
875     const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
876     const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
877     const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
878     Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC,
879     const csrgemm2Info_t info, void* workspace) {
880   TF_RETURN_IF_GPUSPARSE_ERROR(
881       op(cusparse_handle, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA,
882          nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
883          descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
884          csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), descrA, 0,
885          AsCudaComplex(null_ptr<Scalar>()), null_ptr<int>(), null_ptr<int>(),
886          descrC, AsCudaComplex(csrSortedValC), csrSortedRowPtrC,
887          csrSortedColIndC, info, workspace));
888   return OkStatus();
889 }
890 
891 #define CSRGEMM_INSTANCE(Scalar, sparse_prefix)                               \
892   template <>                                                                 \
893   Status GpuSparse::Csrgemm<Scalar>(                                          \
894       int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,         \
895       const Scalar* csrSortedValA, const int* csrSortedRowPtrA,               \
896       const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, \
897       const Scalar* csrSortedValB, const int* csrSortedRowPtrB,               \
898       const int* csrSortedColIndB, const cusparseMatDescr_t descrC,           \
899       Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC,    \
900       const csrgemm2Info_t info, void* workspace) {                           \
901     DCHECK(initialized_);                                                     \
902     return CsrgemmImpl(SPARSE_FN(csrgemm2, sparse_prefix), context_,          \
903                        *gpusparse_handle_, m, n, k, descrA, nnzA,             \
904                        csrSortedValA, csrSortedRowPtrA, csrSortedColIndA,     \
905                        descrB, nnzB, csrSortedValB, csrSortedRowPtrB,         \
906                        csrSortedColIndB, descrC, csrSortedValC,               \
907                        csrSortedRowPtrC, csrSortedColIndC, info, workspace);  \
908   }
909 
910 TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
911 
912 #endif  // CUDA_VERSION < 10000
913 
914 template <typename Scalar, typename BufferSizeFnT, typename SparseFnT>
Csru2csrImpl(SparseFnT op,BufferSizeFnT buffer_size_op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int nnz,const cusparseMatDescr_t descrA,Scalar * csrVal,const int * csrRowPtr,int * csrColInd)915 static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
916                                   OpKernelContext* context,
917                                   cusparseHandle_t cusparse_handle, int m,
918                                   int n, int nnz,
919                                   const cusparseMatDescr_t descrA,
920                                   Scalar* csrVal, const int* csrRowPtr,
921                                   int* csrColInd) {
922   GpuSparseCsrSortingConversionInfo info;
923   TF_RETURN_IF_ERROR(info.Initialize());
924 
925   size_t pBufferSizeInBytes = 0;
926 
927   TF_RETURN_IF_GPUSPARSE_ERROR(
928       buffer_size_op(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
929                      csrRowPtr, csrColInd, info.info(), &pBufferSizeInBytes));
930 
931   Tensor pBuffer_t;
932   TF_RETURN_IF_ERROR(context->allocate_temp(
933       DT_INT8, TensorShape({static_cast<int64_t>(pBufferSizeInBytes)}),
934       &pBuffer_t));
935   auto pBuffer = pBuffer_t.flat<int8>();
936   DCHECK(pBuffer.data() != nullptr);
937 
938   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, descrA,
939                                   AsCudaComplex(csrVal), csrRowPtr, csrColInd,
940                                   info.info(), pBuffer.data()));
941 
942   return OkStatus();
943 }
944 
945 #define CSRU2CSR_INSTANCE(Scalar, sparse_prefix)                              \
946   template <>                                                                 \
947   Status GpuSparse::Csru2csr<Scalar>(                                         \
948       int m, int n, int nnz, const cusparseMatDescr_t descrA, Scalar* csrVal, \
949       const int* csrRowPtr, int* csrColInd) {                                 \
950     DCHECK(initialized_);                                                     \
951     return Csru2csrImpl(SPARSE_FN(csru2csr, sparse_prefix),                   \
952                         BUFSIZE_FN(csru2csr, sparse_prefix), context_,        \
953                         *gpusparse_handle_, m, n, nnz, descrA, csrVal,        \
954                         csrRowPtr, csrColInd);                                \
955   }
956 
957 TF_CALL_LAPACK_TYPES(CSRU2CSR_INSTANCE);
958 
959 #if CUDA_VERSION < 10010
960 
961 template <typename Scalar, typename SparseFnT>
Csr2cscImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int nnz,const Scalar * csrVal,const int * csrRowPtr,const int * csrColInd,Scalar * cscVal,int * cscRowInd,int * cscColPtr,const cusparseAction_t copyValues)962 static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
963                                  cusparseHandle_t cusparse_handle, int m, int n,
964                                  int nnz, const Scalar* csrVal,
965                                  const int* csrRowPtr, const int* csrColInd,
966                                  Scalar* cscVal, int* cscRowInd, int* cscColPtr,
967                                  const cusparseAction_t copyValues) {
968   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz,
969                                   AsCudaComplex(csrVal), csrRowPtr, csrColInd,
970                                   AsCudaComplex(cscVal), cscRowInd, cscColPtr,
971                                   copyValues, CUSPARSE_INDEX_BASE_ZERO));
972   return Status::OK();
973 }
974 
975 #define CSR2CSC_INSTANCE(Scalar, sparse_prefix)                              \
976   template <>                                                                \
977   Status GpuSparse::Csr2csc<Scalar>(                                         \
978       int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr,     \
979       const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr,  \
980       const cusparseAction_t copyValues) {                                   \
981     DCHECK(initialized_);                                                    \
982     return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_,          \
983                        *gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr,     \
984                        csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
985   }
986 
987 TF_CALL_LAPACK_TYPES(CSR2CSC_INSTANCE);
988 
989 #else
990 
991 template <typename Scalar>
Csr2cscImpl(cudaDataType_t dtype,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int nnz,const Scalar * csrVal,const int * csrRowPtr,const int * csrColInd,Scalar * cscVal,int * cscRowInd,int * cscColPtr,const cusparseAction_t copyValues)992 static inline Status Csr2cscImpl(cudaDataType_t dtype, OpKernelContext* context,
993                                  cusparseHandle_t cusparse_handle, int m, int n,
994                                  int nnz, const Scalar* csrVal,
995                                  const int* csrRowPtr, const int* csrColInd,
996                                  Scalar* cscVal, int* cscRowInd, int* cscColPtr,
997                                  const cusparseAction_t copyValues) {
998   size_t bufferSize;
999   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsr2cscEx2_bufferSize(
1000       cusparse_handle, m, n, nnz, AsCudaComplex(csrVal), csrRowPtr, csrColInd,
1001       AsCudaComplex(cscVal), cscColPtr, cscRowInd, dtype, copyValues,
1002       CUSPARSE_INDEX_BASE_ZERO, CUSPARSE_CSR2CSC_ALG2, &bufferSize));
1003 
1004   Tensor buffer;
1005   TF_RETURN_IF_ERROR(context->allocate_temp(
1006       DataTypeToEnum<Scalar>::value,
1007       TensorShape({static_cast<int64_t>(bufferSize)}), &buffer));
1008 
1009   DCHECK(buffer.flat<Scalar>().data() != nullptr);
1010 
1011   TF_RETURN_IF_GPUSPARSE_ERROR(
1012       cusparseCsr2cscEx2(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
1013                          csrRowPtr, csrColInd, AsCudaComplex(cscVal), cscColPtr,
1014                          cscRowInd, dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO,
1015                          CUSPARSE_CSR2CSC_ALG2, buffer.flat<Scalar>().data()));
1016 
1017   return OkStatus();
1018 }
1019 
1020 #define CSR2CSC_INSTANCE(Scalar, cudaDataType)                                \
1021   template <>                                                                 \
1022   Status GpuSparse::Csr2csc<Scalar>(                                          \
1023       int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr,      \
1024       const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr,   \
1025       const cusparseAction_t copyValues) {                                    \
1026     DCHECK(initialized_);                                                     \
1027     return Csr2cscImpl(cudaDataType, context_, *gpusparse_handle_, m, n, nnz, \
1028                        csrVal, csrRowPtr, csrColInd, cscVal, cscRowInd,       \
1029                        cscColPtr, copyValues);                                \
1030   }
1031 
1032 TF_CALL_CUSPARSE_DTYPES(CSR2CSC_INSTANCE);
1033 
1034 #endif  // CUDA_VERSION < 10010
1035 
1036 }  // namespace tensorflow
1037 
1038 #endif  // GOOGLE_CUDA
1039