1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 19 #define EIGEN_USE_GPU 20 #endif 21 22 #include "third_party/eigen3/Eigen/Core" 23 #include "third_party/eigen3/Eigen/SparseCore" 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 #include "tensorflow/core/framework/type_traits.h" 29 #include "tensorflow/core/framework/variant_op_registry.h" 30 #include "tensorflow/core/kernels/cwise_ops_common.h" 31 #include "tensorflow/core/kernels/dense_update_functor.h" 32 #include "tensorflow/core/kernels/fill_functor.h" 33 #include "tensorflow/core/kernels/sparse/kernels.h" 34 #include "tensorflow/core/kernels/sparse/sparse_matrix.h" 35 #include "tensorflow/core/kernels/sparse/transpose_op.h" 36 #include "tensorflow/core/kernels/transpose_functor.h" 37 #include "tensorflow/core/lib/gtl/inlined_vector.h" 38 #include "tensorflow/core/platform/threadpool.h" 39 40 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 41 #include "tensorflow/core/util/cuda_sparse.h" 42 #include "tensorflow/core/util/gpu_solvers.h" 43 #endif 44 45 namespace tensorflow { 46 47 // TODO(anudhyan): These constants may be tuned based on the performance of 48 // 'benchmark_sparse_matrix_mat_vec_mul'. We would like to find constants 49 // which work across hardware platforms for typical matrix sizes. It should be 50 // possible to observe at least 30-50% improvement as we increase the number 51 // of threads by 1. If not, then it may we worth increasing kMaxShards and 52 // kNumShardsPerThread. However, once we have too many shards, latency may be 53 // dominated by per-shard overhead. 54 // 55 // Maximum number of shards into which to divide the computation for each CSR 56 // Sparse Matrix instance. 57 static constexpr int32_t kMaxShards = 20; 58 // Number of shards allocated to each thread. 59 static constexpr int32_t kNumShardsPerThread = 3; 60 61 typedef Eigen::ThreadPoolDevice CPUDevice; 62 typedef Eigen::GpuDevice GPUDevice; 63 64 // Abstract OpKernel to compute sparse-dense matrix multiplication. 65 // 66 // Implements a kernel which, given a SparseMatrix `a` and dense Tensor `b`, 67 // computes a dense Tensor `c` satisfying `c = a * b` where * denotes matrix 68 // multiplication. 69 // 70 // The boolean attributes `transpose_a` and `adjoint_a` will transpose or 71 // adjoint `a` before multiplication, respectively. At most one of these 72 // attributes must be set to True. Corresponding attributes will transpose or 73 // adjoint `b` or the output (after multiplication). 74 // 75 // The rank of both `a` and `b` must be equal and their shapes must be 76 // compatible for matrix multiplication. Otherwise, InvalidArgument runtime 77 // errors will be thrown. Only rank 2 or rank 3 inputs are supported. 78 // 79 template <typename Device, typename T> 80 class CSRMatMulOp : public OpKernel { 81 public: CSRMatMulOp(OpKernelConstruction * c)82 explicit CSRMatMulOp(OpKernelConstruction* c) : OpKernel(c) { 83 OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_)); 84 OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_)); 85 bool adjoint_a; 86 OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a)); 87 OP_REQUIRES(c, !(adjoint_a && transpose_a_), 88 errors::InvalidArgument( 89 "Only one of adjoint_a and transpose_a may be true.")); 90 bool adjoint_b; 91 OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b)); 92 OP_REQUIRES(c, !(adjoint_b && transpose_b_), 93 errors::InvalidArgument( 94 "Only one of adjoint_b and transpose_b may be true.")); 95 OP_REQUIRES_OK(c, c->GetAttr("transpose_output", &transpose_output_)); 96 OP_REQUIRES_OK(c, c->GetAttr("conjugate_output", &conjugate_output_)); 97 transpose_a_ |= adjoint_a; 98 transpose_b_ |= adjoint_b; 99 if (is_complex<T>::value) { 100 conjugate_a_ = adjoint_a; 101 conjugate_b_ = adjoint_b; 102 } else { 103 conjugate_a_ = false; 104 conjugate_b_ = false; 105 } 106 } 107 ~CSRMatMulOp()108 ~CSRMatMulOp() override {} 109 ValidateInputs(const CSRSparseMatrix & sparse_matrix_a,const Tensor & dense_tensor_b,int * rank,int64_t * batch_size)110 Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a, 111 const Tensor& dense_tensor_b, int* rank, 112 int64_t* batch_size) { 113 if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) { 114 return errors::InvalidArgument( 115 "Input types don't match. a.dtype == ", 116 DataTypeString(sparse_matrix_a.dtype()), 117 " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype())); 118 } 119 *rank = sparse_matrix_a.dims(); 120 // TODO(ebrevdo): Add support for broadcasting matmul. 121 if (*rank != dense_tensor_b.dims()) { 122 return errors::InvalidArgument("Ranks of a and b must match, saw: ", rank, 123 " vs. ", dense_tensor_b.dims(), "."); 124 } 125 // A valid CSR SparseMatrix has rank 2 or rank 3. 126 *batch_size = (*rank == 2) ? 1 : dense_tensor_b.dim_size(0); 127 if (sparse_matrix_a.batch_size() != *batch_size) { 128 return errors::InvalidArgument("Batch sizes of a and b must match, saw: ", 129 sparse_matrix_a.batch_size(), " vs. ", 130 batch_size, "."); 131 } 132 const auto& a_dense_shape = sparse_matrix_a.dense_shape().vec<int64_t>(); 133 const int64_t a_inner_dim = 134 a_dense_shape(this->transpose_a_ ? *rank - 2 : *rank - 1); 135 const int64_t b_inner_dim = 136 dense_tensor_b.dim_size(this->transpose_b_ ? *rank - 1 : *rank - 2); 137 if (a_inner_dim != b_inner_dim) { 138 return errors::InvalidArgument( 139 "Inner product dimensions of A and B do not agree. Shapes are: ", 140 TensorShape(a_dense_shape), " vs. ", 141 dense_tensor_b.shape().DebugString()); 142 } 143 return OkStatus(); 144 } 145 146 public: 147 bool transpose_a_; 148 bool transpose_b_; 149 bool conjugate_a_; 150 bool conjugate_b_; 151 bool transpose_output_; 152 bool conjugate_output_; 153 }; 154 155 // CPU Kernel to compute sparse-dense matrix multiplication. 156 // 157 // Uses Eigen SparseMatrix to compute the sparse-dense multiplication between 158 // a CSR SparseMatrix `a` and dense Tensor `b`. If intra-op parallelism is 159 // available, the implementation parallelizes the computation across each row 160 // of the sparse matrix. 161 template <typename T> 162 class CSRMatMulCPUOp : public CSRMatMulOp<CPUDevice, T> { 163 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>; 164 using Matrix = 165 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 166 using ConstMatrixMap = Eigen::Map<const Matrix>; 167 using MatrixMap = Eigen::Map<Matrix>; 168 169 public: CSRMatMulCPUOp(OpKernelConstruction * c)170 explicit CSRMatMulCPUOp(OpKernelConstruction* c) 171 : CSRMatMulOp<CPUDevice, T>(c) {} 172 ~CSRMatMulCPUOp()173 ~CSRMatMulCPUOp() override {} 174 Compute(OpKernelContext * ctx)175 void Compute(OpKernelContext* ctx) final { 176 const CSRSparseMatrix* sparse_matrix_a; 177 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &sparse_matrix_a)); 178 const Tensor& matrix_b = ctx->input(1); 179 180 int rank; 181 int64_t batch_size; 182 OP_REQUIRES_OK(ctx, this->ValidateInputs(*sparse_matrix_a, matrix_b, &rank, 183 &batch_size)); 184 185 const auto dense_shape = sparse_matrix_a->dense_shape().vec<int64_t>(); 186 int64_t num_lhs_rows = dense_shape(rank - 2); 187 int64_t num_lhs_cols = dense_shape(rank - 1); 188 int64_t num_rhs_rows = matrix_b.dim_size(rank - 2); 189 int64_t num_rhs_cols = matrix_b.dim_size(rank - 1); 190 191 if (this->transpose_a_) { 192 std::swap(num_lhs_rows, num_lhs_cols); 193 } 194 195 // Possibly transpose the dense Tensor b. 196 const Tensor* rhs = &matrix_b; 197 Tensor b_transposed; 198 if (this->transpose_b_) { 199 OP_REQUIRES_OK( 200 ctx, TransposeAndConjugateTensor(ctx, matrix_b, this->conjugate_b_, 201 &b_transposed)); 202 rhs = &b_transposed; 203 std::swap(num_rhs_rows, num_rhs_cols); 204 } 205 206 // If we're transposing the output, then allocate a temporary buffer to 207 // store the output. Otherwise allocate the output directly. 208 Tensor* output = nullptr; 209 Tensor* matmul_result = nullptr; 210 Tensor output_transposed; 211 OP_REQUIRES_OK( 212 ctx, AllocateOutput(ctx, rank, batch_size, num_lhs_rows, num_rhs_cols, 213 this->transpose_output_, &output, 214 &output_transposed, &matmul_result)); 215 216 if (!this->transpose_a_) { 217 SparseDenseMatMulWithoutTransposedLHS( 218 ctx, batch_size, num_lhs_rows, *sparse_matrix_a, *rhs, matmul_result); 219 } else { // transpose_a_ == true 220 SparseDenseMatMulWithTransposedLHS(ctx, batch_size, num_lhs_rows, 221 num_lhs_cols, *sparse_matrix_a, *rhs, 222 matmul_result); 223 } 224 225 // Transpose (and conjugate) the output if necessary. 226 // Note that conjugate is only true if transpose is also true. 227 if (this->transpose_output_) { 228 OP_REQUIRES_OK( 229 ctx, TransposeAndConjugateAllocatedTensor( 230 ctx, output_transposed, this->conjugate_output_, output)); 231 } else if (this->conjugate_output_) { 232 functor::maybe_conj_inplace<CPUDevice, T>::run( 233 ctx->eigen_device<CPUDevice>(), output); 234 } 235 } 236 237 private: 238 // Allocates the output with the appropriate shape. Additionally, if 239 // transpose_output is True, allocates a temporary buffer with the transposed 240 // output. 'matmul_result' points to either output or output_transposed, based 241 // on whether transpose_output is True. AllocateOutput(OpKernelContext * ctx,const int32_t rank,const int64_t batch_size,const int64_t num_rows,const int64_t num_cols,const bool transpose_output,Tensor ** output,Tensor * output_transposed,Tensor ** matmul_result)242 Status AllocateOutput(OpKernelContext* ctx, const int32_t rank, 243 const int64_t batch_size, const int64_t num_rows, 244 const int64_t num_cols, const bool transpose_output, 245 Tensor** output, Tensor* output_transposed, 246 Tensor** matmul_result) { 247 TensorShape output_shape; 248 if (rank == 3) output_shape.AddDim(batch_size); 249 250 if (!transpose_output) { 251 output_shape.AppendShape({num_rows, num_cols}); 252 TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output)); 253 *matmul_result = *output; 254 } else { 255 TensorShape output_transposed_shape = output_shape; 256 output_transposed_shape.AppendShape({num_rows, num_cols}); 257 output_shape.AppendShape({num_cols, num_rows}); 258 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, 259 output_transposed_shape, 260 output_transposed)); 261 TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output)); 262 *matmul_result = output_transposed; 263 } 264 return OkStatus(); 265 } 266 267 // Returns an Eigen::Ref expression of a sparse sub-matrix from the given 268 // contiguous segment of rows of the CSR Sparse Matrix. GetSparseMatrixRef(const CSRSparseMatrix & csr_matrix,const int batch_index,const int64_t row_begin,const int64_t num_shard_rows,std::vector<int32> * row_ptrs)269 Eigen::Ref<const SparseMatrix> GetSparseMatrixRef( 270 const CSRSparseMatrix& csr_matrix, const int batch_index, 271 const int64_t row_begin, const int64_t num_shard_rows, 272 std::vector<int32>* row_ptrs) { 273 // Compute the row pointers of the sparse sub-matrix. 274 row_ptrs->resize(num_shard_rows + 1); 275 const int64_t row_offset = 276 csr_matrix.row_pointers_vec(batch_index)(row_begin); 277 for (int64_t row_idx = 0; row_idx <= num_shard_rows; ++row_idx) { 278 row_ptrs->at(row_idx) = 279 csr_matrix.row_pointers_vec(batch_index)(row_begin + row_idx) - 280 row_offset; 281 } 282 const int64_t num_cols = 283 csr_matrix.dense_shape().vec<int64_t>()(csr_matrix.dims() - 1); 284 return Eigen::Map<const SparseMatrix>( 285 num_shard_rows /* num_rows */, num_cols /* num_cols */, 286 row_ptrs->at(num_shard_rows) /* total_nnz */, row_ptrs->data(), 287 csr_matrix.col_indices_vec(batch_index).data() + row_offset, 288 csr_matrix.values_vec<T>(batch_index).data() + row_offset); 289 } 290 291 // Sparse-Dense Matrix Multiplication between a CSRSparseMatrix (LHS) and a 292 // dense Tensor (RHS). SparseDenseMatMulWithoutTransposedLHS(OpKernelContext * ctx,const int64_t batch_size,const int64_t num_lhs_rows,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)293 void SparseDenseMatMulWithoutTransposedLHS(OpKernelContext* ctx, 294 const int64_t batch_size, 295 const int64_t num_lhs_rows, 296 const CSRSparseMatrix& lhs, 297 const Tensor& rhs, 298 Tensor* output) { 299 // Parallelize matrix multiplication across batch dimensions and across 300 // rows in each batch. 301 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 302 const int32_t num_threads = worker_threads.num_threads; 303 const int64_t block_size = 304 num_lhs_rows / std::max(kMaxShards, kNumShardsPerThread * num_threads); 305 const int64_t num_rhs_rows = rhs.dim_size(rhs.dims() - 2); 306 const int64_t num_rhs_cols = rhs.dim_size(rhs.dims() - 1); 307 worker_threads.workers->ParallelFor( 308 batch_size * num_lhs_rows /* total */, 309 thread::ThreadPool::SchedulingParams( 310 thread::ThreadPool::SchedulingStrategy:: 311 kFixedBlockSize /* strategy */, 312 absl::nullopt /* cost_per_unit */, block_size), 313 [&](int64_t batch_and_row_begin, int64_t batch_and_row_end) { 314 HandleBatchAndRowRange( 315 num_lhs_rows, batch_and_row_begin, batch_and_row_end, 316 [&](int64_t batch_idx, int64_t row_begin, int64_t row_end) { 317 const int64_t num_shard_rows = row_end - row_begin; 318 319 // Define an Eigen::SparseMatrix over the row range: 320 // [row_begin, row_end) of the CSR SparseMatrix A. 321 std::vector<int32> row_ptrs; 322 auto sparse_matrix = GetSparseMatrixRef( 323 lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs); 324 325 // Map the corresponding rows of the rhs. 326 ConstMatrixMap rhs_map(rhs.flat<T>().data() + batch_idx * 327 num_rhs_rows * 328 num_rhs_cols, 329 num_rhs_rows, num_rhs_cols); 330 331 // Write to the corresponding rows of the output matrix. 332 MatrixMap output_map( 333 output->flat<T>().data() + 334 batch_idx * num_lhs_rows * num_rhs_cols + 335 row_begin * num_rhs_cols, 336 num_shard_rows, num_rhs_cols); 337 output_map.noalias() = sparse_matrix * rhs_map; 338 }); 339 }); 340 } 341 342 // Sparse-Dense Matrix Multiplication assuming the CSRSparseMatrix (LHS) is 343 // to be transposed before the operation. SparseDenseMatMulWithTransposedLHS(OpKernelContext * ctx,const int64_t batch_size,const int64_t num_lhs_rows,const int64_t num_lhs_cols,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)344 void SparseDenseMatMulWithTransposedLHS(OpKernelContext* ctx, 345 const int64_t batch_size, 346 const int64_t num_lhs_rows, 347 const int64_t num_lhs_cols, 348 const CSRSparseMatrix& lhs, 349 const Tensor& rhs, Tensor* output) { 350 auto device = ctx->eigen_device<CPUDevice>(); 351 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 352 const int32_t num_threads = worker_threads.num_threads; 353 const int64_t num_rhs_rows = rhs.dim_size(rhs.dims() - 2); 354 const int64_t num_rhs_cols = rhs.dim_size(rhs.dims() - 1); 355 // Usually, we want to avoid transposing the sparse matrix A since it may be 356 // an expensive operation. Instead, we use the identity (A^T B) = (B^T A)^T. 357 // We don't actually transpose B or the output because it is more convenient 358 // to have them in column major form. 359 // 360 // However, if A is hypersparse and B and C are huge, transposing A will be 361 // cheaper. In the future, we should have a cost model estimating the cost 362 // of transposing all matrices (A, B, C) to decide which variant to use. 363 364 // Each thread writes to its own copy of the matrix product. These 365 // `num_threads` copies are summed together to obtain the final result. 366 Tensor matmul_result_buffer; 367 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 368 TensorShape({num_threads + 1, 369 output->NumElements()}), 370 &matmul_result_buffer)); 371 functor::SetZeroFunctor<CPUDevice, T> set_zero; 372 set_zero(device, matmul_result_buffer.flat<T>()); 373 374 // Parallelize matrix multiplication across batch dimensions and across 375 // columns of A^T in each batch. These correspond to rows of A. 376 const int64_t block_size = 377 num_lhs_cols / std::max(kMaxShards, kNumShardsPerThread * num_threads); 378 worker_threads.workers->ParallelForWithWorkerId( 379 batch_size * num_lhs_cols /* total */, 380 thread::ThreadPool::SchedulingParams( 381 thread::ThreadPool::SchedulingStrategy:: 382 kFixedBlockSize /* strategy */, 383 absl::nullopt /* cost_per_unit */, block_size), 384 [&](int64_t batch_and_row_begin, int64_t batch_and_row_end, int tid) { 385 HandleBatchAndRowRange( 386 num_lhs_cols, batch_and_row_begin, batch_and_row_end, 387 [&](int64_t batch_idx, int64_t row_begin, int64_t row_end) { 388 const int64_t num_shard_rows = row_end - row_begin; 389 390 // Define a new sparse sub-matrix from the row range 391 // [row_begin, row_end) of the sparse matrix A. 392 std::vector<int32> row_ptrs; 393 auto sparse_matrix = GetSparseMatrixRef( 394 lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs); 395 396 // Map the corresponding `num_shard_rows` columns of B^T. 397 // This is the same as taking the `num_shard_rows` rows of B. 398 ConstMatrixMap b_dense_map( 399 rhs.flat<T>().data() + 400 batch_idx * num_rhs_rows * num_rhs_cols + 401 row_begin * num_rhs_cols, 402 num_shard_rows, num_rhs_cols); 403 404 // Map to the corresponding rows of the output. 405 MatrixMap output_map( 406 matmul_result_buffer.flat<T>().data() + 407 tid * batch_size * num_lhs_rows * num_rhs_cols + 408 batch_idx * num_lhs_rows * num_rhs_cols, 409 num_lhs_rows, num_rhs_cols); 410 411 // Compute the product C^T = B^T * A; restricted to the row 412 // range in the current shard. 413 if (this->conjugate_a_) { 414 output_map.transpose().noalias() += 415 b_dense_map.transpose() * sparse_matrix.conjugate(); 416 } else { 417 output_map.transpose().noalias() += 418 b_dense_map.transpose() * sparse_matrix; 419 } 420 }); 421 }); 422 423 // Sum across each thread's matmul result. 424 using Reducer = Eigen::internal::SumReducer<T>; 425 using Index = typename TTypes<T>::Tensor::Index; 426 output->flat<T>().device(device) = matmul_result_buffer.matrix<T>().reduce( 427 Eigen::array<Index, 1>({0}), Reducer()); 428 } 429 430 // Given a range [batch_and_row_begin, batch_and_row_end) which is a 431 // contiguous subset of [0, num_rows * batch_size), calls the function 432 // fn(batch_idx, row_begin, row_end) for each batch index 433 // and the row range [row_begin, row_end) contained in the batch. HandleBatchAndRowRange(const int64_t num_rows,const int64_t batch_and_row_begin,const int64_t batch_and_row_end,const std::function<void (int64_t,int64_t,int64_t)> & fn)434 void HandleBatchAndRowRange( 435 const int64_t num_rows, const int64_t batch_and_row_begin, 436 const int64_t batch_and_row_end, 437 const std::function<void(int64_t, int64_t, int64_t)>& fn) { 438 // Obtain the batch indices overlapping with the current shard. 439 const int64_t batch_begin = batch_and_row_begin / num_rows; 440 const int64_t batch_end_inclusive = batch_and_row_end / num_rows; 441 442 for (int64_t batch_idx = batch_begin; batch_idx <= batch_end_inclusive; 443 ++batch_idx) { 444 // Find the contiguous set of rows which are contained in this shard as 445 // well as the current batch. We intersect with interval [batch_idx * 446 // num_rows, (batch_idx + 1) * num_rows) which denotes the current batch. 447 const int64_t current_batch_row_begin = 448 std::max(batch_and_row_begin, batch_idx * num_rows); 449 const int64_t current_batch_row_end = 450 std::min(batch_and_row_end, (batch_idx + 1) * num_rows); 451 452 const int64_t row_begin = current_batch_row_begin % num_rows; 453 const int64_t num_shard_rows = 454 current_batch_row_end - current_batch_row_begin; 455 // Edge case for when current_batch_row_end is the first index of a new 456 // row. 457 if (num_shard_rows == 0) continue; 458 459 fn(batch_idx, row_begin, row_begin + num_shard_rows); 460 } 461 } 462 463 // Transposes (and optionally, conjugates) a given Tensor. Also allocates the 464 // required memory for the output Tensor. TransposeAndConjugateTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)465 Status TransposeAndConjugateTensor(OpKernelContext* ctx, const Tensor& input, 466 bool conjugate, Tensor* output) { 467 TensorShape transposed_shape = input.shape(); 468 transposed_shape.set_dim(input.dims() - 1, 469 input.dim_size(input.dims() - 2)); 470 transposed_shape.set_dim(input.dims() - 2, 471 input.dim_size(input.dims() - 1)); 472 TF_RETURN_IF_ERROR( 473 ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output)); 474 return TransposeAndConjugateAllocatedTensor(ctx, input, conjugate, output); 475 } 476 477 // Transposes (and optionally, conjugates) a given Tensor. The output should 478 // be already allocated. TransposeAndConjugateAllocatedTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)479 Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx, 480 const Tensor& input, 481 bool conjugate, Tensor* output) { 482 if (conjugate) { 483 TF_RETURN_IF_ERROR(DoConjugateMatrixTranspose( 484 ctx->eigen_device<CPUDevice>(), input, output)); 485 } else { 486 TF_RETURN_IF_ERROR( 487 DoMatrixTranspose(ctx->eigen_device<CPUDevice>(), input, output)); 488 } 489 return OkStatus(); 490 } 491 }; 492 493 // GPU Kernel to compute sparse-dense matrix multiplication. 494 template <typename T> 495 class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> { 496 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>; 497 using Matrix = 498 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 499 using ConstMatrixMap = Eigen::Map<const Matrix>; 500 using MatrixMap = Eigen::Map<Matrix>; 501 502 public: CSRMatMulGPUOp(OpKernelConstruction * c)503 explicit CSRMatMulGPUOp(OpKernelConstruction* c) 504 : CSRMatMulOp<GPUDevice, T>(c) {} 505 ~CSRMatMulGPUOp()506 ~CSRMatMulGPUOp() override {} 507 Compute(OpKernelContext * ctx)508 void Compute(OpKernelContext* ctx) final { 509 const CSRSparseMatrix* a_matrix; 510 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix)); 511 const Tensor& b_t = ctx->input(1); 512 513 int rank; 514 int64_t batch_size; 515 OP_REQUIRES_OK(ctx, 516 this->ValidateInputs(*a_matrix, b_t, &rank, &batch_size)); 517 518 const Tensor& a_dense_shape_t = a_matrix->dense_shape(); 519 TensorShape a_dense_tensor_shape; 520 auto a_dense_shape = a_dense_shape_t.vec<int64_t>(); 521 OP_REQUIRES_OK( 522 ctx, TensorShapeUtils::MakeShape(a_dense_shape, &a_dense_tensor_shape)); 523 524 const int row_dim = (rank == 2) ? 0 : 1; 525 const int64_t a_outer_dim = a_dense_tensor_shape.dim_size( 526 this->transpose_a_ ? row_dim + 1 : row_dim); 527 const int64_t b_inner_dim = 528 b_t.shape().dim_size(this->transpose_b_ ? row_dim + 1 : row_dim); 529 const int64_t b_outer_dim = 530 b_t.dim_size(this->transpose_b_ ? row_dim : row_dim + 1); 531 const int64_t b_slice_size = b_inner_dim * b_outer_dim; 532 533 TensorShape c_shape; 534 if (rank == 3) c_shape.AddDim(batch_size); 535 if (this->transpose_output_) { 536 c_shape.AddDim(b_outer_dim); 537 c_shape.AddDim(a_outer_dim); 538 } else { 539 c_shape.AddDim(a_outer_dim); 540 c_shape.AddDim(b_outer_dim); 541 } 542 543 const int64_t c_matrix_lhs = c_shape.dim_size(row_dim); 544 const int64_t c_matrix_rhs = c_shape.dim_size(row_dim + 1); 545 const int64_t c_slice_size = c_matrix_lhs * c_matrix_rhs; 546 Tensor* c_t; 547 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t)); 548 549 const GPUDevice& d = ctx->eigen_device<GPUDevice>(); 550 bool use_matrix_vector_multiply = (b_outer_dim == 1); 551 #if TENSORFLOW_USE_ROCM 552 // ROCm hipsparse does not implement csrmv with transposed input a 553 use_matrix_vector_multiply = 554 use_matrix_vector_multiply && !this->transpose_a_; 555 #endif 556 if (use_matrix_vector_multiply) { 557 // Call matrix-vector multiply if b is a vector. 558 TTypes<int64_t>::ConstVec a_dense_shape_comp( 559 a_dense_shape.data() + row_dim, 2); 560 Tensor b_conj_t; 561 const T* b_base_ptr = b_t.template flat<T>().data(); 562 bool conjugate_a = this->conjugate_a_; 563 bool conjugate_output = this->conjugate_output_; 564 if (this->conjugate_b_) { 565 if (conjugate_a) { 566 // In this case we can use the identity 567 // conj(a) * conj(b) = conj(a * b) 568 // instead of creating a conjugated copy of b. 569 conjugate_a = false; 570 conjugate_output = !conjugate_output; 571 } else { 572 OP_REQUIRES_OK( 573 ctx, ctx->forward_input_or_allocate_temp( 574 {1}, DataTypeToEnum<T>::value, b_t.shape(), &b_conj_t)); 575 functor::maybe_conj<GPUDevice, T>::run(d, b_t, &b_conj_t); 576 b_base_ptr = b_conj_t.template flat<T>().data(); 577 } 578 } 579 580 functor::CSRSparseMatrixMatVec<GPUDevice, T> csr_spmv(this->transpose_a_, 581 conjugate_a); 582 for (int i = 0; i < batch_size; ++i) { 583 auto a_row_ptr = a_matrix->row_pointers_vec(i); 584 auto a_col_ind = a_matrix->col_indices_vec(i); 585 auto a_values = a_matrix->values_vec<T>(i); 586 ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values, 587 a_dense_shape_comp}; 588 const T* b_i = b_base_ptr + i * b_slice_size; 589 T* c_i = &c_t->template flat<T>()(i * c_slice_size); 590 Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i); 591 OP_REQUIRES_OK(ctx, s); 592 } 593 if (conjugate_output) { 594 functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t); 595 } 596 return; 597 } 598 599 functor::CSRSparseMatrixMatMul<GPUDevice, T> csr_spmmadd( 600 this->transpose_output_); 601 602 Tensor c_mat_col_major_t; 603 if (!this->transpose_output_) { 604 // If transpose_output is false, we'll need to transpose the (col 605 // major) output of the csrgemm call to get proper (row-major) 606 // output. Which means we need to keep a temporary buffer to 607 // store the intermediate gemm output. 608 TensorShape c_mat_col_major_shape; 609 if (rank == 2) { 610 c_mat_col_major_shape = TensorShape({c_matrix_rhs, c_matrix_lhs}); 611 } else { 612 c_mat_col_major_shape = 613 TensorShape({batch_size, c_matrix_rhs, c_matrix_lhs}); 614 } 615 OP_REQUIRES_OK( 616 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 617 c_mat_col_major_shape, &c_mat_col_major_t)); 618 } 619 620 // If transpose_output is true, return the direct (column-major i.e., 621 // transposed) output of the csrgemm call. Otherwise we'll need 622 // to transpose it to row major format. 623 auto c_mat_col_major = (this->transpose_output_) 624 ? c_t->flat<T>() 625 : c_mat_col_major_t.flat<T>(); 626 627 // Possibly transpose a. 628 const CSRSparseMatrix* a_input_matrix; 629 // If we need to transpose a, we will store the result temporarily 630 // in the object below. 631 CSRSparseMatrix a_matrix_transposed; 632 if (!this->transpose_a_) { 633 a_input_matrix = a_matrix; 634 } else { 635 functor::CSRSparseMatrixTranspose<GPUDevice, T> transpose; 636 OP_REQUIRES_OK(ctx, transpose(ctx, this->conjugate_a_, *a_matrix, 637 &a_matrix_transposed)); 638 a_input_matrix = &a_matrix_transposed; 639 } 640 641 auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64_t>(); 642 643 // Possibly transpose b. 644 Tensor b_t_input; 645 if (!this->transpose_b_) { 646 b_t_input = b_t; 647 } else { 648 TensorShape b_t_transposed_shape; 649 if (rank == 3) { 650 b_t_transposed_shape.AddDim(batch_size); 651 } 652 b_t_transposed_shape.AddDim(b_t.dim_size(row_dim + 1)); 653 b_t_transposed_shape.AddDim(b_t.dim_size(row_dim)); 654 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 655 b_t_transposed_shape, &b_t_input)); 656 const GPUDevice& d = ctx->eigen_device<GPUDevice>(); 657 if (this->conjugate_b_) { 658 OP_REQUIRES_OK(ctx, DoConjugateMatrixTranspose(d, b_t /*input*/, 659 &b_t_input /*output*/)); 660 } else { 661 OP_REQUIRES_OK( 662 ctx, DoMatrixTranspose(d, b_t /*input*/, &b_t_input /*output*/)); 663 } 664 } 665 666 // Dense shape of a batch component of A. 667 TTypes<int64_t>::ConstVec a_input_dense_shape_comp( 668 a_input_dense_shape.data() + row_dim, 2); 669 670 auto b = b_t_input.flat<T>(); 671 672 for (int i = 0; i < batch_size; ++i) { 673 auto a_row_ptr = a_input_matrix->row_pointers_vec(i); 674 auto a_col_ind = a_input_matrix->col_indices_vec(i); 675 auto a_values = a_input_matrix->values_vec<T>(i); 676 typename TTypes<T>::UnalignedConstMatrix b_i(b.data() + i * b_slice_size, 677 {b_inner_dim, b_outer_dim}); 678 typename TTypes<T>::UnalignedMatrix c_mat_col_major_i( 679 c_mat_col_major.data() + i * c_slice_size, 680 {c_matrix_lhs, c_matrix_rhs}); 681 ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values, 682 a_input_dense_shape_comp}; 683 Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i); 684 OP_REQUIRES_OK(ctx, s); 685 } 686 687 if (!this->transpose_output_) { 688 // We need to return values in row major format, so transpose 689 // the column-major values in c_mat_col_major_t to row-major output c_t. 690 OP_REQUIRES_OK(ctx, DoMatrixTranspose(d, /*input=*/c_mat_col_major_t, 691 /*output=*/c_t)); 692 } 693 if (this->conjugate_output_) { 694 functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t); 695 } 696 } 697 }; 698 699 #define REGISTER_CPU(T) \ 700 REGISTER_KERNEL_BUILDER( \ 701 Name("SparseMatrixMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 702 CSRMatMulCPUOp<T>); 703 704 REGISTER_CPU(float) 705 REGISTER_CPU(double) 706 REGISTER_CPU(complex64) 707 REGISTER_CPU(complex128) 708 709 #undef REGISTER_CPU 710 711 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 712 713 #define REGISTER_GPU(T) \ 714 REGISTER_KERNEL_BUILDER( \ 715 Name("SparseMatrixMatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 716 CSRMatMulGPUOp<T>); 717 718 REGISTER_GPU(float) 719 REGISTER_GPU(double) 720 REGISTER_GPU(complex64) 721 REGISTER_GPU(complex128) 722 723 #undef REGISTER_GPU 724 725 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 726 727 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 728 729 namespace functor { 730 731 namespace { 732 733 // GPUDataType<T>::type translates from a C++ type (e.g. float) to a 734 // GPUDataType_t (e.g. CUDA_R_32F). 735 template <typename T> 736 struct GPUDataType; 737 738 // GPUDataType templates are currently not instantiated in the ROCm flow 739 // So leaving out the #elif TENSORFLOW_USE_ROCM blocks for now 740 // hipblas library is not (yet) being pulled in via rocm_configure.bzl 741 // so cannot reference tyeps from hipblas headers here 742 template <> 743 struct GPUDataType<Eigen::half> { 744 #if GOOGLE_CUDA 745 static constexpr cudaDataType_t type = CUDA_R_16F; 746 #endif 747 }; 748 749 template <> 750 struct GPUDataType<float> { 751 #if GOOGLE_CUDA 752 static constexpr cudaDataType_t type = CUDA_R_32F; 753 #endif 754 }; 755 756 template <> 757 struct GPUDataType<std::complex<float>> { 758 #if GOOGLE_CUDA 759 static constexpr cudaDataType_t type = CUDA_C_32F; 760 #endif 761 }; 762 763 template <> 764 struct GPUDataType<double> { 765 #if GOOGLE_CUDA 766 static constexpr cudaDataType_t type = CUDA_R_64F; 767 #endif 768 }; 769 770 template <> 771 struct GPUDataType<std::complex<double>> { 772 #if GOOGLE_CUDA 773 static constexpr cudaDataType_t type = CUDA_C_64F; 774 #endif 775 }; 776 777 } // namespace 778 779 template <typename T> 780 class CSRSparseMatrixMatMul<GPUDevice, T> { 781 public: CSRSparseMatrixMatMul(const bool transpose_output)782 explicit CSRSparseMatrixMatMul(const bool transpose_output) 783 : transpose_output_(transpose_output) {} 784 Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,typename TTypes<T>::UnalignedConstMatrix b,typename TTypes<T>::UnalignedMatrix c)785 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 786 typename TTypes<T>::UnalignedConstMatrix b, 787 typename TTypes<T>::UnalignedMatrix c) { 788 GpuSparse cuda_sparse(ctx); 789 TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); 790 { 791 // Use Csrmm/SpMM to calculate: 792 // C = alpha * op(A) * op(B) + beta * C 793 // where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense. 794 // Note that Csrmm/Spmm assumes B and C are in column-major form; so we 795 // use transB == true, and manually transpose the output in place 796 // using blas<t>geam. 797 // TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint. 798 799 // Create alpha and beta scalars; alpha = 1.0, beta = 0.0 800 // TODO(ebrevdo,rmlarsen): Add support for non-trivial alpha and beta. 801 const T alpha = 1; 802 const T beta = 0; 803 804 // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n) 805 const int k = b.dimension(0); 806 DCHECK_EQ(k, a.dense_shape_host(1)); 807 808 // If transpose_output_ is true, then the c matrix we receive 809 // here is the direct row major output (into which we will store 810 // csrgemm's col major output). Otherwise it's a 811 // temporary tensor that will store the column major output that 812 // will eventually be transposed. 813 const int m = c.dimension(transpose_output_ ? 1 : 0); 814 const int n = c.dimension(transpose_output_ ? 0 : 1); 815 DCHECK_EQ(m, a.dense_shape_host(0)); 816 DCHECK_EQ(n, b.dimension(1)); 817 const int nnz = a.values.size(); 818 DCHECK_EQ(nnz, a.col_ind.size()); 819 820 // ldb: leading dimension of B. If op(B)=B, it must be at least max(1, k) 821 // if op(A) = A and at least max (1, m) otherwise. If op(B) != B, it must 822 // be at least max(1, n). 823 const int ldb = n; 824 // ldc: leading dimension of C. It must be at least max(1, m) if 825 // op(A) = A and at least max(1, k) otherwise. 826 const int ldc = m; 827 828 // transA must be non-transpose if transB is transpose (cusparse 829 // limitation). 830 #if GOOGLE_CUDA 831 const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE; 832 #elif TENSORFLOW_USE_ROCM 833 const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE; 834 #endif 835 836 // transB: b is row-major, and cusparse requires col-major b (or 837 // equivalently transB == transpose). this version is actually more 838 // efficient. 839 #if GOOGLE_CUDA && CUDA_VERSION >= 10020 840 841 const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; 842 gpusparseSpMatDescr_t matA; 843 gpusparseDnMatDescr_t matB, matC; 844 845 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr( 846 &matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()), 847 const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()), 848 CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, 849 GPUDataType<T>::type)); 850 851 TF_RETURN_IF_GPUSPARSE_ERROR( 852 cusparseCreateDnMat(&matB, n, k, ldb, const_cast<T*>(b.data()), 853 GPUDataType<T>::type, CUSPARSE_ORDER_COL)); 854 855 TF_RETURN_IF_GPUSPARSE_ERROR( 856 cusparseCreateDnMat(&matC, m, n, ldc, c.data(), GPUDataType<T>::type, 857 CUSPARSE_ORDER_COL)); 858 859 size_t bufferSize = 0; 860 TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize( 861 transA, transB, &alpha, matA, matB, &beta, matC, 862 CUSPARSE_MM_ALG_DEFAULT, &bufferSize)); 863 864 Tensor buffer; 865 TF_RETURN_IF_ERROR(ctx->allocate_temp( 866 DT_INT8, TensorShape({static_cast<int64_t>(bufferSize)}), &buffer)); 867 DCHECK(buffer.flat<int8>().data() != nullptr); 868 869 TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB, 870 &beta, matC, CUSPARSE_MM_ALG_DEFAULT, 871 buffer.flat<int8>().data())); 872 873 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB)); 874 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC)); 875 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA)); 876 877 #elif TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 40200 878 // Use SPMM 879 const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE; 880 gpusparseSpMatDescr_t matA; 881 gpusparseDnMatDescr_t matB, matC; 882 883 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateCsr( 884 &matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()), 885 const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()), 886 CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, 887 GPUDataType<T>::type)); 888 889 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateDnMat( 890 &matB, n, k, ldb, const_cast<T*>(b.data()), GPUDataType<T>::type, 891 HIPSPARSE_ORDER_COL)); 892 893 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateDnMat( 894 &matC, m, n, ldc, c.data(), GPUDataType<T>::type, 895 HIPSPARSE_ORDER_COL)); 896 897 size_t bufferSize = 0; 898 TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize( 899 transA, transB, &alpha, matA, matB, &beta, matC, 900 HIPSPARSE_MM_ALG_DEFAULT, &bufferSize)); 901 902 Tensor buffer; 903 TF_RETURN_IF_ERROR(ctx->allocate_temp( 904 DT_INT8, TensorShape({static_cast<int64_t>(bufferSize)}), &buffer)); 905 DCHECK(buffer.flat<int8>().data() != nullptr); 906 907 TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB, 908 &beta, matC, CUSPARSE_MM_ALG_DEFAULT, 909 buffer.flat<int8>().data())); 910 911 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroyDnMat(matB)); 912 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroyDnMat(matC)); 913 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroySpMat(matA)); 914 915 #else 916 917 #if GOOGLE_CUDA 918 919 const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; 920 921 gpusparseMatDescr_t descrA; 922 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); 923 TF_RETURN_IF_GPUSPARSE_ERROR( 924 cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); 925 TF_RETURN_IF_GPUSPARSE_ERROR( 926 cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); 927 928 #elif TENSORFLOW_USE_ROCM 929 930 const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE; 931 932 gpusparseMatDescr_t descrA; 933 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA)); 934 TF_RETURN_IF_GPUSPARSE_ERROR( 935 wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); 936 TF_RETURN_IF_GPUSPARSE_ERROR( 937 wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); 938 #endif // GOOGLE_CUDA 939 940 TF_RETURN_IF_ERROR( 941 cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA, 942 a.values.data(), a.row_ptr.data(), a.col_ind.data(), 943 b.data(), ldb, &beta, c.data(), ldc)); 944 945 #endif // GOOGLE_CUDA && CUDA_VERSION >= 10020 946 } 947 948 return OkStatus(); 949 } 950 951 private: 952 bool transpose_output_; 953 }; 954 955 template <typename T> 956 class CSRSparseMatrixMatVec<GPUDevice, T> { 957 public: CSRSparseMatrixMatVec(bool transpose_a,bool conjugate_a)958 CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a) 959 : transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a, 960 &status_)) {} 961 Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,const T * x,T * y)962 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 963 const T* x, T* y) { 964 TF_RETURN_IF_ERROR(status_); 965 GpuSparse cuda_sparse(ctx); 966 TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); 967 { 968 // Use Csrmv to calculate: 969 // y = alpha * op(A) * x + beta * y 970 // where alpha = 1.0, beta = 0.0, A is a sparse matrix and x and y are 971 // dense vectors. 972 973 // Create alpha and beta scalars; alpha = 1.0, beta = 0.0 974 // TODO(rmlarsen,ebrevdo): Add support for general alpha, beta. 975 const T alpha = 1; 976 const T beta = 0; 977 978 #if GOOGLE_CUDA && CUDA_VERSION < 10020 979 gpusparseMatDescr_t descrA; 980 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); 981 TF_RETURN_IF_GPUSPARSE_ERROR( 982 cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); 983 TF_RETURN_IF_GPUSPARSE_ERROR( 984 cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); 985 #elif TENSORFLOW_USE_ROCM 986 gpusparseMatDescr_t descrA; 987 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA)); 988 TF_RETURN_IF_GPUSPARSE_ERROR( 989 wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); 990 TF_RETURN_IF_GPUSPARSE_ERROR( 991 wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); 992 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 993 994 const int m = a.dense_shape_host(0); 995 const int n = a.dense_shape_host(1); 996 const int nnz = a.values.size(); 997 DCHECK_EQ(nnz, a.col_ind.size()); 998 #if GOOGLE_CUDA && (CUDA_VERSION >= 10020) 999 TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, 1000 a.values.data(), a.row_ptr.data(), 1001 a.col_ind.data(), x, &beta, y)); 1002 #else 1003 TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA, 1004 a.values.data(), a.row_ptr.data(), 1005 a.col_ind.data(), x, &beta, y)); 1006 #endif 1007 } 1008 1009 return OkStatus(); 1010 } 1011 1012 private: 1013 Status status_; 1014 const gpusparseOperation_t transA_; 1015 }; 1016 1017 } // namespace functor 1018 1019 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 1020 1021 } // namespace tensorflow 1022