1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 19 #define EIGEN_USE_GPU 20 #endif 21 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/framework/tensor_types.h" 27 #include "tensorflow/core/framework/tensor_util.h" 28 #include "tensorflow/core/framework/variant_op_registry.h" 29 #include "tensorflow/core/kernels/dense_update_functor.h" 30 #include "tensorflow/core/kernels/sparse/kernels.h" 31 #include "tensorflow/core/kernels/sparse/sparse_matrix.h" 32 #include "tensorflow/core/kernels/fill_functor.h" 33 34 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 35 #include "tensorflow/core/util/cuda_sparse.h" 36 #include "tensorflow/core/util/gpu_solvers.h" 37 #endif 38 39 namespace tensorflow { 40 41 typedef Eigen::ThreadPoolDevice CPUDevice; 42 typedef Eigen::GpuDevice GPUDevice; 43 44 namespace { 45 template <typename Device, typename T> 46 class CSRSparseMatrixAddFunctor { 47 public: CSRSparseMatrixAddFunctor(OpKernelContext * ctx,const T alpha,const T beta)48 explicit CSRSparseMatrixAddFunctor(OpKernelContext* ctx, const T alpha, 49 const T beta) 50 : ctx_(ctx), alpha_(alpha), beta_(beta) {} 51 operator ()(const CSRSparseMatrix & a,const CSRSparseMatrix & b,CSRSparseMatrix * c)52 Status operator()(const CSRSparseMatrix& a, const CSRSparseMatrix& b, 53 CSRSparseMatrix* c) { 54 TensorShape a_tensor_shape; 55 TensorShape b_tensor_shape; 56 TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape( 57 a.dense_shape().vec<int64_t>(), &a_tensor_shape)); 58 TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape( 59 b.dense_shape().vec<int64_t>(), &b_tensor_shape)); 60 61 if (a_tensor_shape.dims() == 3) { 62 if ((a_tensor_shape.dims() != b_tensor_shape.dims()) || 63 (a_tensor_shape.dim_size(0) != b_tensor_shape.dim_size(0))) { 64 return errors::InvalidArgument( 65 "Incompatible shapes of a and b, a.shape == ", 66 a_tensor_shape.DebugString(), 67 ", b.shape == ", b_tensor_shape.DebugString()); 68 } 69 } 70 const int rank = a_tensor_shape.dims(); 71 if ((a_tensor_shape.dim_size(rank - 2) != 72 b_tensor_shape.dim_size(rank - 2)) || 73 (a_tensor_shape.dim_size(rank - 1) != 74 b_tensor_shape.dim_size(rank - 1))) { 75 return errors::InvalidArgument( 76 "Incompatible shapes of a and b, a.shape == ", 77 a_tensor_shape.DebugString(), 78 ", b.shape == ", b_tensor_shape.DebugString()); 79 } 80 81 const int batch_size = a.batch_size(); 82 83 // TODO(ebrevdo): Add support for broadcasting at least in the 84 // batch dimension. 85 auto a_dense_shape = a.dense_shape().vec<int64_t>(); 86 auto b_dense_shape = b.dense_shape().vec<int64_t>(); 87 Tensor c_dense_shape_t = a.dense_shape(); 88 89 const int64_t rows = a_dense_shape((rank == 2) ? 0 : 1); 90 91 functor::CSRSparseMatrixAdd<Device, T> csr_geam(ctx_, alpha_, beta_); 92 TF_RETURN_IF_ERROR(csr_geam.Initialize()); 93 94 Tensor c_batch_ptr_t(cpu_allocator(), DT_INT32, 95 TensorShape({batch_size + 1})); 96 auto c_batch_ptr = c_batch_ptr_t.vec<int32>(); 97 c_batch_ptr(0) = 0; 98 99 Tensor c_row_ptr_t; 100 TF_RETURN_IF_ERROR(ctx_->allocate_temp( 101 DT_INT32, TensorShape({batch_size * (rows + 1)}), &c_row_ptr_t)); 102 auto c_row_ptr = c_row_ptr_t.vec<int32>(); 103 104 // Set the output row pointers to zero, in case we hit any empty 105 // combinations of rows in a and b. 106 functor::SetZeroFunctor<Device, int32> set_zero; 107 const Device& d = ctx_->eigen_device<Device>(); 108 set_zero(d, c_row_ptr_t.flat<int32>()); 109 110 size_t maxWorkspaceSize = 0; 111 for (int i = 0; i < batch_size; ++i) { 112 ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i), 113 a.values_vec<T>(i), a_dense_shape}; 114 ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i), 115 b.values_vec<T>(i), b_dense_shape}; 116 117 size_t thisWorkspaceSize; 118 TF_RETURN_IF_ERROR( 119 csr_geam.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize)); 120 if (thisWorkspaceSize > maxWorkspaceSize) { 121 maxWorkspaceSize = thisWorkspaceSize; 122 } 123 } 124 125 Tensor temp; 126 TF_RETURN_IF_ERROR(ctx_->allocate_temp( 127 DT_INT8, TensorShape({static_cast<int64_t>(maxWorkspaceSize)}), &temp)); 128 void* workspace = temp.flat<int8>().data(); 129 130 for (int i = 0; i < batch_size; ++i) { 131 // Calculate output sizes for all minibatch entries. 132 // Store in c_batch_ptr and update c_row_ptrs. 133 if (a.nnz(i) == 0 && b.nnz(i) == 0) { 134 c_batch_ptr(i + 1) = c_batch_ptr(i); 135 continue; 136 } 137 ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i), 138 a.values_vec<T>(i), a_dense_shape}; 139 ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i), 140 b.values_vec<T>(i), b_dense_shape}; 141 TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)), 142 rows + 1); 143 int c_nnz_i; 144 TF_RETURN_IF_ERROR(csr_geam.GetOutputStructure( 145 a_comp, b_comp, c_row_ptr_i, &c_nnz_i, workspace)); 146 c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i; 147 } 148 149 Tensor c_col_ind_t; 150 Tensor c_values_t; 151 152 const int total_nnz = c_batch_ptr(batch_size); 153 154 TF_RETURN_IF_ERROR( 155 ctx_->allocate_temp(DT_INT32, TensorShape({total_nnz}), &c_col_ind_t)); 156 TF_RETURN_IF_ERROR(ctx_->allocate_temp( 157 DataTypeToEnum<T>::value, TensorShape({total_nnz}), &c_values_t)); 158 TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix( 159 DataTypeToEnum<T>::value, c_dense_shape_t, c_batch_ptr_t, c_row_ptr_t, 160 c_col_ind_t, c_values_t, c)); 161 162 for (int i = 0; i < batch_size; ++i) { 163 if (a.nnz(i) == 0 && b.nnz(i) == 0) { 164 // Setting of c_row_pointers_vec(i) == 0 is already done. 165 continue; 166 } 167 ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i), 168 a.values_vec<T>(i), a_dense_shape}; 169 ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i), 170 b.values_vec<T>(i), b_dense_shape}; 171 CSRComponent<T> c_comp{c->row_pointers_vec(i), c->col_indices_vec(i), 172 c->values_vec<T>(i), 173 c_dense_shape_t.vec<int64_t>()}; 174 175 TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp, workspace)); 176 } 177 178 return OkStatus(); 179 } 180 181 private: 182 OpKernelContext* ctx_; 183 const T alpha_; 184 const T beta_; 185 }; 186 187 template <typename Device, typename T> 188 class CSRSparseMatrixSumFunctor : public CSRSparseMatrixAddFunctor<Device, T> { 189 public: 190 // Same as above, but with alpha = beta = 1.0, so C = 1.0 * A + 1.0 * B. CSRSparseMatrixSumFunctor(OpKernelContext * ctx)191 explicit CSRSparseMatrixSumFunctor(OpKernelContext* ctx) 192 : CSRSparseMatrixAddFunctor<Device, T>(ctx, 1, 1) {} 193 }; 194 195 } // namespace 196 197 template <typename Device, typename T> 198 class CSRAddOp : public OpKernel { 199 public: CSRAddOp(OpKernelConstruction * c)200 explicit CSRAddOp(OpKernelConstruction* c) : OpKernel(c) {} 201 Compute(OpKernelContext * ctx)202 void Compute(OpKernelContext* ctx) final { 203 const CSRSparseMatrix* a_matrix; 204 const CSRSparseMatrix* b_matrix; 205 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix)); 206 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &b_matrix)); 207 208 OP_REQUIRES( 209 ctx, a_matrix->dtype() == DataTypeToEnum<T>::value, 210 errors::InvalidArgument("dtype of a is not equal to 'type': ", 211 DataTypeString(a_matrix->dtype()), " vs. ", 212 DataTypeString(DataTypeToEnum<T>::value))); 213 OP_REQUIRES( 214 ctx, b_matrix->dtype() == DataTypeToEnum<T>::value, 215 errors::InvalidArgument("dtype of b is not equal to 'type': ", 216 DataTypeString(b_matrix->dtype()), " vs. ", 217 DataTypeString(DataTypeToEnum<T>::value))); 218 219 const Tensor& alpha_t = ctx->input(2); 220 const Tensor& beta_t = ctx->input(3); 221 OP_REQUIRES( 222 ctx, TensorShapeUtils::IsScalar(alpha_t.shape()), 223 errors::InvalidArgument("Expected alpha to be a scalar, saw shape: ", 224 alpha_t.shape().DebugString())); 225 OP_REQUIRES( 226 ctx, TensorShapeUtils::IsScalar(beta_t.shape()), 227 errors::InvalidArgument("Expected beta to be a scalar, saw shape: ", 228 beta_t.shape().DebugString())); 229 230 const T host_alpha = alpha_t.scalar<T>()(); 231 const T host_beta = beta_t.scalar<T>()(); 232 233 Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({})); 234 CSRSparseMatrix c_matrix; 235 CSRSparseMatrixAddFunctor<Device, T> add_functor(ctx, host_alpha, 236 host_beta); 237 OP_REQUIRES_OK(ctx, add_functor(*a_matrix, *b_matrix, &c_matrix)); 238 c_t.scalar<Variant>()() = std::move(c_matrix); 239 ctx->set_output(0, c_t); 240 } 241 }; 242 243 #define REGISTER(DEV, T) \ 244 REGISTER_KERNEL_BUILDER(Name("SparseMatrixAdd") \ 245 .Device(DEVICE_##DEV) \ 246 .TypeConstraint<T>("T") \ 247 .HostMemory("alpha") \ 248 .HostMemory("beta"), \ 249 CSRAddOp<DEV##Device, T>); 250 251 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 252 253 #define REGISTER_GPU(T) REGISTER(GPU, T) 254 255 REGISTER_GPU(float) 256 REGISTER_GPU(double) 257 #if GOOGLE_CUDA 258 REGISTER_GPU(complex64) 259 REGISTER_GPU(complex128) 260 #endif 261 262 #undef REGISTER_GPU 263 264 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION( 265 ADD_VARIANT_BINARY_OP, DEVICE_GPU, CSRSparseMatrix, 266 (CSRSparseMatrixBinaryHelper<GPUDevice, CSRSparseMatrixSumFunctor>)); 267 268 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 269 270 #undef REGISTER 271 272 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 273 namespace functor { 274 template <typename T> 275 struct CSRSparseMatrixAdd<GPUDevice, T> 276 : public CSRStructureModifyingFunctor<GPUDevice, T> { CSRSparseMatrixAddtensorflow::functor::CSRSparseMatrixAdd277 explicit CSRSparseMatrixAdd(OpKernelContext* ctx, const T alpha, const T beta) 278 : ctx_(ctx), 279 cuda_sparse_(ctx), 280 alpha_(alpha), 281 beta_(beta), 282 initialized_(false) {} 283 Initializetensorflow::functor::CSRSparseMatrixAdd284 Status Initialize() { 285 TF_RETURN_IF_ERROR(cuda_sparse_.Initialize()); 286 TF_RETURN_IF_ERROR(descrA_.Initialize()); 287 TF_RETURN_IF_ERROR(descrB_.Initialize()); 288 TF_RETURN_IF_ERROR(descrC_.Initialize()); 289 initialized_ = true; 290 return OkStatus(); 291 } 292 GetWorkspaceSizetensorflow::functor::CSRSparseMatrixAdd293 Status GetWorkspaceSize(const ConstCSRComponent<T>& a, 294 const ConstCSRComponent<T>& b, size_t* bufferSize) { 295 DCHECK(initialized_); 296 297 const int m = a.row_ptr.size() - 1; 298 DCHECK_EQ(m, b.row_ptr.size() - 1); 299 const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1; 300 DCHECK_EQ(m, a.dense_shape_host(row_dim)); 301 DCHECK_EQ(m, b.dense_shape_host(row_dim)); 302 const int nnzA = a.col_ind.size(); 303 const int nnzB = b.col_ind.size(); 304 305 const int n = a.dense_shape_host(row_dim + 1); 306 DCHECK_EQ(n, b.dense_shape_host(row_dim + 1)); 307 T* null_T = nullptr; 308 int* null_int = nullptr; 309 310 TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamBufferSizeExt( 311 m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(), 312 a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(), 313 b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), null_T, null_int, 314 null_int, bufferSize)); 315 316 return OkStatus(); 317 } 318 GetOutputStructuretensorflow::functor::CSRSparseMatrixAdd319 Status GetOutputStructure(const ConstCSRComponent<T>& a, 320 const ConstCSRComponent<T>& b, 321 TTypes<int32>::UnalignedVec c_row_ptr, 322 int* output_nnz, void* workspace) { 323 DCHECK(initialized_); 324 325 const int m = a.row_ptr.size() - 1; 326 DCHECK_EQ(m, b.row_ptr.size() - 1); 327 const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1; 328 DCHECK_EQ(m, a.dense_shape_host(row_dim)); 329 DCHECK_EQ(m, b.dense_shape_host(row_dim)); 330 const int nnzA = a.col_ind.size(); 331 const int nnzB = b.col_ind.size(); 332 *output_nnz = -1; 333 334 const int n = a.dense_shape_host(row_dim + 1); 335 DCHECK_EQ(n, b.dense_shape_host(row_dim + 1)); 336 337 TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamNnz( 338 m, n, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(), 339 descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), 340 descrC_.descr(), c_row_ptr.data(), output_nnz, workspace)); 341 342 if (*output_nnz < 0) { 343 return errors::Internal( 344 "CSRAdd: CsrgeamNnz returned nnzTotalDevHostPtr < 0: ", *output_nnz); 345 } 346 return OkStatus(); 347 } 348 Computetensorflow::functor::CSRSparseMatrixAdd349 Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b, 350 CSRComponent<T>* c, void* workspace) { 351 DCHECK(initialized_); 352 353 const int m = a.row_ptr.size() - 1; 354 DCHECK_EQ(m, b.row_ptr.size() - 1); 355 const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1; 356 DCHECK_EQ(m, a.dense_shape_host(row_dim)); 357 DCHECK_EQ(m, b.dense_shape_host(row_dim)); 358 const int nnzA = a.col_ind.size(); 359 const int nnzB = b.col_ind.size(); 360 361 const int n = a.dense_shape_host(row_dim + 1); 362 DCHECK_EQ(n, b.dense_shape_host(row_dim + 1)); 363 364 // Adding alpha * a + beta * b. 365 TF_RETURN_IF_ERROR(cuda_sparse_.Csrgeam( 366 m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(), 367 a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(), 368 b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(), 369 c->row_ptr.data(), c->col_ind.data(), workspace)); 370 371 return OkStatus(); 372 } 373 374 private: 375 OpKernelContext* ctx_; 376 GpuSparse cuda_sparse_; 377 GpuSparseMatrixDescriptor descrA_; 378 GpuSparseMatrixDescriptor descrB_; 379 GpuSparseMatrixDescriptor descrC_; 380 const T alpha_; 381 const T beta_; 382 bool initialized_; 383 384 TF_DISALLOW_COPY_AND_ASSIGN(CSRSparseMatrixAdd); 385 }; 386 387 } // namespace functor 388 389 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 390 391 } // namespace tensorflow 392