xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/hexagon/utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/hexagon/utils.h"
16 
17 #include <vector>
18 
19 #include "tensorflow/lite/builtin_ops.h"
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace {
26 
IsActivationReluOrNone(TfLiteFusedActivation activation)27 bool IsActivationReluOrNone(TfLiteFusedActivation activation) {
28   return (activation == kTfLiteActRelu || activation == kTfLiteActRelu6 ||
29           activation == kTfLiteActReluN1To1 || activation == kTfLiteActNone);
30 }
31 
TensorTypeMatch(int tensor_id,TfLiteContext * context,TfLiteType tensor_type)32 bool TensorTypeMatch(int tensor_id, TfLiteContext* context,
33                      TfLiteType tensor_type) {
34   const auto& tensor = context->tensors[tensor_id];
35   return tensor.type == tensor_type;
36 }
37 
38 // For each input tensor i, checks if the type matches one of the possibilities
39 // in per_input_possible_types[i].
InputsWithCorrectTypes(const TfLiteNode * node,TfLiteContext * context,const std::vector<std::vector<TfLiteType>> & per_input_possible_types)40 bool InputsWithCorrectTypes(
41     const TfLiteNode* node, TfLiteContext* context,
42     const std::vector<std::vector<TfLiteType>>& per_input_possible_types) {
43   if (node->inputs->size != per_input_possible_types.size()) return false;
44   for (int i = 0; i < per_input_possible_types.size(); ++i) {
45     // Skip optional tensor.
46     if (node->inputs->data[i] == -1) continue;
47     bool type_found = false;
48     for (auto possible_type : per_input_possible_types[i]) {
49       if (TensorTypeMatch(node->inputs->data[i], context, possible_type)) {
50         type_found = true;
51         break;
52       }
53     }
54     if (!type_found) return false;
55   }
56   return true;
57 }
58 
59 }  // namespace
60 
Get4DShape(unsigned int * batch_size,unsigned int * height_size,unsigned int * width_size,unsigned int * depth_size,TfLiteIntArray * dims)61 TfLiteStatus Get4DShape(unsigned int* batch_size, unsigned int* height_size,
62                         unsigned int* width_size, unsigned int* depth_size,
63                         TfLiteIntArray* dims) {
64   if (dims->size > 4) return kTfLiteError;
65   unsigned int* dim[] = {batch_size, height_size, width_size, depth_size};
66   for (int i = 0; i < 4; ++i) *(dim[i]) = 1;
67   for (int i = 4 - dims->size; i < 4; ++i) {
68     *dim[i] = dims->data[i - (4 - dims->size)];
69   }
70   return kTfLiteOk;
71 }
72 
73 // We maintain an op-version allowlist here to ensure we don't accept unintended
74 // ops.
CheckOpVersion(const TfLiteRegistration * registration)75 bool CheckOpVersion(const TfLiteRegistration* registration) {
76   switch (registration->builtin_code) {
77     case kTfLiteBuiltinAdd:
78     case kTfLiteBuiltinArgMax:
79     case kTfLiteBuiltinArgMin:
80     case kTfLiteBuiltinAveragePool2d:
81     case kTfLiteBuiltinConcatenation:
82     case kTfLiteBuiltinL2Normalization:
83     case kTfLiteBuiltinLogistic:
84     case kTfLiteBuiltinMaximum:
85     case kTfLiteBuiltinMaxPool2d:
86     case kTfLiteBuiltinMean:
87     case kTfLiteBuiltinMinimum:
88     case kTfLiteBuiltinMirrorPad:
89     case kTfLiteBuiltinMul:
90     case kTfLiteBuiltinPack:
91     case kTfLiteBuiltinPad:
92     case kTfLiteBuiltinQuantize:
93     case kTfLiteBuiltinRelu6:
94     case kTfLiteBuiltinSlice:
95     case kTfLiteBuiltinSoftmax:
96     case kTfLiteBuiltinSpaceToDepth:
97     case kTfLiteBuiltinDepthToSpace:
98     case kTfLiteBuiltinSplit:
99     case kTfLiteBuiltinStridedSlice:
100     case kTfLiteBuiltinSub:
101     case kTfLiteBuiltinTanh:
102     case kTfLiteBuiltinTranspose:
103       return registration->version <= 2;
104     case kTfLiteBuiltinSquaredDifference:
105     case kTfLiteBuiltinRelu:
106     case kTfLiteBuiltinRsqrt:
107       return registration->version == 2;
108     case kTfLiteBuiltinConv2d:
109     case kTfLiteBuiltinDepthwiseConv2d:
110     case kTfLiteBuiltinResizeBilinear:
111     case kTfLiteBuiltinResizeNearestNeighbor:
112     case kTfLiteBuiltinTransposeConv:
113       return registration->version <= 3;
114     case kTfLiteBuiltinFullyConnected:
115       return registration->version <= 4;
116     default:
117       return registration->version == 1;
118   }
119 }
120 
IsNodeSupportedByHexagon(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)121 bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
122                               const TfLiteNode* node, TfLiteContext* context) {
123   // Ensure all inputs & outputs have dim <= 4.
124   int tensor_id;
125   for (int i = 0; i < node->inputs->size; ++i) {
126     tensor_id = node->inputs->data[i];
127     // Skip optional tensors. Builders should handle optional tensors
128     // not available.
129     if (tensor_id == -1) continue;
130     const auto& tensor = context->tensors[tensor_id];
131     if (tensor.dims->size > 4) return false;
132   }
133   for (int i = 0; i < node->outputs->size; ++i) {
134     tensor_id = node->outputs->data[i];
135     const auto& tensor = context->tensors[tensor_id];
136     if (tensor.dims->size > 4) return false;
137   }
138 
139   if (!CheckOpVersion(registration)) return false;
140 
141   switch (registration->builtin_code) {
142     case kTfLiteBuiltinAdd: {
143       if (!InputsWithCorrectTypes(
144               node, context,
145               {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
146         return false;
147       const TfLiteAddParams* add_params =
148           reinterpret_cast<const TfLiteAddParams*>(node->builtin_data);
149       return IsActivationReluOrNone(add_params->activation);
150     }
151     case kTfLiteBuiltinMul: {
152       if (!InputsWithCorrectTypes(
153               node, context,
154               {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
155         return false;
156       const TfLiteMulParams* mul_params =
157           reinterpret_cast<const TfLiteMulParams*>(node->builtin_data);
158       // TODO(b/129276536): Add support for activation on Mul node.
159       return IsActivationReluOrNone(mul_params->activation);
160     }
161     case kTfLiteBuiltinSub: {
162       if (!InputsWithCorrectTypes(
163               node, context,
164               {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}))
165         return false;
166       const TfLiteSubParams* sub_params =
167           reinterpret_cast<const TfLiteSubParams*>(node->builtin_data);
168       return IsActivationReluOrNone(sub_params->activation);
169     }
170     case kTfLiteBuiltinSum:
171       // TODO(b/139277813): Enable these when they pass unit tests. These seem
172       // to recompute the output min/max instead of taking them as inputs, which
173       // causes an unexpected shift in dequantized values.
174       return false;
175     case kTfLiteBuiltinMean: {
176       return InputsWithCorrectTypes(
177                  node, context,
178                  {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
179              IsConstantTensor(GetInput(context, node, 1));
180     }
181     case kTfLiteBuiltinMirrorPad: {
182       if (!InputsWithCorrectTypes(
183               node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) ||
184           !IsConstantTensor(GetInput(context, node, 1)))
185         return false;
186       const TfLiteMirrorPaddingParams* params =
187           reinterpret_cast<const TfLiteMirrorPaddingParams*>(
188               node->builtin_data);
189       return params->mode == kTfLiteMirrorPaddingReflect ||
190              params->mode == kTfLiteMirrorPaddingSymmetric;
191     }
192     case kTfLiteBuiltinPad: {
193       // TODO(b/139277813): Currently we only support padding with the default
194       // of 0. Add support for user-defined constant if required.
195       return (
196           node->inputs->size == 2 &&
197           InputsWithCorrectTypes(
198               node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
199           IsConstantTensor(GetInput(context, node, 1)));
200     }
201     case kTfLiteBuiltinFullyConnected: {
202       if (!InputsWithCorrectTypes(node, context,
203                                   {{kTfLiteUInt8, kTfLiteInt8},
204                                    {kTfLiteUInt8, kTfLiteInt8},
205                                    {kTfLiteInt32, kTfLiteNoType}})) {
206         return false;
207       }
208 
209       bool bias_const_or_no_bias = true;
210       if (node->inputs->data[2] != -1) {
211         const auto& bias_tensor = context->tensors[node->inputs->data[2]];
212         bias_const_or_no_bias = bias_tensor.allocation_type == kTfLiteMmapRo;
213       }
214 
215       const TfLiteFullyConnectedParams* matmul_params =
216           reinterpret_cast<const TfLiteFullyConnectedParams*>(
217               node->builtin_data);
218       return (bias_const_or_no_bias &&
219               IsActivationReluOrNone(matmul_params->activation) &&
220               matmul_params->keep_num_dims == false &&
221               matmul_params->weights_format ==
222                   kTfLiteFullyConnectedWeightsFormatDefault);
223     }
224     case kTfLiteBuiltinConcatenation: {
225       // All concatenated tensors must be 8-bit.
226       for (int i = 0; i < node->inputs->size; ++i) {
227         if (!TensorTypeMatch(node->inputs->data[i], context, kTfLiteUInt8) &&
228             !TensorTypeMatch(node->inputs->data[i], context, kTfLiteInt8))
229           return false;
230       }
231       return true;
232     }
233     case kTfLiteBuiltinMaxPool2d: {
234       if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
235         return false;
236       // TODO(b/129276536): Add support for activation here.
237       const TfLitePoolParams* pool_params =
238           reinterpret_cast<const TfLitePoolParams*>(node->builtin_data);
239       // Disable max pool on delegate with activation SAME when filter is > 12.
240       if (pool_params->padding == kTfLitePaddingSame &&
241           (pool_params->filter_height >= 13 ||
242            pool_params->filter_width >= 13)) {
243         return false;
244       }
245       return pool_params->activation == kTfLiteActNone;
246     }
247     case kTfLiteBuiltinAveragePool2d: {
248       if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
249         return false;
250       const TfLitePoolParams* pool_params =
251           reinterpret_cast<const TfLitePoolParams*>(node->builtin_data);
252       return (node->inputs->size == 1 &&
253               pool_params->activation == kTfLiteActNone);
254     }
255     case kTfLiteBuiltinTransposeConv: {
256       if (NumInputs(node) == 3) {
257         if (!InputsWithCorrectTypes(node, context,
258                                     {{kTfLiteInt32},
259                                      {kTfLiteUInt8, kTfLiteInt8},
260                                      {kTfLiteUInt8, kTfLiteInt8}}))
261           return false;
262       } else if (NumInputs(node) == 4) {
263         if (!InputsWithCorrectTypes(node, context,
264                                     {{kTfLiteInt32},
265                                      {kTfLiteUInt8, kTfLiteInt8},
266                                      {kTfLiteUInt8, kTfLiteInt8},
267                                      {kTfLiteInt32}}))
268           return false;
269       } else {
270         return false;
271       }
272       const TfLiteTransposeConvParams* params =
273           reinterpret_cast<const TfLiteTransposeConvParams*>(
274               node->builtin_data);
275       return (params->stride_height <= 3 && params->stride_width <= 3 &&
276               (params->padding == kTfLitePaddingSame ||
277                params->padding == kTfLitePaddingValid));
278     }
279     case kTfLiteBuiltinConv2d: {
280       if (!InputsWithCorrectTypes(node, context,
281                                   {{kTfLiteUInt8, kTfLiteInt8},
282                                    {kTfLiteUInt8, kTfLiteInt8},
283                                    {kTfLiteInt32}}))
284         return false;
285       const TfLiteConvParams* conv_params =
286           reinterpret_cast<const TfLiteConvParams*>(node->builtin_data);
287       return (IsActivationReluOrNone(conv_params->activation) &&
288               conv_params->stride_height <= 3 &&
289               conv_params->stride_width <= 3 &&
290               conv_params->dilation_height_factor == 1 &&
291               conv_params->dilation_width_factor == 1);
292     }
293     case kTfLiteBuiltinDepthwiseConv2d: {
294       if (!InputsWithCorrectTypes(node, context,
295                                   {{kTfLiteUInt8, kTfLiteInt8},
296                                    {kTfLiteUInt8, kTfLiteInt8},
297                                    {kTfLiteInt32}}))
298         return false;
299 
300       // Check dilation.
301       const TfLiteDepthwiseConvParams* conv_params =
302           reinterpret_cast<const TfLiteDepthwiseConvParams*>(
303               node->builtin_data);
304       const bool dilation = conv_params->dilation_height_factor != 1 ||
305                             conv_params->dilation_width_factor != 1;
306       if (dilation) {
307         // We only support dilations when stride == 1.
308         if (conv_params->stride_height != 1 || conv_params->stride_width != 1)
309           return false;
310       }
311 
312       // We currently only support depth_multiplier > 1 when:
313       // 1. dilation_factor == 1 AND
314       // 2. input_depth == 1
315       // TODO(b/143759564): Add support for general case.
316       const auto& input = context->tensors[node->inputs->data[0]];
317       const bool supported_depth_multiplier =
318           conv_params->depth_multiplier == 1 ||
319           (!dilation && input.dims->size == 4 && input.dims->data[3] == 1);
320 
321       // Hexagon only supports filter height >= 2.
322       const auto& weights = context->tensors[node->inputs->data[1]];
323       const bool filter_height_not_supported =
324           (weights.dims->size >= 2 && weights.dims->data[1] < 2);
325 
326       return (IsActivationReluOrNone(conv_params->activation) &&
327               conv_params->stride_height <= 3 &&
328               conv_params->stride_width <= 3 && supported_depth_multiplier &&
329               !filter_height_not_supported);
330     }
331     case kTfLiteBuiltinReshape: {
332       if (node->inputs->size > 2 ||
333           (!TensorTypeMatch(node->inputs->data[0], context, kTfLiteUInt8) &&
334            !TensorTypeMatch(node->inputs->data[0], context, kTfLiteInt8)))
335         return false;
336       return true;
337     }
338     case kTfLiteBuiltinSoftmax: {
339       return (
340           InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}));
341     }
342     case kTfLiteBuiltinHardSwish:
343     case kTfLiteBuiltinRelu:
344     case kTfLiteBuiltinRelu6:
345     case kTfLiteBuiltinTanh:
346     case kTfLiteBuiltinLogistic: {
347       return InputsWithCorrectTypes(node, context,
348                                     {{kTfLiteUInt8, kTfLiteInt8}});
349     }
350     case kTfLiteBuiltinResizeNearestNeighbor: {
351       return InputsWithCorrectTypes(
352                  node, context,
353                  {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) &&
354              IsConstantTensor(GetInput(context, node, 1));
355     }
356     case kTfLiteBuiltinL2Normalization: {
357       if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
358         return false;
359       const TfLiteL2NormParams* norm_params =
360           reinterpret_cast<const TfLiteL2NormParams*>(node->builtin_data);
361       return (norm_params->activation == kTfLiteActNone);
362     }
363     case kTfLiteBuiltinArgMax:
364     case kTfLiteBuiltinArgMin:
365       return InputsWithCorrectTypes(
366           node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}});
367     case kTfLiteBuiltinSplit: {
368       if (!InputsWithCorrectTypes(
369               node, context, {{kTfLiteInt32}, {kTfLiteUInt8, kTfLiteInt8}}))
370         return false;
371       const auto& input_tensor = context->tensors[node->inputs->data[1]];
372       const bool is_four_dim_or_less = input_tensor.dims->size < 5;
373       // We need splitting axis to be constant, so Hexagon knows output
374       // shapes.
375       return is_four_dim_or_less &&
376              IsConstantTensor(GetInput(context, node, 0));
377     }
378     case kTfLiteBuiltinResizeBilinear: {
379       if (!InputsWithCorrectTypes(
380               node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) ||
381           !IsConstantTensor(GetInput(context, node, 1))) {
382         return false;
383       }
384       const auto& size_tensor = context->tensors[node->inputs->data[1]];
385       return NumElements(&size_tensor) == 2;
386     }
387     case kTfLiteBuiltinNeg: {
388       return InputsWithCorrectTypes(node, context,
389                                     {{kTfLiteUInt8, kTfLiteInt8}});
390     }
391     case kTfLiteBuiltinTranspose: {
392       return InputsWithCorrectTypes(
393           node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}});
394     }
395     case kTfLiteBuiltinSpaceToDepth:
396     case kTfLiteBuiltinDepthToSpace: {
397       return InputsWithCorrectTypes(node, context,
398                                     {{kTfLiteUInt8, kTfLiteInt8}});
399     }
400     case kTfLiteBuiltinQuantize: {
401       return InputsWithCorrectTypes(node, context,
402                                     {{kTfLiteUInt8, kTfLiteInt8}});
403     }
404     case kTfLiteBuiltinMinimum: {
405       return InputsWithCorrectTypes(
406           node, context,
407           {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}});
408     }
409     case kTfLiteBuiltinMaximum: {
410       return InputsWithCorrectTypes(
411           node, context,
412           {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}});
413     }
414     case kTfLiteBuiltinSlice: {
415       const auto& begins_tensor = context->tensors[node->inputs->data[1]];
416       const auto& sizes_tensor = context->tensors[node->inputs->data[2]];
417       if (!IsConstantTensor(&begins_tensor) || !IsConstantTensor(&sizes_tensor))
418         return false;
419       return InputsWithCorrectTypes(node, context,
420                                     {{kTfLiteUInt8, kTfLiteInt8},
421                                      {kTfLiteInt32, kTfLiteInt64},
422                                      {kTfLiteInt32, kTfLiteInt64}});
423     }
424     case kTfLiteBuiltinPack: {
425       // All tensors must be 8-bit.
426       for (int i = 0; i < node->inputs->size; ++i) {
427         if (!TensorTypeMatch(node->inputs->data[i], context, kTfLiteUInt8) &&
428             !TensorTypeMatch(node->inputs->data[i], context, kTfLiteInt8))
429           return false;
430       }
431       return true;
432     }
433     case kTfLiteBuiltinStridedSlice: {
434       if (!InputsWithCorrectTypes(node, context,
435                                   {{kTfLiteUInt8, kTfLiteInt8},
436                                    {kTfLiteInt32},
437                                    {kTfLiteInt32},
438                                    {kTfLiteInt32}}))
439         return false;
440       const auto& begins_tensor = context->tensors[node->inputs->data[1]];
441       const auto& ends_tensor = context->tensors[node->inputs->data[2]];
442       const auto& step_tensor = context->tensors[node->inputs->data[3]];
443       if (!IsConstantTensor(&begins_tensor) ||
444           !IsConstantTensor(&ends_tensor) || !IsConstantTensor(&step_tensor))
445         return false;
446       const TfLiteStridedSliceParams* params =
447           reinterpret_cast<const TfLiteStridedSliceParams*>(node->builtin_data);
448       // Hexagon doesn't support ellipsis/new-axis masks.
449       return (params->ellipsis_mask == 0 && params->new_axis_mask == 0);
450     }
451     case kTfLiteBuiltinSquaredDifference: {
452       return InputsWithCorrectTypes(node, context,
453                                     {{kTfLiteInt8}, {kTfLiteInt8}});
454     }
455     case kTfLiteBuiltinRsqrt: {
456       return InputsWithCorrectTypes(node, context, {{kTfLiteInt8}});
457     }
458     default:
459       return false;
460   }
461   return false;
462 }
463 
464 }  // namespace tflite
465