xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/reduction_ops_common.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This is an internal header file intended to only be included as the
17 // front-matter in the implementation files of various reduction ops.  It
18 // is a header file because we split the various reduction ops into their
19 // own compilation units to get more parallelism in compilation.
20 
21 #ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_
22 #define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_
23 
24 #define EIGEN_USE_THREADS
25 
26 #include "third_party/eigen3/Eigen/Core"
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 
29 #include "tensorflow/core/framework/numeric_op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/kernels/reduction_ops.h"
35 #include "tensorflow/core/kernels/transpose_functor.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/lib/gtl/inlined_vector.h"
38 #include "tensorflow/core/platform/logging.h"
39 
40 namespace tensorflow {
41 
42 typedef Eigen::ThreadPoolDevice CPUDevice;
43 typedef Eigen::GpuDevice GPUDevice;
44 
45 template <typename Device>
46 struct Constants {
47   // Derive Index type. int (32-bit) or long (64-bit) depending on the
48   // compile-time configuration. "float" here is not relevant.
49   // TODO(zhifengc): Moves the definition to TTypes.
50   typedef TTypes<float>::Tensor::Index Index;
51   Eigen::array<Index, 1> kZero;
52   Eigen::array<Index, 1> kOne;
53   Eigen::array<Index, 2> kZeroTwo;
54 
ConstantsConstants55   Constants() {
56     kZero[0] = 0;
57     kOne[0] = 1;
58     kZeroTwo[0] = 0;
59     kZeroTwo[1] = 2;
60   }
61 };
62 
63 struct ConstantsBase {
64   const Eigen::IndexList<Eigen::type2index<0>> kZero;
65   const Eigen::IndexList<Eigen::type2index<1>> kOne;
66   const Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2>> kZeroTwo;
67 };
68 template <>
69 struct Constants<CPUDevice> : ConstantsBase {};
70 
71 class ReductionHelper {
72  public:
73   ReductionHelper() : reduce_first_axis_(false) {}
74 
75   Status Simplify(const Tensor& data, const Tensor& axis, const bool keep_dims);
76 
77   // We need to do roughly:
78   //   tmp_out = allocate(out_reshape())
79   //   tmp_out.reshape(out_reshape) = data.reshape(data_reshape).reduce(axes)
80   //   out = tmp_out.reshape(out_shape)
81 
82   // The reduction result must be allocated with this shape.
83   TensorShape out_reshape() const;
84 
85   // The final output shape must be allocated with this shape.
86   TensorShape out_shape() const;
87 
88   // The reduction is on a reshaped tensor of this rank.
89   int ndims() const { return data_reshape_.size(); }
90 
91   // True if need to reduce the 0-th dimension.
92   bool reduce_first_axis() const { return reduce_first_axis_; }
93 
94   // The output is reshaped.
95   template <typename T, int N>
96   typename TTypes<T, N>::Tensor out(Tensor* out) {
97     return out->shaped<T, N>(out_reshape_);
98   }
99 
100   // The input is reshaped.
101   template <typename T, int N>
102   typename TTypes<T, N>::ConstTensor in(const Tensor& data) {
103     return data.shaped<T, N>(data_reshape_);
104   }
105 
106   // Shape of shuffled input
107   TensorShape data_reshape() const {
108     TensorShape shape;
109     for (auto s : data_reshape_) shape.AddDim(s);
110     return shape;
111   }
112 
113   // Shape with all reduction dimensions at the end
114   TensorShape shuffled_shape();
115 
116   // Permutation of reduced dims needed to put reduction dimensions at the end
117   gtl::InlinedVector<int32, 8> permutation();
118 
119  private:
120   bool reduce_first_axis_;  // True if need to reduce the 0-th dimension.
121   gtl::InlinedVector<int64_t, 4>
122       data_reshape_;                          // Reshape data before reduction.
123   gtl::InlinedVector<int64_t, 4> out_shape_;  // The final output shape.
124   gtl::InlinedVector<int64_t, 4> out_reshape_;  // Reshape output for reduction.
125 };
126 
127 // For operations where the output is a reduction function along some
128 // dimensions of the input.
129 template <typename Device, class T, typename Tperm, typename Reducer>
130 class ReductionOp : public OpKernel {
131  public:
132   explicit ReductionOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
133     const DataType dt = DataTypeToEnum<T>::v();
134     const DataType pt = DataTypeToEnum<Tperm>::v();
135     OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, pt}, {dt}));
136 
137     OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
138   }
139 
140   void Compute(OpKernelContext* ctx) override {
141     const Tensor& data = ctx->input(0);
142     const Tensor& axes = ctx->input(1);
143     VLOG(1) << "data shape: " << data.shape().DebugString();
144     VLOG(1) << "axes      : " << axes.SummarizeValue(10);
145 
146     ReductionHelper helper;
147     OP_REQUIRES_OK(ctx, helper.Simplify(data, axes, keep_dims_));
148     CHECK_GE(helper.ndims(), 0);
149 
150     bool is_scalar_identity = functor::ReducerTraits<Reducer>::IsScalarIdentity;
151     bool is_trivial = helper.ndims() == 0 ||
152                       (helper.ndims() == 1 && !helper.reduce_first_axis());
153     if (is_scalar_identity && is_trivial) {
154       Tensor out;
155       // Special case. Reduces nothing and does not alter the input values.
156       if (!out.CopyFrom(data, helper.out_shape())) {
157         ctx->SetStatus(errors::Internal("Error during reduction copy."));
158       }
159       ctx->set_output(0, out);
160       return;
161     }
162 
163     // We must allocate temp tensors using the same alloc attr as
164     // output(0) because it is returned as output(0) in the end.
165     const AllocatorAttributes alloc_attr = ctx->output_alloc_attr(0);
166 
167     Tensor tmp_out;
168     typedef functor::ReduceFunctor<Device, Reducer> Functor;
169     Constants<Device> constants;
170     const Device& d = ctx->eigen_device<Device>();
171     Reducer reducer;
172 
173     if (data.NumElements() > 0 && is_trivial && !is_scalar_identity) {
174       OP_REQUIRES_OK(ctx, ctx->allocate_temp(ctx->expected_output_dtype(0),
175                                              TensorShape({data.NumElements()}),
176                                              &tmp_out, alloc_attr));
177       Functor::Reduce(ctx, tmp_out.flat<T>(),
178                       data.shaped<T, 2>({1, data.NumElements()}),
179                       constants.kZero, reducer);
180     } else {
181       // A temporary tensor whose size matches the size of the reduced
182       // output.
183       OP_REQUIRES_OK(
184           ctx, ctx->allocate_temp(ctx->expected_output_dtype(0),
185                                   helper.out_reshape(), &tmp_out, alloc_attr));
186 
187       if (tmp_out.NumElements() == 0) {
188         // Nothing to do, fall through to final reshaping.
189       } else if (data.NumElements() == 0) {
190         // Degenerate reduction where the input is empty but the output is
191         // nonempty (thus tmp_out.NumElements() > 0), and we must fill the
192         // output with identity elements.  Example: tf.reduce_sum(tf.zeros((0,
193         // 3)), [0]). Eigen sometimes crashes in this case, so we do it
194         // manually.
195         Functor::FillIdentity(d, tmp_out.flat<T>(), reducer);
196       } else if ((helper.ndims() == 1) && helper.reduce_first_axis()) {
197         // Reduce to a scalar.
198         Functor::Reduce(ctx, helper.out<T, 0>(&tmp_out), helper.in<T, 1>(data),
199                         constants.kZero, reducer);
200       } else if ((helper.ndims() == 2) && helper.reduce_first_axis()) {
201         // Can be viewed as a reduction of a matrix along 1st dimension.
202         Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data),
203                         constants.kZero, reducer);
204       } else if ((helper.ndims() == 2) && !helper.reduce_first_axis()) {
205         // Can be viewed as a reduction of a matrix along 2nd dimension.
206         Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data),
207                         constants.kOne, reducer);
208       } else if ((helper.ndims() == 3) && helper.reduce_first_axis()) {
209         // Can be viewed as a reduction of a 3D tensor along 1st and 3rd
210         // dimensions.
211         Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 3>(data),
212                         constants.kZeroTwo, reducer);
213       } else if ((helper.ndims() == 3) && !helper.reduce_first_axis()) {
214         // Can be viewed as a reduction of a 3D tensor along 2nd dimension.
215         Functor::Reduce(ctx, helper.out<T, 2>(&tmp_out), helper.in<T, 3>(data),
216                         constants.kOne, reducer);
217       } else {
218         // If we don't hit one of the cases above, transpose the data so that
219         // all reduced dimensions are last and reuse the 2-D -> 1-D case.
220         Tensor data_reshaped;
221         OP_REQUIRES(ctx, data_reshaped.CopyFrom(data, helper.data_reshape()),
222                     errors::Internal("Error during reduction copy."));
223         Tensor shuffled;
224         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
225                                                helper.shuffled_shape(),
226                                                &shuffled, alloc_attr));
227         OP_REQUIRES_OK(ctx, DoTranspose(d, data_reshaped, helper.permutation(),
228                                         &shuffled));
229         const int64_t unreduced = tmp_out.NumElements();
230         const int64_t reduced = shuffled.NumElements() / unreduced;
231         const Tensor& const_shuffled = shuffled;
232         Functor::Reduce(ctx, tmp_out.flat<T>(),
233                         const_shuffled.shaped<T, 2>({unreduced, reduced}),
234                         constants.kOne, reducer);
235       }
236     }
237 
238     // Set the real output using the contents of the reduction but the
239     // real expected output shape.  The number of elements should
240     // match between the two shapes.
241     Tensor out;
242     OP_REQUIRES(ctx, out.CopyFrom(tmp_out, helper.out_shape()),
243                 errors::Internal("Error during reduction copy."));
244     ctx->set_output(0, out);
245   }
246 
247  private:
248   // True if the number of dimensions should be maintained.
249   bool keep_dims_;
250 };
251 
252 namespace functor {
253 
254 template <typename Device, typename Reducer>
255 struct ReduceFunctorBase {
256   template <typename OUT_T, typename IN_T, typename ReductionAxes>
257   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
258                      const ReductionAxes& reduction_axes,
259                      const Reducer& reducer) {
260     const Device& d = ctx->eigen_device<Device>();
261     ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, Reducer> reducer_impl;
262     reducer_impl(d, out, in, reduction_axes, reducer);
263   }
264 
265   template <typename OUT_T>
266   static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer) {
267     FillIdentityEigenImpl(d, out, reducer);
268   }
269 };
270 
271 template <typename Reducer>
272 struct ReduceFunctor<CPUDevice, Reducer>
273     : ReduceFunctorBase<CPUDevice, Reducer> {};
274 
275 }  // namespace functor
276 }  // namespace tensorflow
277 
278 #endif  // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_
279