1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <string>
17
18 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 #include "tensorflow/lite/kernels/padding.h"
27
28 namespace tflite {
29 namespace ops {
30 namespace custom {
31 namespace max_pool_with_argmax {
32 namespace {
33 // TODO(b/175003241): Move this logic to lite/kernels/internal when promoting
34 // this op to a builtin op.
35 template <typename T>
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,const T * input_data,T * output_data,int32_t * indices_data)36 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
37 const RuntimeShape& output_shape, const T* input_data,
38 T* output_data, int32_t* indices_data) {
39 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
40 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
41
42 const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
43 const int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
44 const int32_t input_height = input_shape.Dims(1);
45 const int32_t input_width = input_shape.Dims(2);
46 const int32_t output_height = output_shape.Dims(1);
47 const int32_t output_width = output_shape.Dims(2);
48 const int32_t stride_height = params.stride_height;
49 const int32_t stride_width = params.stride_width;
50 for (int32_t batch = 0; batch < batches; ++batch) {
51 for (int32_t out_y = 0; out_y < output_height; ++out_y) {
52 for (int32_t out_x = 0; out_x < output_width; ++out_x) {
53 for (int32_t channel = 0; channel < depth; ++channel) {
54 const int32_t in_x_origin =
55 (out_x * stride_width) - params.padding_values.width;
56 const int32_t in_y_origin =
57 (out_y * stride_height) - params.padding_values.height;
58 // Compute the boundaries of the filter region clamped so as to
59 // ensure that the filter window fits in the input array.
60 const int32_t filter_x_start = std::max(0, -in_x_origin);
61 const int32_t filter_x_end =
62 std::min(params.filter_width, input_width - in_x_origin);
63 const int32_t filter_y_start = std::max(0, -in_y_origin);
64 const int32_t filter_y_end =
65 std::min(params.filter_height, input_height - in_y_origin);
66 float max = std::numeric_limits<float>::lowest();
67 int32_t max_x = 0;
68 int32_t max_y = 0;
69
70 for (int32_t filter_y = filter_y_start; filter_y < filter_y_end;
71 ++filter_y) {
72 for (int32_t filter_x = filter_x_start; filter_x < filter_x_end;
73 ++filter_x) {
74 const int32_t in_x = in_x_origin + filter_x;
75 const int32_t in_y = in_y_origin + filter_y;
76 float cur =
77 input_data[Offset(input_shape, batch, in_y, in_x, channel)];
78 if (cur > max) {
79 max = cur;
80 max_x = in_x;
81 max_y = in_y;
82 }
83 }
84 }
85 int32_t output_idx =
86 Offset(output_shape, batch, out_y, out_x, channel);
87 output_data[output_idx] = ActivationFunctionWithMinMax(
88 max, params.float_activation_min, params.float_activation_max);
89 indices_data[output_idx] =
90 (max_y * input_width + max_x) * depth + channel;
91 }
92 }
93 }
94 }
95 }
96
97 } // namespace
98
99 constexpr int kDataInputTensor = 0;
100 constexpr int kDataOutputTensor = 0;
101 constexpr int kIndicesOutputTensor = 1;
102
103 constexpr const char kIncludeBatchStr[] = "include_batch_in_index";
104 constexpr const char kPoolSizeStr[] = "ksize";
105 constexpr const char kStridesStr[] = "strides";
106 constexpr const char kPaddingStr[] = "padding";
107 constexpr const char kPaddingSameStr[] = "SAME";
108 constexpr const char kPaddingValidStr[] = "VALID";
109
110 struct OpData {
111 TfLitePoolParams params;
112 bool include_batch_in_index;
113 };
114
Init(TfLiteContext * context,const char * buffer,size_t length)115 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
116 const flexbuffers::Map& m =
117 flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
118 .AsMap();
119
120 OpData* op_data = new OpData;
121 op_data->params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
122 op_data->include_batch_in_index = m[kIncludeBatchStr].AsBool();
123 op_data->params.activation = kTfLiteActNone;
124
125 const std::string padding = m[kPaddingStr].AsString().str();
126 if (padding == kPaddingValidStr) {
127 op_data->params.padding = kTfLitePaddingValid;
128 } else if (padding == kPaddingSameStr) {
129 op_data->params.padding = kTfLitePaddingSame;
130 } else {
131 op_data->params.padding = kTfLitePaddingUnknown;
132 }
133
134 // The first and last element of pool_size are always 1.
135 const auto pool_size = m[kPoolSizeStr].AsTypedVector();
136 TFLITE_CHECK_EQ(pool_size.size(), 4);
137 TFLITE_CHECK_EQ(pool_size[0].AsInt32(), 1);
138 TFLITE_CHECK_EQ(pool_size[3].AsInt32(), 1);
139 op_data->params.filter_height = pool_size[1].AsInt32();
140 op_data->params.filter_width = pool_size[2].AsInt32();
141
142 // The first and last element of strides are always 1.
143 const auto strides = m[kStridesStr].AsTypedVector();
144 TFLITE_CHECK_EQ(strides.size(), 4);
145 TFLITE_CHECK_EQ(strides[0].AsInt32(), 1);
146 TFLITE_CHECK_EQ(strides[3].AsInt32(), 1);
147 op_data->params.stride_height = strides[1].AsInt32();
148 op_data->params.stride_width = strides[2].AsInt32();
149
150 return op_data;
151 }
152
Free(TfLiteContext * context,void * buffer)153 void Free(TfLiteContext* context, void* buffer) {
154 delete reinterpret_cast<OpData*>(buffer);
155 }
156
Prepare(TfLiteContext * context,TfLiteNode * node)157 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
158 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
159
160 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
161 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
162 TfLiteTensor *output, *indices;
163 TF_LITE_ENSURE_OK(context,
164 GetOutputSafe(context, node, kDataOutputTensor, &output));
165 TF_LITE_ENSURE_OK(
166 context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
167 const TfLiteTensor* input;
168 TF_LITE_ENSURE_OK(context,
169 GetInputSafe(context, node, kDataInputTensor, &input));
170 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
171 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
172 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
173 TF_LITE_ENSURE(context, indices->type == kTfLiteInt32);
174 TF_LITE_ENSURE(context, op_data->params.padding != kTfLitePaddingUnknown);
175 TF_LITE_ENSURE_MSG(
176 context, !op_data->include_batch_in_index,
177 "Include batch dimension in flattened index is not yet supported.");
178
179 int batches = input->dims->data[0];
180 int height = input->dims->data[1];
181 int width = input->dims->data[2];
182 int channels_out = input->dims->data[3];
183
184 // Matching GetWindowedOutputSize in TensorFlow.
185 int out_width, out_height;
186 op_data->params.computed.padding = ComputePaddingHeightWidth(
187 op_data->params.stride_height, op_data->params.stride_width, 1, 1, height,
188 width, op_data->params.filter_height, op_data->params.filter_width,
189 op_data->params.padding, &out_height, &out_width);
190
191 TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
192 output_size->data[0] = batches;
193 output_size->data[1] = out_height;
194 output_size->data[2] = out_width;
195 output_size->data[3] = channels_out;
196 TfLiteIntArray* indices_size = TfLiteIntArrayCopy(output_size);
197
198 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, indices, indices_size));
199 return context->ResizeTensor(context, output, output_size);
200 }
201
Eval(TfLiteContext * context,TfLiteNode * node)202 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
203 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
204
205 float activation_min, activation_max;
206 CalculateActivationRange(op_data->params.activation, &activation_min,
207 &activation_max);
208
209 tflite::PoolParams op_params;
210 op_params.stride_height = op_data->params.stride_height;
211 op_params.stride_width = op_data->params.stride_width;
212 op_params.filter_height = op_data->params.filter_height;
213 op_params.filter_width = op_data->params.filter_width;
214 op_params.padding_values.height = op_data->params.computed.padding.height;
215 op_params.padding_values.width = op_data->params.computed.padding.width;
216 op_params.float_activation_min = activation_min;
217 op_params.float_activation_max = activation_max;
218
219 TfLiteTensor *output, *indices;
220 TF_LITE_ENSURE_OK(context,
221 GetOutputSafe(context, node, kDataOutputTensor, &output));
222 TF_LITE_ENSURE_OK(
223 context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
224 const TfLiteTensor* input;
225 TF_LITE_ENSURE_OK(context,
226 GetInputSafe(context, node, kDataInputTensor, &input));
227
228 switch (input->type) {
229 case kTfLiteFloat32:
230 MaxPool<float>(op_params, GetTensorShape(input), GetTensorShape(output),
231 GetTensorData<float>(input), GetTensorData<float>(output),
232 GetTensorData<int32_t>(indices));
233 break;
234 default:
235 TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
236 TfLiteTypeGetName(input->type));
237 return kTfLiteError;
238 }
239 return kTfLiteOk;
240 }
241 } // namespace max_pool_with_argmax
242
RegisterMaxPoolWithArgmax()243 TfLiteRegistration* RegisterMaxPoolWithArgmax() {
244 static TfLiteRegistration r = {
245 max_pool_with_argmax::Init, max_pool_with_argmax::Free,
246 max_pool_with_argmax::Prepare, max_pool_with_argmax::Eval};
247 return &r;
248 }
249
250 // Alias for selective build.
Register_MAX_POOL_WITH_ARGMAX()251 TfLiteRegistration* Register_MAX_POOL_WITH_ARGMAX() {
252 return RegisterMaxPoolWithArgmax();
253 }
254
255 } // namespace custom
256 } // namespace ops
257 } // namespace tflite
258