xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse/transpose_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements the kernel for the CSRTranspose op, which transposes the
17 // two innermost dimensions of a CSRSparseMatrix object stored in a
18 // DT_VARIANT.
19 
20 #define EIGEN_USE_THREADS
21 
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 #include "tensorflow/core/util/cuda_sparse.h"
24 #define EIGEN_USE_GPU
25 #endif
26 
27 #include <numeric>
28 
29 #include "third_party/eigen3/Eigen/SparseCore"
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/tensor_types.h"
34 #include "tensorflow/core/framework/variant_op_registry.h"
35 #include "tensorflow/core/kernels/cwise_ops.h"
36 #include "tensorflow/core/kernels/cwise_ops_common.h"
37 #include "tensorflow/core/kernels/dense_update_functor.h"
38 #include "tensorflow/core/kernels/fill_functor.h"
39 #include "tensorflow/core/kernels/slice_op.h"
40 #include "tensorflow/core/kernels/sparse/kernels.h"
41 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
42 #include "tensorflow/core/kernels/sparse/transpose_op.h"
43 #include "tensorflow/core/lib/core/threadpool.h"
44 
45 namespace tensorflow {
46 
47 typedef Eigen::ThreadPoolDevice CPUDevice;
48 typedef Eigen::GpuDevice GPUDevice;
49 
50 namespace {
51 
52 template <typename T>
ValidateTransposeInputs(const ConstCSRComponent<T> & input,const CSRComponent<T> & output)53 Status ValidateTransposeInputs(const ConstCSRComponent<T>& input,
54                                const CSRComponent<T>& output) {
55   const int rank = input.dense_shape_host.size();
56   const int64_t nnz = input.col_ind.size();
57   const int num_rows = input.row_ptr.size() - 1;
58   const int num_cols = input.dense_shape_host(rank - 1);
59 
60   if (nnz != input.values.size()) {
61     return errors::InvalidArgument(
62         "Input nnz should equal the input values size. Got ", nnz, " vs. ",
63         input.values.size());
64   }
65   if (num_cols + 1 != output.row_ptr.size()) {
66     return errors::InvalidArgument(
67         "Input num_cols should be equal to output num_rows. Got ", num_cols,
68         " vs. ", output.row_ptr.size());
69   }
70   if (rank != output.dense_shape_host.size()) {
71     return errors::InvalidArgument(
72         "Input rank should be equal to the output rank. Got ", rank, " vs. ",
73         output.dense_shape_host.size());
74   }
75   if (num_rows != output.dense_shape_host(rank - 1)) {
76     return errors::InvalidArgument(
77         "Input num_rows should be equal to the output num_cols. Got ", num_rows,
78         " vs. ", output.dense_shape_host(rank - 1));
79   }
80   if (nnz != output.col_ind.size()) {
81     return errors::InvalidArgument(
82         "Input nnz should equal the output col_ind size. Got ", nnz, " vs. ",
83         output.col_ind.size());
84   }
85   if (nnz != output.values.size()) {
86     return errors::InvalidArgument(
87         "Input nnz should equal the output values size. Got ", nnz, " vs. ",
88         output.values.size());
89   }
90   return OkStatus();
91 }
92 }  // namespace
93 
94 template <typename Device, typename T>
95 class CSRTransposeOp : public OpKernel {
96  public:
CSRTransposeOp(OpKernelConstruction * ctx)97   explicit CSRTransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
98     OP_REQUIRES_OK(ctx, ctx->GetAttr("conjugate", &conjugate_));
99   }
100 
Compute(OpKernelContext * ctx)101   void Compute(OpKernelContext* ctx) override {
102     const CSRSparseMatrix* input_matrix;
103     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix));
104     OP_REQUIRES(
105         ctx, input_matrix->dtype() == DataTypeToEnum<T>::value,
106         errors::InvalidArgument("dtype of input is not equal to 'type': ",
107                                 DataTypeString(input_matrix->dtype()), " vs. ",
108                                 DataTypeString(DataTypeToEnum<T>::value)));
109 
110     // Allocate output shapes
111     functor::CSRSparseMatrixTranspose<Device, T> transpose;
112     CSRSparseMatrix output_matrix;
113     OP_REQUIRES_OK(ctx,
114                    transpose(ctx, conjugate_, *input_matrix, &output_matrix));
115     Tensor output_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
116     output_t.scalar<Variant>()() = std::move(output_matrix);
117     ctx->set_output(0, output_t);
118   }
119 
120  private:
121   bool conjugate_;
122 };
123 
124 #define REGISTER_TRANSPOSE(DEV, T)                        \
125   REGISTER_KERNEL_BUILDER(Name("SparseMatrixTranspose")   \
126                               .Device(DEVICE_##DEV)       \
127                               .TypeConstraint<T>("type"), \
128                           CSRTransposeOp<DEV##Device, T>);
129 
130 REGISTER_TRANSPOSE(CPU, float)
131 REGISTER_TRANSPOSE(CPU, double)
132 REGISTER_TRANSPOSE(CPU, complex64)
133 REGISTER_TRANSPOSE(CPU, complex128)
134 
135 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
136 REGISTER_TRANSPOSE(GPU, float)
137 REGISTER_TRANSPOSE(GPU, double)
138 REGISTER_TRANSPOSE(GPU, complex64)
139 REGISTER_TRANSPOSE(GPU, complex128)
140 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
141 
142 #undef REGISTER_TRANSPOSE
143 
144 namespace functor {
145 
146 template <typename Device, typename T>
operator ()(OpKernelContext * ctx,bool conjugate,const CSRSparseMatrix & input_matrix,CSRSparseMatrix * output_matrix)147 Status CSRSparseMatrixTranspose<Device, T>::operator()(
148     OpKernelContext* ctx, bool conjugate, const CSRSparseMatrix& input_matrix,
149     CSRSparseMatrix* output_matrix) {
150   const int rank = input_matrix.dims();
151   Tensor output_dense_shape_t(cpu_allocator(), DT_INT64, TensorShape({rank}));
152   const Tensor& input_dense_shape_t = input_matrix.dense_shape();
153   auto input_dense_shape = input_dense_shape_t.vec<int64_t>();
154   auto output_dense_shape = output_dense_shape_t.vec<int64_t>();
155   const int64_t batch_size = input_matrix.batch_size();
156   if (rank == 3) {
157     output_dense_shape(0) = batch_size;
158   }
159   output_dense_shape(rank - 2) = input_dense_shape(rank - 1);
160   output_dense_shape(rank - 1) = input_dense_shape(rank - 2);
161   const int64_t output_rows = output_dense_shape(rank - 2);
162 
163   // nnzs per batch do not change with matrix transposition.
164   Tensor batch_ptr_t = input_matrix.batch_pointers();
165   const int total_nnz = input_matrix.total_nnz();
166 
167   Tensor output_row_ptr_t;
168   Tensor output_col_ind_t;
169   Tensor output_values_t;
170 
171   TF_RETURN_IF_ERROR(ctx->allocate_temp(
172       DT_INT32, TensorShape({batch_size * (output_rows + 1)}),
173       &output_row_ptr_t));
174   TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT32, TensorShape({total_nnz}),
175                                         &output_col_ind_t));
176   TF_RETURN_IF_ERROR(ctx->allocate_temp(
177       DataTypeToEnum<T>::value, TensorShape({total_nnz}), &output_values_t));
178 
179   TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
180       DataTypeToEnum<T>::value, output_dense_shape_t, batch_ptr_t,
181       output_row_ptr_t, output_col_ind_t, output_values_t, output_matrix));
182 
183   // Set the output row pointers to zero, in case we hit any empty
184   // input batches.
185   functor::SetZeroFunctor<Device, int32> set_zero;
186   const Device& d = ctx->eigen_device<Device>();
187   set_zero(d, output_row_ptr_t.flat<int32>());
188 
189   functor::CSRSparseMatrixTransposeComponent<Device, T> transpose_component;
190   for (int i = 0; i < batch_size; ++i) {
191     if (output_matrix->nnz(i) == 0) {
192       continue;
193     }
194     ConstCSRComponent<T> input_comp{
195         input_matrix.row_pointers_vec(i), input_matrix.col_indices_vec(i),
196         input_matrix.values_vec<T>(i), input_dense_shape};
197     CSRComponent<T> output_comp{
198         output_matrix->row_pointers_vec(i), output_matrix->col_indices_vec(i),
199         output_matrix->values_vec<T>(i), output_dense_shape};
200 
201     TF_RETURN_IF_ERROR(transpose_component(ctx, input_comp, &output_comp));
202   }
203   if (conjugate) {
204     // conjugate all values with a single kernel launch.
205     maybe_conj_inplace<Device, T>::run(d, &output_values_t);
206   }
207 
208   return OkStatus();
209 }
210 
211 // CPU kernel for transposing a single component of a CSR SparseMatrix.
212 template <typename T>
213 struct CSRSparseMatrixTransposeComponent<CPUDevice, T> {
214   using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
215 
operator ()tensorflow::functor::CSRSparseMatrixTransposeComponent216   Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& input,
217                     CSRComponent<T>* output) {
218     TF_RETURN_IF_ERROR(ValidateTransposeInputs(input, *output));
219 
220     const int rank = input.dense_shape_host.size();
221     const int num_rows = input.row_ptr.size() - 1;
222     const int num_cols = input.dense_shape_host(rank - 1);
223     const int64_t nnz = input.col_ind.size();
224 
225     // Compute the column counts; whose prefix sums make up the output row
226     // pointers.
227     for (int64_t i = 0; i < nnz; ++i) {
228       output->row_ptr(input.col_ind(i) + 1) += 1;
229     }
230     std::partial_sum(output->row_ptr.data(),
231                      output->row_ptr.data() + num_cols + 1,
232                      output->row_ptr.data());
233 
234     // Iterate through each row of the input, and place each non-zero element
235     // into the target output row (based on the current column count).
236     std::vector<int> current_col_count(num_cols);
237     for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
238       const int64_t row_begin = input.row_ptr(row_idx);
239       const int64_t row_end = input.row_ptr(row_idx + 1);
240       for (int64_t i = row_begin; i < row_end; ++i) {
241         const int col_idx = input.col_ind(i);
242         const int64_t offset =
243             output->row_ptr(col_idx) + current_col_count[col_idx];
244         output->col_ind(offset) = row_idx;
245         output->values(offset) = input.values(i);
246         current_col_count[col_idx] += 1;
247       }
248     }
249     return OkStatus();
250   }
251 };
252 
253 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
254 
255 template <typename T>
256 struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
operator ()tensorflow::functor::CSRSparseMatrixTransposeComponent257   Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x,
258                     CSRComponent<T>* y) {
259     TF_RETURN_IF_ERROR(ValidateTransposeInputs(x, *y));
260     GpuSparse cuda_sparse(ctx);
261     TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
262     const gpusparseAction_t copyValues = GPUSPARSE(ACTION_NUMERIC);
263     const int rank = x.dense_shape_host.size();
264     const int m = x.row_ptr.size() - 1;
265     const int n = x.dense_shape_host(rank - 1);
266     const int nnz = x.col_ind.size();
267     DCHECK_EQ(nnz, x.values.size());
268     DCHECK_EQ(n, y->row_ptr.size() - 1);
269     DCHECK_EQ(rank, y->dense_shape_host.size());
270     DCHECK_EQ(m, y->dense_shape_host(rank - 1));
271     DCHECK_EQ(nnz, y->col_ind.size());
272     DCHECK_EQ(nnz, y->values.size());
273 
274     return cuda_sparse.Csr2csc(
275         m, n, nnz, x.values.data() /*csrVal*/, x.row_ptr.data() /*csrRowPtr*/,
276         x.col_ind.data() /*csrColInd*/, y->values.data() /*cscVal*/,
277         y->col_ind.data() /*cscRowInd*/, y->row_ptr.data() /*cscColPtr*/,
278         copyValues);
279     return OkStatus();
280   }
281 };
282 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
283 }  // namespace functor
284 
285 }  // namespace tensorflow
286