xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse/add_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/tensor_types.h"
27 #include "tensorflow/core/framework/tensor_util.h"
28 #include "tensorflow/core/framework/variant_op_registry.h"
29 #include "tensorflow/core/kernels/dense_update_functor.h"
30 #include "tensorflow/core/kernels/sparse/kernels.h"
31 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 
34 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
35 #include "tensorflow/core/util/cuda_sparse.h"
36 #include "tensorflow/core/util/gpu_solvers.h"
37 #endif
38 
39 namespace tensorflow {
40 
41 typedef Eigen::ThreadPoolDevice CPUDevice;
42 typedef Eigen::GpuDevice GPUDevice;
43 
44 namespace {
45 template <typename Device, typename T>
46 class CSRSparseMatrixAddFunctor {
47  public:
CSRSparseMatrixAddFunctor(OpKernelContext * ctx,const T alpha,const T beta)48   explicit CSRSparseMatrixAddFunctor(OpKernelContext* ctx, const T alpha,
49                                      const T beta)
50       : ctx_(ctx), alpha_(alpha), beta_(beta) {}
51 
operator ()(const CSRSparseMatrix & a,const CSRSparseMatrix & b,CSRSparseMatrix * c)52   Status operator()(const CSRSparseMatrix& a, const CSRSparseMatrix& b,
53                     CSRSparseMatrix* c) {
54     TensorShape a_tensor_shape;
55     TensorShape b_tensor_shape;
56     TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
57         a.dense_shape().vec<int64_t>(), &a_tensor_shape));
58     TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
59         b.dense_shape().vec<int64_t>(), &b_tensor_shape));
60 
61     if (a_tensor_shape.dims() == 3) {
62       if ((a_tensor_shape.dims() != b_tensor_shape.dims()) ||
63           (a_tensor_shape.dim_size(0) != b_tensor_shape.dim_size(0))) {
64         return errors::InvalidArgument(
65             "Incompatible shapes of a and b, a.shape == ",
66             a_tensor_shape.DebugString(),
67             ", b.shape == ", b_tensor_shape.DebugString());
68       }
69     }
70     const int rank = a_tensor_shape.dims();
71     if ((a_tensor_shape.dim_size(rank - 2) !=
72          b_tensor_shape.dim_size(rank - 2)) ||
73         (a_tensor_shape.dim_size(rank - 1) !=
74          b_tensor_shape.dim_size(rank - 1))) {
75       return errors::InvalidArgument(
76           "Incompatible shapes of a and b, a.shape == ",
77           a_tensor_shape.DebugString(),
78           ", b.shape == ", b_tensor_shape.DebugString());
79     }
80 
81     const int batch_size = a.batch_size();
82 
83     // TODO(ebrevdo): Add support for broadcasting at least in the
84     // batch dimension.
85     auto a_dense_shape = a.dense_shape().vec<int64_t>();
86     auto b_dense_shape = b.dense_shape().vec<int64_t>();
87     Tensor c_dense_shape_t = a.dense_shape();
88 
89     const int64_t rows = a_dense_shape((rank == 2) ? 0 : 1);
90 
91     functor::CSRSparseMatrixAdd<Device, T> csr_geam(ctx_, alpha_, beta_);
92     TF_RETURN_IF_ERROR(csr_geam.Initialize());
93 
94     Tensor c_batch_ptr_t(cpu_allocator(), DT_INT32,
95                          TensorShape({batch_size + 1}));
96     auto c_batch_ptr = c_batch_ptr_t.vec<int32>();
97     c_batch_ptr(0) = 0;
98 
99     Tensor c_row_ptr_t;
100     TF_RETURN_IF_ERROR(ctx_->allocate_temp(
101         DT_INT32, TensorShape({batch_size * (rows + 1)}), &c_row_ptr_t));
102     auto c_row_ptr = c_row_ptr_t.vec<int32>();
103 
104     // Set the output row pointers to zero, in case we hit any empty
105     // combinations of rows in a and b.
106     functor::SetZeroFunctor<Device, int32> set_zero;
107     const Device& d = ctx_->eigen_device<Device>();
108     set_zero(d, c_row_ptr_t.flat<int32>());
109 
110     size_t maxWorkspaceSize = 0;
111     for (int i = 0; i < batch_size; ++i) {
112       ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i),
113                                   a.values_vec<T>(i), a_dense_shape};
114       ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i),
115                                   b.values_vec<T>(i), b_dense_shape};
116 
117       size_t thisWorkspaceSize;
118       TF_RETURN_IF_ERROR(
119           csr_geam.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize));
120       if (thisWorkspaceSize > maxWorkspaceSize) {
121         maxWorkspaceSize = thisWorkspaceSize;
122       }
123     }
124 
125     Tensor temp;
126     TF_RETURN_IF_ERROR(ctx_->allocate_temp(
127         DT_INT8, TensorShape({static_cast<int64_t>(maxWorkspaceSize)}), &temp));
128     void* workspace = temp.flat<int8>().data();
129 
130     for (int i = 0; i < batch_size; ++i) {
131       // Calculate output sizes for all minibatch entries.
132       // Store in c_batch_ptr and update c_row_ptrs.
133       if (a.nnz(i) == 0 && b.nnz(i) == 0) {
134         c_batch_ptr(i + 1) = c_batch_ptr(i);
135         continue;
136       }
137       ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i),
138                                   a.values_vec<T>(i), a_dense_shape};
139       ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i),
140                                   b.values_vec<T>(i), b_dense_shape};
141       TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)),
142                                               rows + 1);
143       int c_nnz_i;
144       TF_RETURN_IF_ERROR(csr_geam.GetOutputStructure(
145           a_comp, b_comp, c_row_ptr_i, &c_nnz_i, workspace));
146       c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
147     }
148 
149     Tensor c_col_ind_t;
150     Tensor c_values_t;
151 
152     const int total_nnz = c_batch_ptr(batch_size);
153 
154     TF_RETURN_IF_ERROR(
155         ctx_->allocate_temp(DT_INT32, TensorShape({total_nnz}), &c_col_ind_t));
156     TF_RETURN_IF_ERROR(ctx_->allocate_temp(
157         DataTypeToEnum<T>::value, TensorShape({total_nnz}), &c_values_t));
158     TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
159         DataTypeToEnum<T>::value, c_dense_shape_t, c_batch_ptr_t, c_row_ptr_t,
160         c_col_ind_t, c_values_t, c));
161 
162     for (int i = 0; i < batch_size; ++i) {
163       if (a.nnz(i) == 0 && b.nnz(i) == 0) {
164         // Setting of c_row_pointers_vec(i) == 0 is already done.
165         continue;
166       }
167       ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i),
168                                   a.values_vec<T>(i), a_dense_shape};
169       ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i),
170                                   b.values_vec<T>(i), b_dense_shape};
171       CSRComponent<T> c_comp{c->row_pointers_vec(i), c->col_indices_vec(i),
172                              c->values_vec<T>(i),
173                              c_dense_shape_t.vec<int64_t>()};
174 
175       TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp, workspace));
176     }
177 
178     return OkStatus();
179   }
180 
181  private:
182   OpKernelContext* ctx_;
183   const T alpha_;
184   const T beta_;
185 };
186 
187 template <typename Device, typename T>
188 class CSRSparseMatrixSumFunctor : public CSRSparseMatrixAddFunctor<Device, T> {
189  public:
190   // Same as above, but with alpha = beta = 1.0, so C = 1.0 * A + 1.0 * B.
CSRSparseMatrixSumFunctor(OpKernelContext * ctx)191   explicit CSRSparseMatrixSumFunctor(OpKernelContext* ctx)
192       : CSRSparseMatrixAddFunctor<Device, T>(ctx, 1, 1) {}
193 };
194 
195 }  // namespace
196 
197 template <typename Device, typename T>
198 class CSRAddOp : public OpKernel {
199  public:
CSRAddOp(OpKernelConstruction * c)200   explicit CSRAddOp(OpKernelConstruction* c) : OpKernel(c) {}
201 
Compute(OpKernelContext * ctx)202   void Compute(OpKernelContext* ctx) final {
203     const CSRSparseMatrix* a_matrix;
204     const CSRSparseMatrix* b_matrix;
205     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix));
206     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &b_matrix));
207 
208     OP_REQUIRES(
209         ctx, a_matrix->dtype() == DataTypeToEnum<T>::value,
210         errors::InvalidArgument("dtype of a is not equal to 'type': ",
211                                 DataTypeString(a_matrix->dtype()), " vs. ",
212                                 DataTypeString(DataTypeToEnum<T>::value)));
213     OP_REQUIRES(
214         ctx, b_matrix->dtype() == DataTypeToEnum<T>::value,
215         errors::InvalidArgument("dtype of b is not equal to 'type': ",
216                                 DataTypeString(b_matrix->dtype()), " vs. ",
217                                 DataTypeString(DataTypeToEnum<T>::value)));
218 
219     const Tensor& alpha_t = ctx->input(2);
220     const Tensor& beta_t = ctx->input(3);
221     OP_REQUIRES(
222         ctx, TensorShapeUtils::IsScalar(alpha_t.shape()),
223         errors::InvalidArgument("Expected alpha to be a scalar, saw shape: ",
224                                 alpha_t.shape().DebugString()));
225     OP_REQUIRES(
226         ctx, TensorShapeUtils::IsScalar(beta_t.shape()),
227         errors::InvalidArgument("Expected beta to be a scalar, saw shape: ",
228                                 beta_t.shape().DebugString()));
229 
230     const T host_alpha = alpha_t.scalar<T>()();
231     const T host_beta = beta_t.scalar<T>()();
232 
233     Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
234     CSRSparseMatrix c_matrix;
235     CSRSparseMatrixAddFunctor<Device, T> add_functor(ctx, host_alpha,
236                                                      host_beta);
237     OP_REQUIRES_OK(ctx, add_functor(*a_matrix, *b_matrix, &c_matrix));
238     c_t.scalar<Variant>()() = std::move(c_matrix);
239     ctx->set_output(0, c_t);
240   }
241 };
242 
243 #define REGISTER(DEV, T)                              \
244   REGISTER_KERNEL_BUILDER(Name("SparseMatrixAdd")     \
245                               .Device(DEVICE_##DEV)   \
246                               .TypeConstraint<T>("T") \
247                               .HostMemory("alpha")    \
248                               .HostMemory("beta"),    \
249                           CSRAddOp<DEV##Device, T>);
250 
251 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
252 
253 #define REGISTER_GPU(T) REGISTER(GPU, T)
254 
255 REGISTER_GPU(float)
256 REGISTER_GPU(double)
257 #if GOOGLE_CUDA
258 REGISTER_GPU(complex64)
259 REGISTER_GPU(complex128)
260 #endif
261 
262 #undef REGISTER_GPU
263 
264 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
265     ADD_VARIANT_BINARY_OP, DEVICE_GPU, CSRSparseMatrix,
266     (CSRSparseMatrixBinaryHelper<GPUDevice, CSRSparseMatrixSumFunctor>));
267 
268 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
269 
270 #undef REGISTER
271 
272 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
273 namespace functor {
274 template <typename T>
275 struct CSRSparseMatrixAdd<GPUDevice, T>
276     : public CSRStructureModifyingFunctor<GPUDevice, T> {
CSRSparseMatrixAddtensorflow::functor::CSRSparseMatrixAdd277   explicit CSRSparseMatrixAdd(OpKernelContext* ctx, const T alpha, const T beta)
278       : ctx_(ctx),
279         cuda_sparse_(ctx),
280         alpha_(alpha),
281         beta_(beta),
282         initialized_(false) {}
283 
Initializetensorflow::functor::CSRSparseMatrixAdd284   Status Initialize() {
285     TF_RETURN_IF_ERROR(cuda_sparse_.Initialize());
286     TF_RETURN_IF_ERROR(descrA_.Initialize());
287     TF_RETURN_IF_ERROR(descrB_.Initialize());
288     TF_RETURN_IF_ERROR(descrC_.Initialize());
289     initialized_ = true;
290     return OkStatus();
291   }
292 
GetWorkspaceSizetensorflow::functor::CSRSparseMatrixAdd293   Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
294                           const ConstCSRComponent<T>& b, size_t* bufferSize) {
295     DCHECK(initialized_);
296 
297     const int m = a.row_ptr.size() - 1;
298     DCHECK_EQ(m, b.row_ptr.size() - 1);
299     const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1;
300     DCHECK_EQ(m, a.dense_shape_host(row_dim));
301     DCHECK_EQ(m, b.dense_shape_host(row_dim));
302     const int nnzA = a.col_ind.size();
303     const int nnzB = b.col_ind.size();
304 
305     const int n = a.dense_shape_host(row_dim + 1);
306     DCHECK_EQ(n, b.dense_shape_host(row_dim + 1));
307     T* null_T = nullptr;
308     int* null_int = nullptr;
309 
310     TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamBufferSizeExt(
311         m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
312         a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
313         b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), null_T, null_int,
314         null_int, bufferSize));
315 
316     return OkStatus();
317   }
318 
GetOutputStructuretensorflow::functor::CSRSparseMatrixAdd319   Status GetOutputStructure(const ConstCSRComponent<T>& a,
320                             const ConstCSRComponent<T>& b,
321                             TTypes<int32>::UnalignedVec c_row_ptr,
322                             int* output_nnz, void* workspace) {
323     DCHECK(initialized_);
324 
325     const int m = a.row_ptr.size() - 1;
326     DCHECK_EQ(m, b.row_ptr.size() - 1);
327     const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1;
328     DCHECK_EQ(m, a.dense_shape_host(row_dim));
329     DCHECK_EQ(m, b.dense_shape_host(row_dim));
330     const int nnzA = a.col_ind.size();
331     const int nnzB = b.col_ind.size();
332     *output_nnz = -1;
333 
334     const int n = a.dense_shape_host(row_dim + 1);
335     DCHECK_EQ(n, b.dense_shape_host(row_dim + 1));
336 
337     TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamNnz(
338         m, n, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
339         descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
340         descrC_.descr(), c_row_ptr.data(), output_nnz, workspace));
341 
342     if (*output_nnz < 0) {
343       return errors::Internal(
344           "CSRAdd: CsrgeamNnz returned nnzTotalDevHostPtr < 0: ", *output_nnz);
345     }
346     return OkStatus();
347   }
348 
Computetensorflow::functor::CSRSparseMatrixAdd349   Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
350                  CSRComponent<T>* c, void* workspace) {
351     DCHECK(initialized_);
352 
353     const int m = a.row_ptr.size() - 1;
354     DCHECK_EQ(m, b.row_ptr.size() - 1);
355     const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1;
356     DCHECK_EQ(m, a.dense_shape_host(row_dim));
357     DCHECK_EQ(m, b.dense_shape_host(row_dim));
358     const int nnzA = a.col_ind.size();
359     const int nnzB = b.col_ind.size();
360 
361     const int n = a.dense_shape_host(row_dim + 1);
362     DCHECK_EQ(n, b.dense_shape_host(row_dim + 1));
363 
364     // Adding alpha * a + beta * b.
365     TF_RETURN_IF_ERROR(cuda_sparse_.Csrgeam(
366         m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
367         a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
368         b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
369         c->row_ptr.data(), c->col_ind.data(), workspace));
370 
371     return OkStatus();
372   }
373 
374  private:
375   OpKernelContext* ctx_;
376   GpuSparse cuda_sparse_;
377   GpuSparseMatrixDescriptor descrA_;
378   GpuSparseMatrixDescriptor descrB_;
379   GpuSparseMatrixDescriptor descrC_;
380   const T alpha_;
381   const T beta_;
382   bool initialized_;
383 
384   TF_DISALLOW_COPY_AND_ASSIGN(CSRSparseMatrixAdd);
385 };
386 
387 }  // namespace functor
388 
389 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
390 
391 }  // namespace tensorflow
392