1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implements the kernel for the CSRTranspose op, which transposes the
17 // two innermost dimensions of a CSRSparseMatrix object stored in a
18 // DT_VARIANT.
19
20 #define EIGEN_USE_THREADS
21
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 #include "tensorflow/core/util/cuda_sparse.h"
24 #define EIGEN_USE_GPU
25 #endif
26
27 #include <numeric>
28
29 #include "third_party/eigen3/Eigen/SparseCore"
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/tensor_types.h"
34 #include "tensorflow/core/framework/variant_op_registry.h"
35 #include "tensorflow/core/kernels/cwise_ops.h"
36 #include "tensorflow/core/kernels/cwise_ops_common.h"
37 #include "tensorflow/core/kernels/dense_update_functor.h"
38 #include "tensorflow/core/kernels/fill_functor.h"
39 #include "tensorflow/core/kernels/slice_op.h"
40 #include "tensorflow/core/kernels/sparse/kernels.h"
41 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
42 #include "tensorflow/core/kernels/sparse/transpose_op.h"
43 #include "tensorflow/core/lib/core/threadpool.h"
44
45 namespace tensorflow {
46
47 typedef Eigen::ThreadPoolDevice CPUDevice;
48 typedef Eigen::GpuDevice GPUDevice;
49
50 namespace {
51
52 template <typename T>
ValidateTransposeInputs(const ConstCSRComponent<T> & input,const CSRComponent<T> & output)53 Status ValidateTransposeInputs(const ConstCSRComponent<T>& input,
54 const CSRComponent<T>& output) {
55 const int rank = input.dense_shape_host.size();
56 const int64_t nnz = input.col_ind.size();
57 const int num_rows = input.row_ptr.size() - 1;
58 const int num_cols = input.dense_shape_host(rank - 1);
59
60 if (nnz != input.values.size()) {
61 return errors::InvalidArgument(
62 "Input nnz should equal the input values size. Got ", nnz, " vs. ",
63 input.values.size());
64 }
65 if (num_cols + 1 != output.row_ptr.size()) {
66 return errors::InvalidArgument(
67 "Input num_cols should be equal to output num_rows. Got ", num_cols,
68 " vs. ", output.row_ptr.size());
69 }
70 if (rank != output.dense_shape_host.size()) {
71 return errors::InvalidArgument(
72 "Input rank should be equal to the output rank. Got ", rank, " vs. ",
73 output.dense_shape_host.size());
74 }
75 if (num_rows != output.dense_shape_host(rank - 1)) {
76 return errors::InvalidArgument(
77 "Input num_rows should be equal to the output num_cols. Got ", num_rows,
78 " vs. ", output.dense_shape_host(rank - 1));
79 }
80 if (nnz != output.col_ind.size()) {
81 return errors::InvalidArgument(
82 "Input nnz should equal the output col_ind size. Got ", nnz, " vs. ",
83 output.col_ind.size());
84 }
85 if (nnz != output.values.size()) {
86 return errors::InvalidArgument(
87 "Input nnz should equal the output values size. Got ", nnz, " vs. ",
88 output.values.size());
89 }
90 return OkStatus();
91 }
92 } // namespace
93
94 template <typename Device, typename T>
95 class CSRTransposeOp : public OpKernel {
96 public:
CSRTransposeOp(OpKernelConstruction * ctx)97 explicit CSRTransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
98 OP_REQUIRES_OK(ctx, ctx->GetAttr("conjugate", &conjugate_));
99 }
100
Compute(OpKernelContext * ctx)101 void Compute(OpKernelContext* ctx) override {
102 const CSRSparseMatrix* input_matrix;
103 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix));
104 OP_REQUIRES(
105 ctx, input_matrix->dtype() == DataTypeToEnum<T>::value,
106 errors::InvalidArgument("dtype of input is not equal to 'type': ",
107 DataTypeString(input_matrix->dtype()), " vs. ",
108 DataTypeString(DataTypeToEnum<T>::value)));
109
110 // Allocate output shapes
111 functor::CSRSparseMatrixTranspose<Device, T> transpose;
112 CSRSparseMatrix output_matrix;
113 OP_REQUIRES_OK(ctx,
114 transpose(ctx, conjugate_, *input_matrix, &output_matrix));
115 Tensor output_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
116 output_t.scalar<Variant>()() = std::move(output_matrix);
117 ctx->set_output(0, output_t);
118 }
119
120 private:
121 bool conjugate_;
122 };
123
124 #define REGISTER_TRANSPOSE(DEV, T) \
125 REGISTER_KERNEL_BUILDER(Name("SparseMatrixTranspose") \
126 .Device(DEVICE_##DEV) \
127 .TypeConstraint<T>("type"), \
128 CSRTransposeOp<DEV##Device, T>);
129
130 REGISTER_TRANSPOSE(CPU, float)
131 REGISTER_TRANSPOSE(CPU, double)
132 REGISTER_TRANSPOSE(CPU, complex64)
133 REGISTER_TRANSPOSE(CPU, complex128)
134
135 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
136 REGISTER_TRANSPOSE(GPU, float)
137 REGISTER_TRANSPOSE(GPU, double)
138 REGISTER_TRANSPOSE(GPU, complex64)
139 REGISTER_TRANSPOSE(GPU, complex128)
140 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
141
142 #undef REGISTER_TRANSPOSE
143
144 namespace functor {
145
146 template <typename Device, typename T>
operator ()(OpKernelContext * ctx,bool conjugate,const CSRSparseMatrix & input_matrix,CSRSparseMatrix * output_matrix)147 Status CSRSparseMatrixTranspose<Device, T>::operator()(
148 OpKernelContext* ctx, bool conjugate, const CSRSparseMatrix& input_matrix,
149 CSRSparseMatrix* output_matrix) {
150 const int rank = input_matrix.dims();
151 Tensor output_dense_shape_t(cpu_allocator(), DT_INT64, TensorShape({rank}));
152 const Tensor& input_dense_shape_t = input_matrix.dense_shape();
153 auto input_dense_shape = input_dense_shape_t.vec<int64_t>();
154 auto output_dense_shape = output_dense_shape_t.vec<int64_t>();
155 const int64_t batch_size = input_matrix.batch_size();
156 if (rank == 3) {
157 output_dense_shape(0) = batch_size;
158 }
159 output_dense_shape(rank - 2) = input_dense_shape(rank - 1);
160 output_dense_shape(rank - 1) = input_dense_shape(rank - 2);
161 const int64_t output_rows = output_dense_shape(rank - 2);
162
163 // nnzs per batch do not change with matrix transposition.
164 Tensor batch_ptr_t = input_matrix.batch_pointers();
165 const int total_nnz = input_matrix.total_nnz();
166
167 Tensor output_row_ptr_t;
168 Tensor output_col_ind_t;
169 Tensor output_values_t;
170
171 TF_RETURN_IF_ERROR(ctx->allocate_temp(
172 DT_INT32, TensorShape({batch_size * (output_rows + 1)}),
173 &output_row_ptr_t));
174 TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT32, TensorShape({total_nnz}),
175 &output_col_ind_t));
176 TF_RETURN_IF_ERROR(ctx->allocate_temp(
177 DataTypeToEnum<T>::value, TensorShape({total_nnz}), &output_values_t));
178
179 TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
180 DataTypeToEnum<T>::value, output_dense_shape_t, batch_ptr_t,
181 output_row_ptr_t, output_col_ind_t, output_values_t, output_matrix));
182
183 // Set the output row pointers to zero, in case we hit any empty
184 // input batches.
185 functor::SetZeroFunctor<Device, int32> set_zero;
186 const Device& d = ctx->eigen_device<Device>();
187 set_zero(d, output_row_ptr_t.flat<int32>());
188
189 functor::CSRSparseMatrixTransposeComponent<Device, T> transpose_component;
190 for (int i = 0; i < batch_size; ++i) {
191 if (output_matrix->nnz(i) == 0) {
192 continue;
193 }
194 ConstCSRComponent<T> input_comp{
195 input_matrix.row_pointers_vec(i), input_matrix.col_indices_vec(i),
196 input_matrix.values_vec<T>(i), input_dense_shape};
197 CSRComponent<T> output_comp{
198 output_matrix->row_pointers_vec(i), output_matrix->col_indices_vec(i),
199 output_matrix->values_vec<T>(i), output_dense_shape};
200
201 TF_RETURN_IF_ERROR(transpose_component(ctx, input_comp, &output_comp));
202 }
203 if (conjugate) {
204 // conjugate all values with a single kernel launch.
205 maybe_conj_inplace<Device, T>::run(d, &output_values_t);
206 }
207
208 return OkStatus();
209 }
210
211 // CPU kernel for transposing a single component of a CSR SparseMatrix.
212 template <typename T>
213 struct CSRSparseMatrixTransposeComponent<CPUDevice, T> {
214 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
215
operator ()tensorflow::functor::CSRSparseMatrixTransposeComponent216 Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& input,
217 CSRComponent<T>* output) {
218 TF_RETURN_IF_ERROR(ValidateTransposeInputs(input, *output));
219
220 const int rank = input.dense_shape_host.size();
221 const int num_rows = input.row_ptr.size() - 1;
222 const int num_cols = input.dense_shape_host(rank - 1);
223 const int64_t nnz = input.col_ind.size();
224
225 // Compute the column counts; whose prefix sums make up the output row
226 // pointers.
227 for (int64_t i = 0; i < nnz; ++i) {
228 output->row_ptr(input.col_ind(i) + 1) += 1;
229 }
230 std::partial_sum(output->row_ptr.data(),
231 output->row_ptr.data() + num_cols + 1,
232 output->row_ptr.data());
233
234 // Iterate through each row of the input, and place each non-zero element
235 // into the target output row (based on the current column count).
236 std::vector<int> current_col_count(num_cols);
237 for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
238 const int64_t row_begin = input.row_ptr(row_idx);
239 const int64_t row_end = input.row_ptr(row_idx + 1);
240 for (int64_t i = row_begin; i < row_end; ++i) {
241 const int col_idx = input.col_ind(i);
242 const int64_t offset =
243 output->row_ptr(col_idx) + current_col_count[col_idx];
244 output->col_ind(offset) = row_idx;
245 output->values(offset) = input.values(i);
246 current_col_count[col_idx] += 1;
247 }
248 }
249 return OkStatus();
250 }
251 };
252
253 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
254
255 template <typename T>
256 struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
operator ()tensorflow::functor::CSRSparseMatrixTransposeComponent257 Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x,
258 CSRComponent<T>* y) {
259 TF_RETURN_IF_ERROR(ValidateTransposeInputs(x, *y));
260 GpuSparse cuda_sparse(ctx);
261 TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
262 const gpusparseAction_t copyValues = GPUSPARSE(ACTION_NUMERIC);
263 const int rank = x.dense_shape_host.size();
264 const int m = x.row_ptr.size() - 1;
265 const int n = x.dense_shape_host(rank - 1);
266 const int nnz = x.col_ind.size();
267 DCHECK_EQ(nnz, x.values.size());
268 DCHECK_EQ(n, y->row_ptr.size() - 1);
269 DCHECK_EQ(rank, y->dense_shape_host.size());
270 DCHECK_EQ(m, y->dense_shape_host(rank - 1));
271 DCHECK_EQ(nnz, y->col_ind.size());
272 DCHECK_EQ(nnz, y->values.size());
273
274 return cuda_sparse.Csr2csc(
275 m, n, nnz, x.values.data() /*csrVal*/, x.row_ptr.data() /*csrRowPtr*/,
276 x.col_ind.data() /*csrColInd*/, y->values.data() /*cscVal*/,
277 y->col_ind.data() /*cscRowInd*/, y->row_ptr.data() /*cscColPtr*/,
278 copyValues);
279 return OkStatus();
280 }
281 };
282 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
283 } // namespace functor
284
285 } // namespace tensorflow
286