xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse/mat_mul_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif
21 
22 #include "third_party/eigen3/Eigen/Core"
23 #include "third_party/eigen3/Eigen/SparseCore"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/type_traits.h"
29 #include "tensorflow/core/framework/variant_op_registry.h"
30 #include "tensorflow/core/kernels/cwise_ops_common.h"
31 #include "tensorflow/core/kernels/dense_update_functor.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 #include "tensorflow/core/kernels/sparse/kernels.h"
34 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
35 #include "tensorflow/core/kernels/sparse/transpose_op.h"
36 #include "tensorflow/core/kernels/transpose_functor.h"
37 #include "tensorflow/core/lib/gtl/inlined_vector.h"
38 #include "tensorflow/core/platform/threadpool.h"
39 
40 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
41 #include "tensorflow/core/util/cuda_sparse.h"
42 #include "tensorflow/core/util/gpu_solvers.h"
43 #endif
44 
45 namespace tensorflow {
46 
47 // TODO(anudhyan): These constants may be tuned based on the performance of
48 // 'benchmark_sparse_matrix_mat_vec_mul'. We would like to find constants
49 // which work across hardware platforms for typical matrix sizes. It should be
50 // possible to observe at least 30-50% improvement as we increase the number
51 // of threads by 1. If not, then it may we worth increasing kMaxShards and
52 // kNumShardsPerThread. However, once we have too many shards, latency may be
53 // dominated by per-shard overhead.
54 //
55 // Maximum number of shards into which to divide the computation for each CSR
56 // Sparse Matrix instance.
57 static constexpr int32_t kMaxShards = 20;
58 // Number of shards allocated to each thread.
59 static constexpr int32_t kNumShardsPerThread = 3;
60 
61 typedef Eigen::ThreadPoolDevice CPUDevice;
62 typedef Eigen::GpuDevice GPUDevice;
63 
64 // Abstract OpKernel to compute sparse-dense matrix multiplication.
65 //
66 // Implements a kernel which, given a SparseMatrix `a` and dense Tensor `b`,
67 // computes a dense Tensor `c` satisfying `c = a * b` where * denotes matrix
68 // multiplication.
69 //
70 // The boolean attributes `transpose_a` and `adjoint_a` will transpose or
71 // adjoint `a` before multiplication, respectively. At most one of these
72 // attributes must be set to True. Corresponding attributes will transpose or
73 // adjoint `b` or the output (after multiplication).
74 //
75 // The rank of both `a` and `b` must be equal and their shapes must be
76 // compatible for matrix multiplication. Otherwise, InvalidArgument runtime
77 // errors will be thrown. Only rank 2 or rank 3 inputs are supported.
78 //
79 template <typename Device, typename T>
80 class CSRMatMulOp : public OpKernel {
81  public:
CSRMatMulOp(OpKernelConstruction * c)82   explicit CSRMatMulOp(OpKernelConstruction* c) : OpKernel(c) {
83     OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_));
84     OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_));
85     bool adjoint_a;
86     OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a));
87     OP_REQUIRES(c, !(adjoint_a && transpose_a_),
88                 errors::InvalidArgument(
89                     "Only one of adjoint_a and transpose_a may be true."));
90     bool adjoint_b;
91     OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b));
92     OP_REQUIRES(c, !(adjoint_b && transpose_b_),
93                 errors::InvalidArgument(
94                     "Only one of adjoint_b and transpose_b may be true."));
95     OP_REQUIRES_OK(c, c->GetAttr("transpose_output", &transpose_output_));
96     OP_REQUIRES_OK(c, c->GetAttr("conjugate_output", &conjugate_output_));
97     transpose_a_ |= adjoint_a;
98     transpose_b_ |= adjoint_b;
99     if (is_complex<T>::value) {
100       conjugate_a_ = adjoint_a;
101       conjugate_b_ = adjoint_b;
102     } else {
103       conjugate_a_ = false;
104       conjugate_b_ = false;
105     }
106   }
107 
~CSRMatMulOp()108   ~CSRMatMulOp() override {}
109 
ValidateInputs(const CSRSparseMatrix & sparse_matrix_a,const Tensor & dense_tensor_b,int * rank,int64_t * batch_size)110   Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a,
111                         const Tensor& dense_tensor_b, int* rank,
112                         int64_t* batch_size) {
113     if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) {
114       return errors::InvalidArgument(
115           "Input types don't match.  a.dtype == ",
116           DataTypeString(sparse_matrix_a.dtype()),
117           " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype()));
118     }
119     *rank = sparse_matrix_a.dims();
120     // TODO(ebrevdo): Add support for broadcasting matmul.
121     if (*rank != dense_tensor_b.dims()) {
122       return errors::InvalidArgument("Ranks of a and b must match, saw: ", rank,
123                                      " vs. ", dense_tensor_b.dims(), ".");
124     }
125     // A valid CSR SparseMatrix has rank 2 or rank 3.
126     *batch_size = (*rank == 2) ? 1 : dense_tensor_b.dim_size(0);
127     if (sparse_matrix_a.batch_size() != *batch_size) {
128       return errors::InvalidArgument("Batch sizes of a and b must match, saw: ",
129                                      sparse_matrix_a.batch_size(), " vs. ",
130                                      batch_size, ".");
131     }
132     const auto& a_dense_shape = sparse_matrix_a.dense_shape().vec<int64_t>();
133     const int64_t a_inner_dim =
134         a_dense_shape(this->transpose_a_ ? *rank - 2 : *rank - 1);
135     const int64_t b_inner_dim =
136         dense_tensor_b.dim_size(this->transpose_b_ ? *rank - 1 : *rank - 2);
137     if (a_inner_dim != b_inner_dim) {
138       return errors::InvalidArgument(
139           "Inner product dimensions of A and B do not agree.  Shapes are: ",
140           TensorShape(a_dense_shape), " vs. ",
141           dense_tensor_b.shape().DebugString());
142     }
143     return OkStatus();
144   }
145 
146  public:
147   bool transpose_a_;
148   bool transpose_b_;
149   bool conjugate_a_;
150   bool conjugate_b_;
151   bool transpose_output_;
152   bool conjugate_output_;
153 };
154 
155 // CPU Kernel to compute sparse-dense matrix multiplication.
156 //
157 // Uses Eigen SparseMatrix to compute the sparse-dense multiplication between
158 // a CSR SparseMatrix `a` and dense Tensor `b`. If intra-op parallelism is
159 // available, the implementation parallelizes the computation across each row
160 // of the sparse matrix.
161 template <typename T>
162 class CSRMatMulCPUOp : public CSRMatMulOp<CPUDevice, T> {
163   using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
164   using Matrix =
165       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
166   using ConstMatrixMap = Eigen::Map<const Matrix>;
167   using MatrixMap = Eigen::Map<Matrix>;
168 
169  public:
CSRMatMulCPUOp(OpKernelConstruction * c)170   explicit CSRMatMulCPUOp(OpKernelConstruction* c)
171       : CSRMatMulOp<CPUDevice, T>(c) {}
172 
~CSRMatMulCPUOp()173   ~CSRMatMulCPUOp() override {}
174 
Compute(OpKernelContext * ctx)175   void Compute(OpKernelContext* ctx) final {
176     const CSRSparseMatrix* sparse_matrix_a;
177     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &sparse_matrix_a));
178     const Tensor& matrix_b = ctx->input(1);
179 
180     int rank;
181     int64_t batch_size;
182     OP_REQUIRES_OK(ctx, this->ValidateInputs(*sparse_matrix_a, matrix_b, &rank,
183                                              &batch_size));
184 
185     const auto dense_shape = sparse_matrix_a->dense_shape().vec<int64_t>();
186     int64_t num_lhs_rows = dense_shape(rank - 2);
187     int64_t num_lhs_cols = dense_shape(rank - 1);
188     int64_t num_rhs_rows = matrix_b.dim_size(rank - 2);
189     int64_t num_rhs_cols = matrix_b.dim_size(rank - 1);
190 
191     if (this->transpose_a_) {
192       std::swap(num_lhs_rows, num_lhs_cols);
193     }
194 
195     // Possibly transpose the dense Tensor b.
196     const Tensor* rhs = &matrix_b;
197     Tensor b_transposed;
198     if (this->transpose_b_) {
199       OP_REQUIRES_OK(
200           ctx, TransposeAndConjugateTensor(ctx, matrix_b, this->conjugate_b_,
201                                            &b_transposed));
202       rhs = &b_transposed;
203       std::swap(num_rhs_rows, num_rhs_cols);
204     }
205 
206     // If we're transposing the output, then allocate a temporary buffer to
207     // store the output. Otherwise allocate the output directly.
208     Tensor* output = nullptr;
209     Tensor* matmul_result = nullptr;
210     Tensor output_transposed;
211     OP_REQUIRES_OK(
212         ctx, AllocateOutput(ctx, rank, batch_size, num_lhs_rows, num_rhs_cols,
213                             this->transpose_output_, &output,
214                             &output_transposed, &matmul_result));
215 
216     if (!this->transpose_a_) {
217       SparseDenseMatMulWithoutTransposedLHS(
218           ctx, batch_size, num_lhs_rows, *sparse_matrix_a, *rhs, matmul_result);
219     } else {  // transpose_a_ == true
220       SparseDenseMatMulWithTransposedLHS(ctx, batch_size, num_lhs_rows,
221                                          num_lhs_cols, *sparse_matrix_a, *rhs,
222                                          matmul_result);
223     }
224 
225     // Transpose (and conjugate) the output if necessary.
226     // Note that conjugate is only true if transpose is also true.
227     if (this->transpose_output_) {
228       OP_REQUIRES_OK(
229           ctx, TransposeAndConjugateAllocatedTensor(
230                    ctx, output_transposed, this->conjugate_output_, output));
231     } else if (this->conjugate_output_) {
232       functor::maybe_conj_inplace<CPUDevice, T>::run(
233           ctx->eigen_device<CPUDevice>(), output);
234     }
235   }
236 
237  private:
238   // Allocates the output with the appropriate shape. Additionally, if
239   // transpose_output is True, allocates a temporary buffer with the transposed
240   // output. 'matmul_result' points to either output or output_transposed, based
241   // on whether transpose_output is True.
AllocateOutput(OpKernelContext * ctx,const int32_t rank,const int64_t batch_size,const int64_t num_rows,const int64_t num_cols,const bool transpose_output,Tensor ** output,Tensor * output_transposed,Tensor ** matmul_result)242   Status AllocateOutput(OpKernelContext* ctx, const int32_t rank,
243                         const int64_t batch_size, const int64_t num_rows,
244                         const int64_t num_cols, const bool transpose_output,
245                         Tensor** output, Tensor* output_transposed,
246                         Tensor** matmul_result) {
247     TensorShape output_shape;
248     if (rank == 3) output_shape.AddDim(batch_size);
249 
250     if (!transpose_output) {
251       output_shape.AppendShape({num_rows, num_cols});
252       TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output));
253       *matmul_result = *output;
254     } else {
255       TensorShape output_transposed_shape = output_shape;
256       output_transposed_shape.AppendShape({num_rows, num_cols});
257       output_shape.AppendShape({num_cols, num_rows});
258       TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
259                                             output_transposed_shape,
260                                             output_transposed));
261       TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output));
262       *matmul_result = output_transposed;
263     }
264     return OkStatus();
265   }
266 
267   // Returns an Eigen::Ref expression of a sparse sub-matrix from the given
268   // contiguous segment of rows of the CSR Sparse Matrix.
GetSparseMatrixRef(const CSRSparseMatrix & csr_matrix,const int batch_index,const int64_t row_begin,const int64_t num_shard_rows,std::vector<int32> * row_ptrs)269   Eigen::Ref<const SparseMatrix> GetSparseMatrixRef(
270       const CSRSparseMatrix& csr_matrix, const int batch_index,
271       const int64_t row_begin, const int64_t num_shard_rows,
272       std::vector<int32>* row_ptrs) {
273     // Compute the row pointers of the sparse sub-matrix.
274     row_ptrs->resize(num_shard_rows + 1);
275     const int64_t row_offset =
276         csr_matrix.row_pointers_vec(batch_index)(row_begin);
277     for (int64_t row_idx = 0; row_idx <= num_shard_rows; ++row_idx) {
278       row_ptrs->at(row_idx) =
279           csr_matrix.row_pointers_vec(batch_index)(row_begin + row_idx) -
280           row_offset;
281     }
282     const int64_t num_cols =
283         csr_matrix.dense_shape().vec<int64_t>()(csr_matrix.dims() - 1);
284     return Eigen::Map<const SparseMatrix>(
285         num_shard_rows /* num_rows */, num_cols /* num_cols */,
286         row_ptrs->at(num_shard_rows) /* total_nnz */, row_ptrs->data(),
287         csr_matrix.col_indices_vec(batch_index).data() + row_offset,
288         csr_matrix.values_vec<T>(batch_index).data() + row_offset);
289   }
290 
291   // Sparse-Dense Matrix Multiplication between a CSRSparseMatrix (LHS) and a
292   // dense Tensor (RHS).
SparseDenseMatMulWithoutTransposedLHS(OpKernelContext * ctx,const int64_t batch_size,const int64_t num_lhs_rows,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)293   void SparseDenseMatMulWithoutTransposedLHS(OpKernelContext* ctx,
294                                              const int64_t batch_size,
295                                              const int64_t num_lhs_rows,
296                                              const CSRSparseMatrix& lhs,
297                                              const Tensor& rhs,
298                                              Tensor* output) {
299     // Parallelize matrix multiplication across batch dimensions and across
300     // rows in each batch.
301     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
302     const int32_t num_threads = worker_threads.num_threads;
303     const int64_t block_size =
304         num_lhs_rows / std::max(kMaxShards, kNumShardsPerThread * num_threads);
305     const int64_t num_rhs_rows = rhs.dim_size(rhs.dims() - 2);
306     const int64_t num_rhs_cols = rhs.dim_size(rhs.dims() - 1);
307     worker_threads.workers->ParallelFor(
308         batch_size * num_lhs_rows /* total */,
309         thread::ThreadPool::SchedulingParams(
310             thread::ThreadPool::SchedulingStrategy::
311                 kFixedBlockSize /* strategy */,
312             absl::nullopt /* cost_per_unit */, block_size),
313         [&](int64_t batch_and_row_begin, int64_t batch_and_row_end) {
314           HandleBatchAndRowRange(
315               num_lhs_rows, batch_and_row_begin, batch_and_row_end,
316               [&](int64_t batch_idx, int64_t row_begin, int64_t row_end) {
317                 const int64_t num_shard_rows = row_end - row_begin;
318 
319                 // Define an Eigen::SparseMatrix over the row range:
320                 // [row_begin, row_end) of the CSR SparseMatrix A.
321                 std::vector<int32> row_ptrs;
322                 auto sparse_matrix = GetSparseMatrixRef(
323                     lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs);
324 
325                 // Map the corresponding rows of the rhs.
326                 ConstMatrixMap rhs_map(rhs.flat<T>().data() + batch_idx *
327                                                                   num_rhs_rows *
328                                                                   num_rhs_cols,
329                                        num_rhs_rows, num_rhs_cols);
330 
331                 // Write to the corresponding rows of the output matrix.
332                 MatrixMap output_map(
333                     output->flat<T>().data() +
334                         batch_idx * num_lhs_rows * num_rhs_cols +
335                         row_begin * num_rhs_cols,
336                     num_shard_rows, num_rhs_cols);
337                 output_map.noalias() = sparse_matrix * rhs_map;
338               });
339         });
340   }
341 
342   // Sparse-Dense Matrix Multiplication assuming the CSRSparseMatrix (LHS) is
343   // to be transposed before the operation.
SparseDenseMatMulWithTransposedLHS(OpKernelContext * ctx,const int64_t batch_size,const int64_t num_lhs_rows,const int64_t num_lhs_cols,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)344   void SparseDenseMatMulWithTransposedLHS(OpKernelContext* ctx,
345                                           const int64_t batch_size,
346                                           const int64_t num_lhs_rows,
347                                           const int64_t num_lhs_cols,
348                                           const CSRSparseMatrix& lhs,
349                                           const Tensor& rhs, Tensor* output) {
350     auto device = ctx->eigen_device<CPUDevice>();
351     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
352     const int32_t num_threads = worker_threads.num_threads;
353     const int64_t num_rhs_rows = rhs.dim_size(rhs.dims() - 2);
354     const int64_t num_rhs_cols = rhs.dim_size(rhs.dims() - 1);
355     // Usually, we want to avoid transposing the sparse matrix A since it may be
356     // an expensive operation. Instead, we use the identity (A^T B) = (B^T A)^T.
357     // We don't actually transpose B or the output because it is more convenient
358     // to have them in column major form.
359     //
360     // However, if A is hypersparse and B and C are huge, transposing A will be
361     // cheaper. In the future, we should have a cost model estimating the cost
362     // of transposing all matrices (A, B, C) to decide which variant to use.
363 
364     // Each thread writes to its own copy of the matrix product. These
365     // `num_threads` copies are summed together to obtain the final result.
366     Tensor matmul_result_buffer;
367     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
368                                            TensorShape({num_threads + 1,
369                                                         output->NumElements()}),
370                                            &matmul_result_buffer));
371     functor::SetZeroFunctor<CPUDevice, T> set_zero;
372     set_zero(device, matmul_result_buffer.flat<T>());
373 
374     // Parallelize matrix multiplication across batch dimensions and across
375     // columns of A^T in each batch. These correspond to rows of A.
376     const int64_t block_size =
377         num_lhs_cols / std::max(kMaxShards, kNumShardsPerThread * num_threads);
378     worker_threads.workers->ParallelForWithWorkerId(
379         batch_size * num_lhs_cols /* total */,
380         thread::ThreadPool::SchedulingParams(
381             thread::ThreadPool::SchedulingStrategy::
382                 kFixedBlockSize /* strategy */,
383             absl::nullopt /* cost_per_unit */, block_size),
384         [&](int64_t batch_and_row_begin, int64_t batch_and_row_end, int tid) {
385           HandleBatchAndRowRange(
386               num_lhs_cols, batch_and_row_begin, batch_and_row_end,
387               [&](int64_t batch_idx, int64_t row_begin, int64_t row_end) {
388                 const int64_t num_shard_rows = row_end - row_begin;
389 
390                 // Define a new sparse sub-matrix from the row range
391                 // [row_begin, row_end) of the sparse matrix A.
392                 std::vector<int32> row_ptrs;
393                 auto sparse_matrix = GetSparseMatrixRef(
394                     lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs);
395 
396                 // Map the corresponding `num_shard_rows` columns of B^T.
397                 // This is the same as taking the `num_shard_rows` rows of B.
398                 ConstMatrixMap b_dense_map(
399                     rhs.flat<T>().data() +
400                         batch_idx * num_rhs_rows * num_rhs_cols +
401                         row_begin * num_rhs_cols,
402                     num_shard_rows, num_rhs_cols);
403 
404                 // Map to the corresponding rows of the output.
405                 MatrixMap output_map(
406                     matmul_result_buffer.flat<T>().data() +
407                         tid * batch_size * num_lhs_rows * num_rhs_cols +
408                         batch_idx * num_lhs_rows * num_rhs_cols,
409                     num_lhs_rows, num_rhs_cols);
410 
411                 // Compute the product C^T = B^T * A; restricted to the row
412                 // range in the current shard.
413                 if (this->conjugate_a_) {
414                   output_map.transpose().noalias() +=
415                       b_dense_map.transpose() * sparse_matrix.conjugate();
416                 } else {
417                   output_map.transpose().noalias() +=
418                       b_dense_map.transpose() * sparse_matrix;
419                 }
420               });
421         });
422 
423     // Sum across each thread's matmul result.
424     using Reducer = Eigen::internal::SumReducer<T>;
425     using Index = typename TTypes<T>::Tensor::Index;
426     output->flat<T>().device(device) = matmul_result_buffer.matrix<T>().reduce(
427         Eigen::array<Index, 1>({0}), Reducer());
428   }
429 
430   // Given a range [batch_and_row_begin, batch_and_row_end) which is a
431   // contiguous subset of [0, num_rows * batch_size), calls the function
432   // fn(batch_idx, row_begin, row_end) for each batch index
433   // and the row range [row_begin, row_end) contained in the batch.
HandleBatchAndRowRange(const int64_t num_rows,const int64_t batch_and_row_begin,const int64_t batch_and_row_end,const std::function<void (int64_t,int64_t,int64_t)> & fn)434   void HandleBatchAndRowRange(
435       const int64_t num_rows, const int64_t batch_and_row_begin,
436       const int64_t batch_and_row_end,
437       const std::function<void(int64_t, int64_t, int64_t)>& fn) {
438     // Obtain the batch indices overlapping with the current shard.
439     const int64_t batch_begin = batch_and_row_begin / num_rows;
440     const int64_t batch_end_inclusive = batch_and_row_end / num_rows;
441 
442     for (int64_t batch_idx = batch_begin; batch_idx <= batch_end_inclusive;
443          ++batch_idx) {
444       // Find the contiguous set of rows which are contained in this shard as
445       // well as the current batch. We intersect with interval [batch_idx *
446       // num_rows, (batch_idx + 1) * num_rows) which denotes the current batch.
447       const int64_t current_batch_row_begin =
448           std::max(batch_and_row_begin, batch_idx * num_rows);
449       const int64_t current_batch_row_end =
450           std::min(batch_and_row_end, (batch_idx + 1) * num_rows);
451 
452       const int64_t row_begin = current_batch_row_begin % num_rows;
453       const int64_t num_shard_rows =
454           current_batch_row_end - current_batch_row_begin;
455       // Edge case for when current_batch_row_end is the first index of a new
456       // row.
457       if (num_shard_rows == 0) continue;
458 
459       fn(batch_idx, row_begin, row_begin + num_shard_rows);
460     }
461   }
462 
463   // Transposes (and optionally, conjugates) a given Tensor. Also allocates the
464   // required memory for the output Tensor.
TransposeAndConjugateTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)465   Status TransposeAndConjugateTensor(OpKernelContext* ctx, const Tensor& input,
466                                      bool conjugate, Tensor* output) {
467     TensorShape transposed_shape = input.shape();
468     transposed_shape.set_dim(input.dims() - 1,
469                              input.dim_size(input.dims() - 2));
470     transposed_shape.set_dim(input.dims() - 2,
471                              input.dim_size(input.dims() - 1));
472     TF_RETURN_IF_ERROR(
473         ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output));
474     return TransposeAndConjugateAllocatedTensor(ctx, input, conjugate, output);
475   }
476 
477   // Transposes (and optionally, conjugates) a given Tensor. The output should
478   // be already allocated.
TransposeAndConjugateAllocatedTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)479   Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx,
480                                               const Tensor& input,
481                                               bool conjugate, Tensor* output) {
482     if (conjugate) {
483       TF_RETURN_IF_ERROR(DoConjugateMatrixTranspose(
484           ctx->eigen_device<CPUDevice>(), input, output));
485     } else {
486       TF_RETURN_IF_ERROR(
487           DoMatrixTranspose(ctx->eigen_device<CPUDevice>(), input, output));
488     }
489     return OkStatus();
490   }
491 };
492 
493 // GPU Kernel to compute sparse-dense matrix multiplication.
494 template <typename T>
495 class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> {
496   using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
497   using Matrix =
498       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
499   using ConstMatrixMap = Eigen::Map<const Matrix>;
500   using MatrixMap = Eigen::Map<Matrix>;
501 
502  public:
CSRMatMulGPUOp(OpKernelConstruction * c)503   explicit CSRMatMulGPUOp(OpKernelConstruction* c)
504       : CSRMatMulOp<GPUDevice, T>(c) {}
505 
~CSRMatMulGPUOp()506   ~CSRMatMulGPUOp() override {}
507 
Compute(OpKernelContext * ctx)508   void Compute(OpKernelContext* ctx) final {
509     const CSRSparseMatrix* a_matrix;
510     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix));
511     const Tensor& b_t = ctx->input(1);
512 
513     int rank;
514     int64_t batch_size;
515     OP_REQUIRES_OK(ctx,
516                    this->ValidateInputs(*a_matrix, b_t, &rank, &batch_size));
517 
518     const Tensor& a_dense_shape_t = a_matrix->dense_shape();
519     TensorShape a_dense_tensor_shape;
520     auto a_dense_shape = a_dense_shape_t.vec<int64_t>();
521     OP_REQUIRES_OK(
522         ctx, TensorShapeUtils::MakeShape(a_dense_shape, &a_dense_tensor_shape));
523 
524     const int row_dim = (rank == 2) ? 0 : 1;
525     const int64_t a_outer_dim = a_dense_tensor_shape.dim_size(
526         this->transpose_a_ ? row_dim + 1 : row_dim);
527     const int64_t b_inner_dim =
528         b_t.shape().dim_size(this->transpose_b_ ? row_dim + 1 : row_dim);
529     const int64_t b_outer_dim =
530         b_t.dim_size(this->transpose_b_ ? row_dim : row_dim + 1);
531     const int64_t b_slice_size = b_inner_dim * b_outer_dim;
532 
533     TensorShape c_shape;
534     if (rank == 3) c_shape.AddDim(batch_size);
535     if (this->transpose_output_) {
536       c_shape.AddDim(b_outer_dim);
537       c_shape.AddDim(a_outer_dim);
538     } else {
539       c_shape.AddDim(a_outer_dim);
540       c_shape.AddDim(b_outer_dim);
541     }
542 
543     const int64_t c_matrix_lhs = c_shape.dim_size(row_dim);
544     const int64_t c_matrix_rhs = c_shape.dim_size(row_dim + 1);
545     const int64_t c_slice_size = c_matrix_lhs * c_matrix_rhs;
546     Tensor* c_t;
547     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t));
548 
549     const GPUDevice& d = ctx->eigen_device<GPUDevice>();
550     bool use_matrix_vector_multiply = (b_outer_dim == 1);
551 #if TENSORFLOW_USE_ROCM
552     // ROCm hipsparse does not implement csrmv with transposed input a
553     use_matrix_vector_multiply =
554         use_matrix_vector_multiply && !this->transpose_a_;
555 #endif
556     if (use_matrix_vector_multiply) {
557       // Call matrix-vector multiply if b is a vector.
558       TTypes<int64_t>::ConstVec a_dense_shape_comp(
559           a_dense_shape.data() + row_dim, 2);
560       Tensor b_conj_t;
561       const T* b_base_ptr = b_t.template flat<T>().data();
562       bool conjugate_a = this->conjugate_a_;
563       bool conjugate_output = this->conjugate_output_;
564       if (this->conjugate_b_) {
565         if (conjugate_a) {
566           // In this case we can use the identity
567           //   conj(a) * conj(b) = conj(a * b)
568           // instead of creating a conjugated copy of b.
569           conjugate_a = false;
570           conjugate_output = !conjugate_output;
571         } else {
572           OP_REQUIRES_OK(
573               ctx, ctx->forward_input_or_allocate_temp(
574                        {1}, DataTypeToEnum<T>::value, b_t.shape(), &b_conj_t));
575           functor::maybe_conj<GPUDevice, T>::run(d, b_t, &b_conj_t);
576           b_base_ptr = b_conj_t.template flat<T>().data();
577         }
578       }
579 
580       functor::CSRSparseMatrixMatVec<GPUDevice, T> csr_spmv(this->transpose_a_,
581                                                             conjugate_a);
582       for (int i = 0; i < batch_size; ++i) {
583         auto a_row_ptr = a_matrix->row_pointers_vec(i);
584         auto a_col_ind = a_matrix->col_indices_vec(i);
585         auto a_values = a_matrix->values_vec<T>(i);
586         ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values,
587                                     a_dense_shape_comp};
588         const T* b_i = b_base_ptr + i * b_slice_size;
589         T* c_i = &c_t->template flat<T>()(i * c_slice_size);
590         Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i);
591         OP_REQUIRES_OK(ctx, s);
592       }
593       if (conjugate_output) {
594         functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t);
595       }
596       return;
597     }
598 
599     functor::CSRSparseMatrixMatMul<GPUDevice, T> csr_spmmadd(
600         this->transpose_output_);
601 
602     Tensor c_mat_col_major_t;
603     if (!this->transpose_output_) {
604       // If transpose_output is false, we'll need to transpose the (col
605       // major) output of the csrgemm call to get proper (row-major)
606       // output.  Which means we need to keep a temporary buffer to
607       // store the intermediate gemm output.
608       TensorShape c_mat_col_major_shape;
609       if (rank == 2) {
610         c_mat_col_major_shape = TensorShape({c_matrix_rhs, c_matrix_lhs});
611       } else {
612         c_mat_col_major_shape =
613             TensorShape({batch_size, c_matrix_rhs, c_matrix_lhs});
614       }
615       OP_REQUIRES_OK(
616           ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
617                                   c_mat_col_major_shape, &c_mat_col_major_t));
618     }
619 
620     // If transpose_output is true, return the direct (column-major i.e.,
621     // transposed) output of the csrgemm call.  Otherwise we'll need
622     // to transpose it to row major format.
623     auto c_mat_col_major = (this->transpose_output_)
624                                ? c_t->flat<T>()
625                                : c_mat_col_major_t.flat<T>();
626 
627     // Possibly transpose a.
628     const CSRSparseMatrix* a_input_matrix;
629     // If we need to transpose a, we will store the result temporarily
630     // in the object below.
631     CSRSparseMatrix a_matrix_transposed;
632     if (!this->transpose_a_) {
633       a_input_matrix = a_matrix;
634     } else {
635       functor::CSRSparseMatrixTranspose<GPUDevice, T> transpose;
636       OP_REQUIRES_OK(ctx, transpose(ctx, this->conjugate_a_, *a_matrix,
637                                     &a_matrix_transposed));
638       a_input_matrix = &a_matrix_transposed;
639     }
640 
641     auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64_t>();
642 
643     // Possibly transpose b.
644     Tensor b_t_input;
645     if (!this->transpose_b_) {
646       b_t_input = b_t;
647     } else {
648       TensorShape b_t_transposed_shape;
649       if (rank == 3) {
650         b_t_transposed_shape.AddDim(batch_size);
651       }
652       b_t_transposed_shape.AddDim(b_t.dim_size(row_dim + 1));
653       b_t_transposed_shape.AddDim(b_t.dim_size(row_dim));
654       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
655                                              b_t_transposed_shape, &b_t_input));
656       const GPUDevice& d = ctx->eigen_device<GPUDevice>();
657       if (this->conjugate_b_) {
658         OP_REQUIRES_OK(ctx, DoConjugateMatrixTranspose(d, b_t /*input*/,
659                                                        &b_t_input /*output*/));
660       } else {
661         OP_REQUIRES_OK(
662             ctx, DoMatrixTranspose(d, b_t /*input*/, &b_t_input /*output*/));
663       }
664     }
665 
666     // Dense shape of a batch component of A.
667     TTypes<int64_t>::ConstVec a_input_dense_shape_comp(
668         a_input_dense_shape.data() + row_dim, 2);
669 
670     auto b = b_t_input.flat<T>();
671 
672     for (int i = 0; i < batch_size; ++i) {
673       auto a_row_ptr = a_input_matrix->row_pointers_vec(i);
674       auto a_col_ind = a_input_matrix->col_indices_vec(i);
675       auto a_values = a_input_matrix->values_vec<T>(i);
676       typename TTypes<T>::UnalignedConstMatrix b_i(b.data() + i * b_slice_size,
677                                                    {b_inner_dim, b_outer_dim});
678       typename TTypes<T>::UnalignedMatrix c_mat_col_major_i(
679           c_mat_col_major.data() + i * c_slice_size,
680           {c_matrix_lhs, c_matrix_rhs});
681       ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values,
682                                   a_input_dense_shape_comp};
683       Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i);
684       OP_REQUIRES_OK(ctx, s);
685     }
686 
687     if (!this->transpose_output_) {
688       // We need to return values in row major format, so transpose
689       // the column-major values in c_mat_col_major_t to row-major output c_t.
690       OP_REQUIRES_OK(ctx, DoMatrixTranspose(d, /*input=*/c_mat_col_major_t,
691                                             /*output=*/c_t));
692     }
693     if (this->conjugate_output_) {
694       functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t);
695     }
696   }
697 };
698 
699 #define REGISTER_CPU(T)                                                     \
700   REGISTER_KERNEL_BUILDER(                                                  \
701       Name("SparseMatrixMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
702       CSRMatMulCPUOp<T>);
703 
704 REGISTER_CPU(float)
705 REGISTER_CPU(double)
706 REGISTER_CPU(complex64)
707 REGISTER_CPU(complex128)
708 
709 #undef REGISTER_CPU
710 
711 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
712 
713 #define REGISTER_GPU(T)                                                     \
714   REGISTER_KERNEL_BUILDER(                                                  \
715       Name("SparseMatrixMatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
716       CSRMatMulGPUOp<T>);
717 
718 REGISTER_GPU(float)
719 REGISTER_GPU(double)
720 REGISTER_GPU(complex64)
721 REGISTER_GPU(complex128)
722 
723 #undef REGISTER_GPU
724 
725 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
726 
727 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
728 
729 namespace functor {
730 
731 namespace {
732 
733 // GPUDataType<T>::type translates from a C++ type (e.g. float) to a
734 // GPUDataType_t (e.g. CUDA_R_32F).
735 template <typename T>
736 struct GPUDataType;
737 
738 // GPUDataType templates are currently not instantiated in the ROCm flow
739 // So leaving out the #elif TENSORFLOW_USE_ROCM blocks for now
740 // hipblas library is not (yet) being pulled in via rocm_configure.bzl
741 // so cannot reference tyeps from hipblas headers here
742 template <>
743 struct GPUDataType<Eigen::half> {
744 #if GOOGLE_CUDA
745   static constexpr cudaDataType_t type = CUDA_R_16F;
746 #endif
747 };
748 
749 template <>
750 struct GPUDataType<float> {
751 #if GOOGLE_CUDA
752   static constexpr cudaDataType_t type = CUDA_R_32F;
753 #endif
754 };
755 
756 template <>
757 struct GPUDataType<std::complex<float>> {
758 #if GOOGLE_CUDA
759   static constexpr cudaDataType_t type = CUDA_C_32F;
760 #endif
761 };
762 
763 template <>
764 struct GPUDataType<double> {
765 #if GOOGLE_CUDA
766   static constexpr cudaDataType_t type = CUDA_R_64F;
767 #endif
768 };
769 
770 template <>
771 struct GPUDataType<std::complex<double>> {
772 #if GOOGLE_CUDA
773   static constexpr cudaDataType_t type = CUDA_C_64F;
774 #endif
775 };
776 
777 }  // namespace
778 
779 template <typename T>
780 class CSRSparseMatrixMatMul<GPUDevice, T> {
781  public:
CSRSparseMatrixMatMul(const bool transpose_output)782   explicit CSRSparseMatrixMatMul(const bool transpose_output)
783       : transpose_output_(transpose_output) {}
784 
Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,typename TTypes<T>::UnalignedConstMatrix b,typename TTypes<T>::UnalignedMatrix c)785   Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
786                  typename TTypes<T>::UnalignedConstMatrix b,
787                  typename TTypes<T>::UnalignedMatrix c) {
788     GpuSparse cuda_sparse(ctx);
789     TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
790     {
791       // Use Csrmm/SpMM to calculate:
792       //   C = alpha * op(A) * op(B) + beta * C
793       // where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense.
794       // Note that Csrmm/Spmm assumes B and C are in column-major form; so we
795       // use transB == true, and manually transpose the output in place
796       // using blas<t>geam.
797       // TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint.
798 
799       // Create alpha and beta scalars; alpha = 1.0, beta = 0.0
800       // TODO(ebrevdo,rmlarsen): Add support for non-trivial alpha and beta.
801       const T alpha = 1;
802       const T beta = 0;
803 
804       // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
805       const int k = b.dimension(0);
806       DCHECK_EQ(k, a.dense_shape_host(1));
807 
808       // If transpose_output_ is true, then the c matrix we receive
809       // here is the direct row major output (into which we will store
810       // csrgemm's col major output).  Otherwise it's a
811       // temporary tensor that will store the column major output that
812       // will eventually be transposed.
813       const int m = c.dimension(transpose_output_ ? 1 : 0);
814       const int n = c.dimension(transpose_output_ ? 0 : 1);
815       DCHECK_EQ(m, a.dense_shape_host(0));
816       DCHECK_EQ(n, b.dimension(1));
817       const int nnz = a.values.size();
818       DCHECK_EQ(nnz, a.col_ind.size());
819 
820       // ldb: leading dimension of B. If op(B)=B, it must be at least max(1, k)
821       // if op(A) = A and at least max (1, m) otherwise. If op(B) != B, it must
822       // be at least max(1, n).
823       const int ldb = n;
824       // ldc: leading dimension of C. It must be at least max(1, m) if
825       // op(A) = A and at least max(1, k) otherwise.
826       const int ldc = m;
827 
828       // transA must be non-transpose if transB is transpose (cusparse
829       // limitation).
830 #if GOOGLE_CUDA
831       const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
832 #elif TENSORFLOW_USE_ROCM
833       const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
834 #endif
835 
836       // transB: b is row-major, and cusparse requires col-major b (or
837       // equivalently transB == transpose).  this version is actually more
838       // efficient.
839 #if GOOGLE_CUDA && CUDA_VERSION >= 10020
840 
841       const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
842       gpusparseSpMatDescr_t matA;
843       gpusparseDnMatDescr_t matB, matC;
844 
845       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
846           &matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()),
847           const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()),
848           CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO,
849           GPUDataType<T>::type));
850 
851       TF_RETURN_IF_GPUSPARSE_ERROR(
852           cusparseCreateDnMat(&matB, n, k, ldb, const_cast<T*>(b.data()),
853                               GPUDataType<T>::type, CUSPARSE_ORDER_COL));
854 
855       TF_RETURN_IF_GPUSPARSE_ERROR(
856           cusparseCreateDnMat(&matC, m, n, ldc, c.data(), GPUDataType<T>::type,
857                               CUSPARSE_ORDER_COL));
858 
859       size_t bufferSize = 0;
860       TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize(
861           transA, transB, &alpha, matA, matB, &beta, matC,
862           CUSPARSE_MM_ALG_DEFAULT, &bufferSize));
863 
864       Tensor buffer;
865       TF_RETURN_IF_ERROR(ctx->allocate_temp(
866           DT_INT8, TensorShape({static_cast<int64_t>(bufferSize)}), &buffer));
867       DCHECK(buffer.flat<int8>().data() != nullptr);
868 
869       TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB,
870                                           &beta, matC, CUSPARSE_MM_ALG_DEFAULT,
871                                           buffer.flat<int8>().data()));
872 
873       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB));
874       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC));
875       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA));
876 
877 #elif TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 40200
878       // Use SPMM
879       const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
880       gpusparseSpMatDescr_t matA;
881       gpusparseDnMatDescr_t matB, matC;
882 
883       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateCsr(
884           &matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()),
885           const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()),
886           CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO,
887           GPUDataType<T>::type));
888 
889       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateDnMat(
890           &matB, n, k, ldb, const_cast<T*>(b.data()), GPUDataType<T>::type,
891           HIPSPARSE_ORDER_COL));
892 
893       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateDnMat(
894           &matC, m, n, ldc, c.data(), GPUDataType<T>::type,
895           HIPSPARSE_ORDER_COL));
896 
897       size_t bufferSize = 0;
898       TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize(
899           transA, transB, &alpha, matA, matB, &beta, matC,
900           HIPSPARSE_MM_ALG_DEFAULT, &bufferSize));
901 
902       Tensor buffer;
903       TF_RETURN_IF_ERROR(ctx->allocate_temp(
904           DT_INT8, TensorShape({static_cast<int64_t>(bufferSize)}), &buffer));
905       DCHECK(buffer.flat<int8>().data() != nullptr);
906 
907       TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB,
908                                           &beta, matC, CUSPARSE_MM_ALG_DEFAULT,
909                                           buffer.flat<int8>().data()));
910 
911       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroyDnMat(matB));
912       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroyDnMat(matC));
913       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroySpMat(matA));
914 
915 #else
916 
917 #if GOOGLE_CUDA
918 
919       const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
920 
921       gpusparseMatDescr_t descrA;
922       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
923       TF_RETURN_IF_GPUSPARSE_ERROR(
924           cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
925       TF_RETURN_IF_GPUSPARSE_ERROR(
926           cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
927 
928 #elif TENSORFLOW_USE_ROCM
929 
930       const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
931 
932       gpusparseMatDescr_t descrA;
933       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
934       TF_RETURN_IF_GPUSPARSE_ERROR(
935           wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
936       TF_RETURN_IF_GPUSPARSE_ERROR(
937           wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
938 #endif  // GOOGLE_CUDA
939 
940       TF_RETURN_IF_ERROR(
941           cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA,
942                             a.values.data(), a.row_ptr.data(), a.col_ind.data(),
943                             b.data(), ldb, &beta, c.data(), ldc));
944 
945 #endif  // GOOGLE_CUDA && CUDA_VERSION >= 10020
946     }
947 
948     return OkStatus();
949   }
950 
951  private:
952   bool transpose_output_;
953 };
954 
955 template <typename T>
956 class CSRSparseMatrixMatVec<GPUDevice, T> {
957  public:
CSRSparseMatrixMatVec(bool transpose_a,bool conjugate_a)958   CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a)
959       : transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a,
960                                                    &status_)) {}
961 
Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,const T * x,T * y)962   Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
963                  const T* x, T* y) {
964     TF_RETURN_IF_ERROR(status_);
965     GpuSparse cuda_sparse(ctx);
966     TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
967     {
968       // Use Csrmv to calculate:
969       //   y = alpha * op(A) * x + beta * y
970       // where alpha = 1.0, beta = 0.0, A is a sparse matrix and x and y are
971       // dense vectors.
972 
973       // Create alpha and beta scalars; alpha = 1.0, beta = 0.0
974       // TODO(rmlarsen,ebrevdo): Add support for general alpha, beta.
975       const T alpha = 1;
976       const T beta = 0;
977 
978 #if GOOGLE_CUDA && CUDA_VERSION < 10020
979       gpusparseMatDescr_t descrA;
980       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
981       TF_RETURN_IF_GPUSPARSE_ERROR(
982           cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
983       TF_RETURN_IF_GPUSPARSE_ERROR(
984           cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
985 #elif TENSORFLOW_USE_ROCM
986       gpusparseMatDescr_t descrA;
987       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
988       TF_RETURN_IF_GPUSPARSE_ERROR(
989           wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
990       TF_RETURN_IF_GPUSPARSE_ERROR(
991           wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
992 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
993 
994       const int m = a.dense_shape_host(0);
995       const int n = a.dense_shape_host(1);
996       const int nnz = a.values.size();
997       DCHECK_EQ(nnz, a.col_ind.size());
998 #if GOOGLE_CUDA && (CUDA_VERSION >= 10020)
999       TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha,
1000                                            a.values.data(), a.row_ptr.data(),
1001                                            a.col_ind.data(), x, &beta, y));
1002 #else
1003       TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA,
1004                                            a.values.data(), a.row_ptr.data(),
1005                                            a.col_ind.data(), x, &beta, y));
1006 #endif
1007     }
1008 
1009     return OkStatus();
1010   }
1011 
1012  private:
1013   Status status_;
1014   const gpusparseOperation_t transA_;
1015 };
1016 
1017 }  // namespace functor
1018 
1019 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1020 
1021 }  // namespace tensorflow
1022