1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/hexagon/utils.h"
16
17 #include <vector>
18
19 #include "tensorflow/lite/builtin_ops.h"
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23
24 namespace tflite {
25 namespace {
26
IsActivationReluOrNone(TfLiteFusedActivation activation)27 bool IsActivationReluOrNone(TfLiteFusedActivation activation) {
28 return (activation == kTfLiteActRelu || activation == kTfLiteActRelu6 ||
29 activation == kTfLiteActReluN1To1 || activation == kTfLiteActNone);
30 }
31
TensorTypeMatch(int tensor_id,TfLiteContext * context,TfLiteType tensor_type)32 bool TensorTypeMatch(int tensor_id, TfLiteContext* context,
33 TfLiteType tensor_type) {
34 const auto& tensor = context->tensors[tensor_id];
35 return tensor.type == tensor_type;
36 }
37
38 // For each input tensor i, checks if the type matches one of the possibilities
39 // in per_input_possible_types[i].
InputsWithCorrectTypes(const TfLiteNode * node,TfLiteContext * context,const std::vector<std::vector<TfLiteType>> & per_input_possible_types)40 bool InputsWithCorrectTypes(
41 const TfLiteNode* node, TfLiteContext* context,
42 const std::vector<std::vector<TfLiteType>>& per_input_possible_types) {
43 if (node->inputs->size != per_input_possible_types.size()) return false;
44 for (int i = 0; i < per_input_possible_types.size(); ++i) {
45 // Skip optional tensor.
46 if (node->inputs->data[i] == -1) continue;
47 bool type_found = false;
48 for (auto possible_type : per_input_possible_types[i]) {
49 if (TensorTypeMatch(node->inputs->data[i], context, possible_type)) {
50 type_found = true;
51 break;
52 }
53 }
54 if (!type_found) return false;
55 }
56 return true;
57 }
58
59 } // namespace
60
Get4DShape(unsigned int * batch_size,unsigned int * height_size,unsigned int * width_size,unsigned int * depth_size,TfLiteIntArray * dims)61 TfLiteStatus Get4DShape(unsigned int* batch_size, unsigned int* height_size,
62 unsigned int* width_size, unsigned int* depth_size,
63 TfLiteIntArray* dims) {
64 if (dims->size > 4) return kTfLiteError;
65 unsigned int* dim[] = {batch_size, height_size, width_size, depth_size};
66 for (int i = 0; i < 4; ++i) *(dim[i]) = 1;
67 for (int i = 4 - dims->size; i < 4; ++i) {
68 *dim[i] = dims->data[i - (4 - dims->size)];
69 }
70 return kTfLiteOk;
71 }
72
73 // We maintain an op-version allowlist here to ensure we don't accept unintended
74 // ops.
CheckOpVersion(const TfLiteRegistration * registration)75 bool CheckOpVersion(const TfLiteRegistration* registration) {
76 switch (registration->builtin_code) {
77 case kTfLiteBuiltinAdd:
78 case kTfLiteBuiltinArgMax:
79 case kTfLiteBuiltinArgMin:
80 case kTfLiteBuiltinAveragePool2d:
81 case kTfLiteBuiltinConcatenation:
82 case kTfLiteBuiltinL2Normalization:
83 case kTfLiteBuiltinLogistic:
84 case kTfLiteBuiltinMaximum:
85 case kTfLiteBuiltinMaxPool2d:
86 case kTfLiteBuiltinMean:
87 case kTfLiteBuiltinMinimum:
88 case kTfLiteBuiltinMirrorPad:
89 case kTfLiteBuiltinMul:
90 case kTfLiteBuiltinPack:
91 case kTfLiteBuiltinPad:
92 case kTfLiteBuiltinQuantize:
93 case kTfLiteBuiltinRelu6:
94 case kTfLiteBuiltinSlice:
95 case kTfLiteBuiltinSoftmax:
96 case kTfLiteBuiltinSpaceToDepth:
97 case kTfLiteBuiltinDepthToSpace:
98 case kTfLiteBuiltinSplit:
99 case kTfLiteBuiltinStridedSlice:
100 case kTfLiteBuiltinSub:
101 case kTfLiteBuiltinTanh:
102 case kTfLiteBuiltinTranspose:
103 return registration->version <= 2;
104 case kTfLiteBuiltinSquaredDifference:
105 case kTfLiteBuiltinRelu:
106 case kTfLiteBuiltinRsqrt:
107 return registration->version == 2;
108 case kTfLiteBuiltinConv2d:
109 case kTfLiteBuiltinDepthwiseConv2d:
110 case kTfLiteBuiltinResizeBilinear:
111 case kTfLiteBuiltinResizeNearestNeighbor:
112 case kTfLiteBuiltinTransposeConv:
113 return registration->version <= 3;
114 case kTfLiteBuiltinFullyConnected:
115 return registration->version <= 4;
116 default:
117 return registration->version == 1;
118 }
119 }
120
IsNodeSupportedByHexagon(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)121 bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
122 const TfLiteNode* node, TfLiteContext* context) {
123 // Ensure all inputs & outputs have dim <= 4.
124 int tensor_id;
125 for (int i = 0; i < node->inputs->size; ++i) {
126 tensor_id = node->inputs->data[i];
127 // Skip optional tensors. Builders should handle optional tensors
128 // not available.
129 if (tensor_id == -1) continue;
130 const auto& tensor = context->tensors[tensor_id];
131 if (tensor.dims->size > 4) return false;
132 }
133 for (int i = 0; i < node->outputs->size; ++i) {
134 tensor_id = node->outputs->data[i];
135 const auto& tensor = context->tensors[tensor_id];
136 if (tensor.dims->size > 4) return false;
137 }
138
139 if (!CheckOpVersion(registration)) return false;
140
141 switch (registration->builtin_code) {
142 case kTfLiteBuiltinAdd: {
143 if (!InputsWithCorrectTypes(
144 node, context,
145 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
146 return false;
147 const TfLiteAddParams* add_params =
148 reinterpret_cast<const TfLiteAddParams*>(node->builtin_data);
149 return IsActivationReluOrNone(add_params->activation);
150 }
151 case kTfLiteBuiltinMul: {
152 if (!InputsWithCorrectTypes(
153 node, context,
154 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
155 return false;
156 const TfLiteMulParams* mul_params =
157 reinterpret_cast<const TfLiteMulParams*>(node->builtin_data);
158 // TODO(b/129276536): Add support for activation on Mul node.
159 return IsActivationReluOrNone(mul_params->activation);
160 }
161 case kTfLiteBuiltinSub: {
162 if (!InputsWithCorrectTypes(
163 node, context,
164 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
165 return false;
166 const TfLiteSubParams* sub_params =
167 reinterpret_cast<const TfLiteSubParams*>(node->builtin_data);
168 return IsActivationReluOrNone(sub_params->activation);
169 }
170 case kTfLiteBuiltinSum:
171 // TODO(b/139277813): Enable these when they pass unit tests. These seem
172 // to recompute the output min/max instead of taking them as inputs, which
173 // causes an unexpected shift in dequantized values.
174 return false;
175 case kTfLiteBuiltinMean: {
176 return InputsWithCorrectTypes(
177 node, context,
178 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
179 IsConstantTensor(GetInput(context, node, 1));
180 }
181 case kTfLiteBuiltinMirrorPad: {
182 if (!InputsWithCorrectTypes(
183 node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) ||
184 !IsConstantTensor(GetInput(context, node, 1)))
185 return false;
186 const TfLiteMirrorPaddingParams* params =
187 reinterpret_cast<const TfLiteMirrorPaddingParams*>(
188 node->builtin_data);
189 return params->mode == kTfLiteMirrorPaddingReflect ||
190 params->mode == kTfLiteMirrorPaddingSymmetric;
191 }
192 case kTfLiteBuiltinPad: {
193 // TODO(b/139277813): Currently we only support padding with the default
194 // of 0. Add support for user-defined constant if required.
195 return (
196 node->inputs->size == 2 &&
197 InputsWithCorrectTypes(
198 node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
199 IsConstantTensor(GetInput(context, node, 1)));
200 }
201 case kTfLiteBuiltinFullyConnected: {
202 if (!InputsWithCorrectTypes(node, context,
203 {{kTfLiteUInt8, kTfLiteInt8},
204 {kTfLiteUInt8, kTfLiteInt8},
205 {kTfLiteInt32, kTfLiteNoType}})) {
206 return false;
207 }
208
209 bool bias_const_or_no_bias = true;
210 if (node->inputs->data[2] != -1) {
211 const auto& bias_tensor = context->tensors[node->inputs->data[2]];
212 bias_const_or_no_bias = bias_tensor.allocation_type == kTfLiteMmapRo;
213 }
214
215 const TfLiteFullyConnectedParams* matmul_params =
216 reinterpret_cast<const TfLiteFullyConnectedParams*>(
217 node->builtin_data);
218 return (bias_const_or_no_bias &&
219 IsActivationReluOrNone(matmul_params->activation) &&
220 matmul_params->keep_num_dims == false &&
221 matmul_params->weights_format ==
222 kTfLiteFullyConnectedWeightsFormatDefault);
223 }
224 case kTfLiteBuiltinConcatenation: {
225 // All concatenated tensors must be 8-bit.
226 for (int i = 0; i < node->inputs->size; ++i) {
227 if (!TensorTypeMatch(node->inputs->data[i], context, kTfLiteUInt8) &&
228 !TensorTypeMatch(node->inputs->data[i], context, kTfLiteInt8))
229 return false;
230 }
231 return true;
232 }
233 case kTfLiteBuiltinMaxPool2d: {
234 if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
235 return false;
236 // TODO(b/129276536): Add support for activation here.
237 const TfLitePoolParams* pool_params =
238 reinterpret_cast<const TfLitePoolParams*>(node->builtin_data);
239 // Disable max pool on delegate with activation SAME when filter is > 12.
240 if (pool_params->padding == kTfLitePaddingSame &&
241 (pool_params->filter_height >= 13 ||
242 pool_params->filter_width >= 13)) {
243 return false;
244 }
245 return pool_params->activation == kTfLiteActNone;
246 }
247 case kTfLiteBuiltinAveragePool2d: {
248 if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
249 return false;
250 const TfLitePoolParams* pool_params =
251 reinterpret_cast<const TfLitePoolParams*>(node->builtin_data);
252 return (node->inputs->size == 1 &&
253 pool_params->activation == kTfLiteActNone);
254 }
255 case kTfLiteBuiltinTransposeConv: {
256 if (NumInputs(node) == 3) {
257 if (!InputsWithCorrectTypes(node, context,
258 {{kTfLiteInt32},
259 {kTfLiteUInt8, kTfLiteInt8},
260 {kTfLiteUInt8, kTfLiteInt8}}))
261 return false;
262 } else if (NumInputs(node) == 4) {
263 if (!InputsWithCorrectTypes(node, context,
264 {{kTfLiteInt32},
265 {kTfLiteUInt8, kTfLiteInt8},
266 {kTfLiteUInt8, kTfLiteInt8},
267 {kTfLiteInt32}}))
268 return false;
269 } else {
270 return false;
271 }
272 const TfLiteTransposeConvParams* params =
273 reinterpret_cast<const TfLiteTransposeConvParams*>(
274 node->builtin_data);
275 return (params->stride_height <= 3 && params->stride_width <= 3 &&
276 (params->padding == kTfLitePaddingSame ||
277 params->padding == kTfLitePaddingValid));
278 }
279 case kTfLiteBuiltinConv2d: {
280 if (!InputsWithCorrectTypes(node, context,
281 {{kTfLiteUInt8, kTfLiteInt8},
282 {kTfLiteUInt8, kTfLiteInt8},
283 {kTfLiteInt32}}))
284 return false;
285 const TfLiteConvParams* conv_params =
286 reinterpret_cast<const TfLiteConvParams*>(node->builtin_data);
287 return (IsActivationReluOrNone(conv_params->activation) &&
288 conv_params->stride_height <= 3 &&
289 conv_params->stride_width <= 3 &&
290 conv_params->dilation_height_factor == 1 &&
291 conv_params->dilation_width_factor == 1);
292 }
293 case kTfLiteBuiltinDepthwiseConv2d: {
294 if (!InputsWithCorrectTypes(node, context,
295 {{kTfLiteUInt8, kTfLiteInt8},
296 {kTfLiteUInt8, kTfLiteInt8},
297 {kTfLiteInt32}}))
298 return false;
299
300 // Check dilation.
301 const TfLiteDepthwiseConvParams* conv_params =
302 reinterpret_cast<const TfLiteDepthwiseConvParams*>(
303 node->builtin_data);
304 const bool dilation = conv_params->dilation_height_factor != 1 ||
305 conv_params->dilation_width_factor != 1;
306 if (dilation) {
307 // We only support dilations when stride == 1.
308 if (conv_params->stride_height != 1 || conv_params->stride_width != 1)
309 return false;
310 }
311
312 // We currently only support depth_multiplier > 1 when:
313 // 1. dilation_factor == 1 AND
314 // 2. input_depth == 1
315 // TODO(b/143759564): Add support for general case.
316 const auto& input = context->tensors[node->inputs->data[0]];
317 const bool supported_depth_multiplier =
318 conv_params->depth_multiplier == 1 ||
319 (!dilation && input.dims->size == 4 && input.dims->data[3] == 1);
320
321 // Hexagon only supports filter height >= 2.
322 const auto& weights = context->tensors[node->inputs->data[1]];
323 const bool filter_height_not_supported =
324 (weights.dims->size >= 2 && weights.dims->data[1] < 2);
325
326 return (IsActivationReluOrNone(conv_params->activation) &&
327 conv_params->stride_height <= 3 &&
328 conv_params->stride_width <= 3 && supported_depth_multiplier &&
329 !filter_height_not_supported);
330 }
331 case kTfLiteBuiltinReshape: {
332 if (node->inputs->size > 2 ||
333 (!TensorTypeMatch(node->inputs->data[0], context, kTfLiteUInt8) &&
334 !TensorTypeMatch(node->inputs->data[0], context, kTfLiteInt8)))
335 return false;
336 return true;
337 }
338 case kTfLiteBuiltinSoftmax: {
339 return (
340 InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}));
341 }
342 case kTfLiteBuiltinHardSwish:
343 case kTfLiteBuiltinRelu:
344 case kTfLiteBuiltinRelu6:
345 case kTfLiteBuiltinTanh:
346 case kTfLiteBuiltinLogistic: {
347 return InputsWithCorrectTypes(node, context,
348 {{kTfLiteUInt8, kTfLiteInt8}});
349 }
350 case kTfLiteBuiltinResizeNearestNeighbor: {
351 return InputsWithCorrectTypes(
352 node, context,
353 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
354 IsConstantTensor(GetInput(context, node, 1));
355 }
356 case kTfLiteBuiltinL2Normalization: {
357 if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
358 return false;
359 const TfLiteL2NormParams* norm_params =
360 reinterpret_cast<const TfLiteL2NormParams*>(node->builtin_data);
361 return (norm_params->activation == kTfLiteActNone);
362 }
363 case kTfLiteBuiltinArgMax:
364 case kTfLiteBuiltinArgMin:
365 return InputsWithCorrectTypes(
366 node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}});
367 case kTfLiteBuiltinSplit: {
368 if (!InputsWithCorrectTypes(
369 node, context, {{kTfLiteInt32}, {kTfLiteUInt8, kTfLiteInt8}}))
370 return false;
371 const auto& input_tensor = context->tensors[node->inputs->data[1]];
372 const bool is_four_dim_or_less = input_tensor.dims->size < 5;
373 // We need splitting axis to be constant, so Hexagon knows output
374 // shapes.
375 return is_four_dim_or_less &&
376 IsConstantTensor(GetInput(context, node, 0));
377 }
378 case kTfLiteBuiltinResizeBilinear: {
379 if (!InputsWithCorrectTypes(
380 node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) ||
381 !IsConstantTensor(GetInput(context, node, 1))) {
382 return false;
383 }
384 const auto& size_tensor = context->tensors[node->inputs->data[1]];
385 return NumElements(&size_tensor) == 2;
386 }
387 case kTfLiteBuiltinNeg: {
388 return InputsWithCorrectTypes(node, context,
389 {{kTfLiteUInt8, kTfLiteInt8}});
390 }
391 case kTfLiteBuiltinTranspose: {
392 return InputsWithCorrectTypes(
393 node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}});
394 }
395 case kTfLiteBuiltinSpaceToDepth:
396 case kTfLiteBuiltinDepthToSpace: {
397 return InputsWithCorrectTypes(node, context,
398 {{kTfLiteUInt8, kTfLiteInt8}});
399 }
400 case kTfLiteBuiltinQuantize: {
401 return InputsWithCorrectTypes(node, context,
402 {{kTfLiteUInt8, kTfLiteInt8}});
403 }
404 case kTfLiteBuiltinMinimum: {
405 return InputsWithCorrectTypes(
406 node, context,
407 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}});
408 }
409 case kTfLiteBuiltinMaximum: {
410 return InputsWithCorrectTypes(
411 node, context,
412 {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}});
413 }
414 case kTfLiteBuiltinSlice: {
415 const auto& begins_tensor = context->tensors[node->inputs->data[1]];
416 const auto& sizes_tensor = context->tensors[node->inputs->data[2]];
417 if (!IsConstantTensor(&begins_tensor) || !IsConstantTensor(&sizes_tensor))
418 return false;
419 return InputsWithCorrectTypes(node, context,
420 {{kTfLiteUInt8, kTfLiteInt8},
421 {kTfLiteInt32, kTfLiteInt64},
422 {kTfLiteInt32, kTfLiteInt64}});
423 }
424 case kTfLiteBuiltinPack: {
425 // All tensors must be 8-bit.
426 for (int i = 0; i < node->inputs->size; ++i) {
427 if (!TensorTypeMatch(node->inputs->data[i], context, kTfLiteUInt8) &&
428 !TensorTypeMatch(node->inputs->data[i], context, kTfLiteInt8))
429 return false;
430 }
431 return true;
432 }
433 case kTfLiteBuiltinStridedSlice: {
434 if (!InputsWithCorrectTypes(node, context,
435 {{kTfLiteUInt8, kTfLiteInt8},
436 {kTfLiteInt32},
437 {kTfLiteInt32},
438 {kTfLiteInt32}}))
439 return false;
440 const auto& begins_tensor = context->tensors[node->inputs->data[1]];
441 const auto& ends_tensor = context->tensors[node->inputs->data[2]];
442 const auto& step_tensor = context->tensors[node->inputs->data[3]];
443 if (!IsConstantTensor(&begins_tensor) ||
444 !IsConstantTensor(&ends_tensor) || !IsConstantTensor(&step_tensor))
445 return false;
446 const TfLiteStridedSliceParams* params =
447 reinterpret_cast<const TfLiteStridedSliceParams*>(node->builtin_data);
448 // Hexagon doesn't support ellipsis/new-axis masks.
449 return (params->ellipsis_mask == 0 && params->new_axis_mask == 0);
450 }
451 case kTfLiteBuiltinSquaredDifference: {
452 return InputsWithCorrectTypes(node, context,
453 {{kTfLiteInt8}, {kTfLiteInt8}});
454 }
455 case kTfLiteBuiltinRsqrt: {
456 return InputsWithCorrectTypes(node, context, {{kTfLiteInt8}});
457 }
458 default:
459 return false;
460 }
461 return false;
462 }
463
464 } // namespace tflite
465