xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse_reorder_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/kernels/sparse_reorder_op.h"
19 
20 #include <algorithm>
21 #include <numeric>
22 #include <unordered_map>
23 #include <utility>
24 
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_util.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 #include "tensorflow/core/util/sparse/sparse_tensor.h"
32 
33 namespace tensorflow {
34 
35 using CPUDevice = Eigen::ThreadPoolDevice;
36 
37 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
38 using GPUDevice = Eigen::GpuDevice;
39 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
40 
41 namespace functor {
42 
43 template <typename T>
44 struct SparseReorderFunctor<CPUDevice, T> {
operator ()tensorflow::functor::SparseReorderFunctor45   void operator()(OpKernelContext* context, const Tensor& input_ind,
46                   const Tensor& input_val, const Tensor& input_shape_in) {
47     gtl::ArraySlice<int64_t> input_shape(input_shape_in.vec<int64_t>().data(),
48                                          input_shape_in.NumElements());
49 
50     gtl::InlinedVector<int64_t, 8> std_order(input_shape.size());
51     std::iota(std_order.begin(), std_order.end(), 0);
52 
53     // Check if the sparse tensor is already ordered correctly
54     sparse::SparseTensor input_sp;
55     OP_REQUIRES_OK(
56         context, sparse::SparseTensor::Create(input_ind, input_val, input_shape,
57                                               std_order, &input_sp));
58 
59     if (input_sp.IndicesValid().ok()) {
60       context->set_output(0, input_sp.indices());
61       context->set_output(1, input_sp.values());
62     } else {
63       // Deep-copy the input Tensors, then reorder in-place
64       sparse::SparseTensor reordered_sp;
65       OP_REQUIRES_OK(context,
66                      sparse::SparseTensor::Create(tensor::DeepCopy(input_ind),
67                                                   tensor::DeepCopy(input_val),
68                                                   input_shape, &reordered_sp));
69       reordered_sp.Reorder<T>(std_order);
70       context->set_output(0, reordered_sp.indices());
71       context->set_output(1, reordered_sp.values());
72     }
73   }
74 };
75 
76 }  // namespace functor
77 
78 template <typename Device, typename T>
79 class SparseReorderOp : public OpKernel {
80  public:
SparseReorderOp(OpKernelConstruction * context)81   explicit SparseReorderOp(OpKernelConstruction* context) : OpKernel(context) {}
82 
Compute(OpKernelContext * context)83   void Compute(OpKernelContext* context) override {
84     const Tensor& input_ind = context->input(0);
85     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_ind.shape()),
86                 errors::InvalidArgument(
87                     "Input indices should be a matrix but received shape ",
88                     input_ind.shape().DebugString()));
89 
90     const Tensor& input_val = context->input(1);
91     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_val.shape()),
92                 errors::InvalidArgument(
93                     "Input values should be a vector but received shape ",
94                     input_val.shape().DebugString()));
95 
96     const Tensor& input_shape_in = context->input(2);
97     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()),
98                 errors::InvalidArgument(
99                     "Input shape should be a vector but received shape ",
100                     input_shape_in.shape().DebugString()));
101 
102     functor::SparseReorderFunctor<Device, T>()(context, input_ind, input_val,
103                                                input_shape_in);
104   }
105 };
106 
107 #define REGISTER_KERNELS(type)                                            \
108   REGISTER_KERNEL_BUILDER(                                                \
109       Name("SparseReorder").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
110       SparseReorderOp<CPUDevice, type>)
111 
112 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
113 #undef REGISTER_KERNELS
114 
115 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
116 
117 #define REGISTER_GPU_KERNELS(type)                                        \
118   REGISTER_KERNEL_BUILDER(                                                \
119       Name("SparseReorder").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
120       SparseReorderOp<GPUDevice, type>)
121 
122 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
123 TF_CALL_INTEGRAL_TYPES(REGISTER_GPU_KERNELS);
124 REGISTER_GPU_KERNELS(bool);
125 #undef REGISTER_GPU_KERNELS
126 
127 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
128 
129 }  // namespace tensorflow
130