xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/reduce_join_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/string_ops.cc.
17 
18 #include <string>
19 
20 #include "tensorflow/core/framework/kernel_def_builder.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 
30 namespace tensorflow {
31 
32 namespace {
33 
GetStrides(const TensorShape & shape)34 const gtl::InlinedVector<int64_t, 8> GetStrides(const TensorShape& shape) {
35   gtl::InlinedVector<int64_t, 8> result(shape.dims());
36   int64_t product = 1;
37   for (int32_t i = shape.dims() - 1; i >= 0; --i) {
38     result[i] = product;
39     product *= shape.dim_size(i);
40   }
41   return result;
42 }
43 
44 // Given a linear index to a subset of dimensions, full shape,
45 // precomputed list of running products of the full shape, and list of
46 // dimensions in the subset, outputs the linear index to the full shape with
47 // nonspecified dimensions set to 0.  Dimensions must be ordered from outer-most
48 // to inner-most with respect to the subset linear index.
LinearSubIndexToFullIndex(int64_t output_index,const gtl::InlinedVector<int32,8> & dim_list,const TensorShape & input_shape,const gtl::InlinedVector<int64_t,8> & strides)49 inline int64_t LinearSubIndexToFullIndex(
50     int64_t output_index, const gtl::InlinedVector<int32, 8>& dim_list,
51     const TensorShape& input_shape,
52     const gtl::InlinedVector<int64_t, 8>& strides) {
53   int64_t result = 0;
54   int64_t quotient = output_index;
55   for (int32_t i = dim_list.size() - 1; i >= 0; --i) {
56     int32_t dim = dim_list[i];
57     int64_t dim_value = quotient % input_shape.dim_size(dim);
58     quotient = quotient / input_shape.dim_size(dim);
59     result += strides[dim] * dim_value;
60   }
61   return result;
62 }
63 
64 // Computes the number of input elements reduced per output element.
GetReductionIterSize(const gtl::InlinedVector<int32,8> & reduced_indices,const TensorShape & input_shape)65 int64_t GetReductionIterSize(
66     const gtl::InlinedVector<int32, 8>& reduced_indices,
67     const TensorShape& input_shape) {
68   int64_t result = 1;
69   for (int32_t reduce_dim : reduced_indices) {
70     result *= input_shape.dim_size(reduce_dim);
71   }
72   return result;
73 }
74 
75 // Computes a list of all true reduced indices, accounting for negative
76 // indices.
GetReducedIndices(const Tensor & reduction_indices,int32_t input_dims)77 gtl::InlinedVector<int32, 8> GetReducedIndices(const Tensor& reduction_indices,
78                                                int32_t input_dims) {
79   const auto reduction_indices_flat = reduction_indices.flat<int32>();
80   const int32_t reduction_dims = reduction_indices_flat.size();
81 
82   gtl::InlinedVector<int32, 8> reduced_indices(reduction_dims);
83   for (int32_t i = 0; i < reduction_dims; ++i) {
84     reduced_indices[i] = reduction_indices_flat(reduction_dims - i - 1);
85     reduced_indices[i] += reduced_indices[i] < 0 ? input_dims : 0;
86   }
87 
88   return reduced_indices;
89 }
90 
91 // Appends all unreduced dimensions to the given vector.
MakeUnreducedIndices(gtl::InlinedVector<bool,8> index_is_reduced,int32_t input_dims,gtl::InlinedVector<int32,8> * unreduced_indices)92 void MakeUnreducedIndices(gtl::InlinedVector<bool, 8> index_is_reduced,
93                           int32_t input_dims,
94                           gtl::InlinedVector<int32, 8>* unreduced_indices) {
95   for (int32_t index = 0; index < input_dims; ++index) {
96     if (!index_is_reduced[index]) unreduced_indices->push_back(index);
97   }
98 }
99 
GetOutputShape(gtl::InlinedVector<bool,8> index_is_reduced,const TensorShape & input_shape,bool keep_dims)100 TensorShape GetOutputShape(gtl::InlinedVector<bool, 8> index_is_reduced,
101                            const TensorShape& input_shape, bool keep_dims) {
102   TensorShape output_shape;
103   for (size_t index = 0; index < index_is_reduced.size(); ++index) {
104     if (index_is_reduced[index]) {
105       if (keep_dims) output_shape.AddDim(1);
106     } else {
107       output_shape.AddDim(input_shape.dim_size(index));
108     }
109   }
110   return output_shape;
111 }
112 
113 }  // namespace
114 
115 class ReduceJoinOp : public OpKernel {
116  public:
117   using OpKernel::OpKernel;
118 
ReduceJoinOp(OpKernelConstruction * ctx)119   explicit ReduceJoinOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
120     OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
121     OP_REQUIRES_OK(ctx, ctx->GetAttr("separator", &separator_));
122   }
123 
Compute(OpKernelContext * context)124   void Compute(OpKernelContext* context) override {
125     const Tensor& input = context->input(0);
126     const auto input_flat = input.flat<tstring>();
127     const TensorShape& input_shape = input.shape();
128     const int32_t input_dims = input_shape.dims();
129 
130     const Tensor& reduction_indices = context->input(1);
131     const auto reduction_indices_flat = reduction_indices.flat<int32>();
132     const int32_t reduction_dims = reduction_indices_flat.size();
133 
134     gtl::InlinedVector<bool, 8> index_is_reduced(input_dims, false);
135     for (int32_t i = 0; i < reduction_dims; i++) {
136       int32_t reduce_index = reduction_indices_flat(i);
137       const int32_t true_reduce_index =
138           reduce_index < 0 ? reduce_index + input_dims : reduce_index;
139       OP_REQUIRES(
140           context, reduce_index >= -input_dims && reduce_index < input_dims,
141           errors::OutOfRange("Invalid reduction dimension ", reduce_index,
142                              " for input with ", input_dims, " dimension(s)"));
143       OP_REQUIRES(context, !index_is_reduced[true_reduce_index],
144                   errors::InvalidArgument("Duplicate reduction dimension ",
145                                           reduce_index));
146       index_is_reduced[true_reduce_index] = true;
147     }
148 
149     gtl::InlinedVector<int32, 8> reduced_indices =
150         GetReducedIndices(reduction_indices, input_dims);
151     gtl::InlinedVector<int32, 8> unreduced_indices;
152     MakeUnreducedIndices(index_is_reduced, input_dims, &unreduced_indices);
153     const auto strides = GetStrides(input_shape);
154 
155     Tensor* output_tensor = nullptr;
156     TensorShape output_shape =
157         GetOutputShape(index_is_reduced, input_shape, keep_dims_);
158     OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
159                                                      &output_tensor));
160     auto output_flat = output_tensor->flat<tstring>();
161 
162     const int64_t reduction_iter_size =
163         GetReductionIterSize(reduced_indices, input_shape);
164     gtl::InlinedVector<StringPiece, 8> curr_strings(reduction_iter_size);
165     for (int64_t output_index = 0; output_index < output_shape.num_elements();
166          ++output_index) {
167       int64_t output_full_index = LinearSubIndexToFullIndex(
168           output_index, unreduced_indices, input_shape, strides);
169       for (int64_t reduction_index = 0; reduction_index < reduction_iter_size;
170            ++reduction_index) {
171         int64_t reduction_full_index = LinearSubIndexToFullIndex(
172             reduction_index, reduced_indices, input_shape, strides);
173         curr_strings[reduction_index] =
174             input_flat(output_full_index + reduction_full_index);
175       }
176       output_flat(output_index) = absl::StrJoin(curr_strings, separator_);
177     }
178   }
179 
180  private:
181   bool keep_dims_;
182   string separator_;
183 };
184 
185 REGISTER_KERNEL_BUILDER(Name("ReduceJoin").Device(DEVICE_CPU), ReduceJoinOp);
186 
187 }  // namespace tensorflow
188