xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <string>
17 
18 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 #include "tensorflow/lite/kernels/padding.h"
27 
28 namespace tflite {
29 namespace ops {
30 namespace custom {
31 namespace max_pool_with_argmax {
32 namespace {
33 // TODO(b/175003241): Move this logic to lite/kernels/internal when promoting
34 // this op to a builtin op.
35 template <typename T>
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,const T * input_data,T * output_data,int32_t * indices_data)36 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
37                     const RuntimeShape& output_shape, const T* input_data,
38                     T* output_data, int32_t* indices_data) {
39   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
40   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
41 
42   const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
43   const int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
44   const int32_t input_height = input_shape.Dims(1);
45   const int32_t input_width = input_shape.Dims(2);
46   const int32_t output_height = output_shape.Dims(1);
47   const int32_t output_width = output_shape.Dims(2);
48   const int32_t stride_height = params.stride_height;
49   const int32_t stride_width = params.stride_width;
50   for (int32_t batch = 0; batch < batches; ++batch) {
51     for (int32_t out_y = 0; out_y < output_height; ++out_y) {
52       for (int32_t out_x = 0; out_x < output_width; ++out_x) {
53         for (int32_t channel = 0; channel < depth; ++channel) {
54           const int32_t in_x_origin =
55               (out_x * stride_width) - params.padding_values.width;
56           const int32_t in_y_origin =
57               (out_y * stride_height) - params.padding_values.height;
58           // Compute the boundaries of the filter region clamped so as to
59           // ensure that the filter window fits in the input array.
60           const int32_t filter_x_start = std::max(0, -in_x_origin);
61           const int32_t filter_x_end =
62               std::min(params.filter_width, input_width - in_x_origin);
63           const int32_t filter_y_start = std::max(0, -in_y_origin);
64           const int32_t filter_y_end =
65               std::min(params.filter_height, input_height - in_y_origin);
66           float max = std::numeric_limits<float>::lowest();
67           int32_t max_x = 0;
68           int32_t max_y = 0;
69 
70           for (int32_t filter_y = filter_y_start; filter_y < filter_y_end;
71                ++filter_y) {
72             for (int32_t filter_x = filter_x_start; filter_x < filter_x_end;
73                  ++filter_x) {
74               const int32_t in_x = in_x_origin + filter_x;
75               const int32_t in_y = in_y_origin + filter_y;
76               float cur =
77                   input_data[Offset(input_shape, batch, in_y, in_x, channel)];
78               if (cur > max) {
79                 max = cur;
80                 max_x = in_x;
81                 max_y = in_y;
82               }
83             }
84           }
85           int32_t output_idx =
86               Offset(output_shape, batch, out_y, out_x, channel);
87           output_data[output_idx] = ActivationFunctionWithMinMax(
88               max, params.float_activation_min, params.float_activation_max);
89           indices_data[output_idx] =
90               (max_y * input_width + max_x) * depth + channel;
91         }
92       }
93     }
94   }
95 }
96 
97 }  // namespace
98 
99 constexpr int kDataInputTensor = 0;
100 constexpr int kDataOutputTensor = 0;
101 constexpr int kIndicesOutputTensor = 1;
102 
103 constexpr const char kIncludeBatchStr[] = "include_batch_in_index";
104 constexpr const char kPoolSizeStr[] = "ksize";
105 constexpr const char kStridesStr[] = "strides";
106 constexpr const char kPaddingStr[] = "padding";
107 constexpr const char kPaddingSameStr[] = "SAME";
108 constexpr const char kPaddingValidStr[] = "VALID";
109 
110 struct OpData {
111   TfLitePoolParams params;
112   bool include_batch_in_index;
113 };
114 
Init(TfLiteContext * context,const char * buffer,size_t length)115 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
116   const flexbuffers::Map& m =
117       flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
118           .AsMap();
119 
120   OpData* op_data = new OpData;
121   op_data->params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
122   op_data->include_batch_in_index = m[kIncludeBatchStr].AsBool();
123   op_data->params.activation = kTfLiteActNone;
124 
125   const std::string padding = m[kPaddingStr].AsString().str();
126   if (padding == kPaddingValidStr) {
127     op_data->params.padding = kTfLitePaddingValid;
128   } else if (padding == kPaddingSameStr) {
129     op_data->params.padding = kTfLitePaddingSame;
130   } else {
131     op_data->params.padding = kTfLitePaddingUnknown;
132   }
133 
134   // The first and last element of pool_size are always 1.
135   const auto pool_size = m[kPoolSizeStr].AsTypedVector();
136   TFLITE_CHECK_EQ(pool_size.size(), 4);
137   TFLITE_CHECK_EQ(pool_size[0].AsInt32(), 1);
138   TFLITE_CHECK_EQ(pool_size[3].AsInt32(), 1);
139   op_data->params.filter_height = pool_size[1].AsInt32();
140   op_data->params.filter_width = pool_size[2].AsInt32();
141 
142   // The first and last element of strides are always 1.
143   const auto strides = m[kStridesStr].AsTypedVector();
144   TFLITE_CHECK_EQ(strides.size(), 4);
145   TFLITE_CHECK_EQ(strides[0].AsInt32(), 1);
146   TFLITE_CHECK_EQ(strides[3].AsInt32(), 1);
147   op_data->params.stride_height = strides[1].AsInt32();
148   op_data->params.stride_width = strides[2].AsInt32();
149 
150   return op_data;
151 }
152 
Free(TfLiteContext * context,void * buffer)153 void Free(TfLiteContext* context, void* buffer) {
154   delete reinterpret_cast<OpData*>(buffer);
155 }
156 
Prepare(TfLiteContext * context,TfLiteNode * node)157 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
158   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
159 
160   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
161   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
162   TfLiteTensor *output, *indices;
163   TF_LITE_ENSURE_OK(context,
164                     GetOutputSafe(context, node, kDataOutputTensor, &output));
165   TF_LITE_ENSURE_OK(
166       context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
167   const TfLiteTensor* input;
168   TF_LITE_ENSURE_OK(context,
169                     GetInputSafe(context, node, kDataInputTensor, &input));
170   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
171   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
172   TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
173   TF_LITE_ENSURE(context, indices->type == kTfLiteInt32);
174   TF_LITE_ENSURE(context, op_data->params.padding != kTfLitePaddingUnknown);
175   TF_LITE_ENSURE_MSG(
176       context, !op_data->include_batch_in_index,
177       "Include batch dimension in flattened index is not yet supported.");
178 
179   int batches = input->dims->data[0];
180   int height = input->dims->data[1];
181   int width = input->dims->data[2];
182   int channels_out = input->dims->data[3];
183 
184   // Matching GetWindowedOutputSize in TensorFlow.
185   int out_width, out_height;
186   op_data->params.computed.padding = ComputePaddingHeightWidth(
187       op_data->params.stride_height, op_data->params.stride_width, 1, 1, height,
188       width, op_data->params.filter_height, op_data->params.filter_width,
189       op_data->params.padding, &out_height, &out_width);
190 
191   TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
192   output_size->data[0] = batches;
193   output_size->data[1] = out_height;
194   output_size->data[2] = out_width;
195   output_size->data[3] = channels_out;
196   TfLiteIntArray* indices_size = TfLiteIntArrayCopy(output_size);
197 
198   TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, indices, indices_size));
199   return context->ResizeTensor(context, output, output_size);
200 }
201 
Eval(TfLiteContext * context,TfLiteNode * node)202 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
203   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
204 
205   float activation_min, activation_max;
206   CalculateActivationRange(op_data->params.activation, &activation_min,
207                            &activation_max);
208 
209   tflite::PoolParams op_params;
210   op_params.stride_height = op_data->params.stride_height;
211   op_params.stride_width = op_data->params.stride_width;
212   op_params.filter_height = op_data->params.filter_height;
213   op_params.filter_width = op_data->params.filter_width;
214   op_params.padding_values.height = op_data->params.computed.padding.height;
215   op_params.padding_values.width = op_data->params.computed.padding.width;
216   op_params.float_activation_min = activation_min;
217   op_params.float_activation_max = activation_max;
218 
219   TfLiteTensor *output, *indices;
220   TF_LITE_ENSURE_OK(context,
221                     GetOutputSafe(context, node, kDataOutputTensor, &output));
222   TF_LITE_ENSURE_OK(
223       context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
224   const TfLiteTensor* input;
225   TF_LITE_ENSURE_OK(context,
226                     GetInputSafe(context, node, kDataInputTensor, &input));
227 
228   switch (input->type) {
229     case kTfLiteFloat32:
230       MaxPool<float>(op_params, GetTensorShape(input), GetTensorShape(output),
231                      GetTensorData<float>(input), GetTensorData<float>(output),
232                      GetTensorData<int32_t>(indices));
233       break;
234     default:
235       TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
236                          TfLiteTypeGetName(input->type));
237       return kTfLiteError;
238   }
239   return kTfLiteOk;
240 }
241 }  // namespace max_pool_with_argmax
242 
RegisterMaxPoolWithArgmax()243 TfLiteRegistration* RegisterMaxPoolWithArgmax() {
244   static TfLiteRegistration r = {
245       max_pool_with_argmax::Init, max_pool_with_argmax::Free,
246       max_pool_with_argmax::Prepare, max_pool_with_argmax::Eval};
247   return &r;
248 }
249 
250 // Alias for selective build.
Register_MAX_POOL_WITH_ARGMAX()251 TfLiteRegistration* Register_MAX_POOL_WITH_ARGMAX() {
252   return RegisterMaxPoolWithArgmax();
253 }
254 
255 }  // namespace custom
256 }  // namespace ops
257 }  // namespace tflite
258