xref: /aosp_15_r20/external/armnn/delegate/classic/src/ElementwiseBinary.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 #include "MultiLayerFacade.hpp"
10 #include "SharedFunctions.hpp"
11 
12 #include <tensorflow/lite/builtin_ops.h>
13 #include <tensorflow/lite/c/builtin_op_data.h>
14 #include <tensorflow/lite/c/common.h>
15 #include <tensorflow/lite/minimal_logging.h>
16 #include "tensorflow/lite/delegates/utils.h"
17 
18 namespace armnnDelegate
19 {
20 
ValidateAddOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)21 TfLiteStatus ValidateAddOperator(DelegateData& delegateData,
22                                  TfLiteContext* tfLiteContext,
23                                  const armnn::TensorInfo& inputInfo1,
24                                  const armnn::TensorInfo& inputInfo2,
25                                  const armnn::TensorInfo& outputInfo)
26 {
27     bool isSupported = false;
28     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
29     {
30         std::vector<armnn::TensorInfo> infos { inputInfo1, inputInfo2, outputInfo };
31         FORWARD_LAYER_SUPPORT_FUNC("ADD",
32                                    tfLiteContext,
33                                    IsElementwiseBinarySupported,
34                                    delegateData.m_Backends,
35                                    isSupported,
36                                    armnn::BackendId(),
37                                    inputInfo1,
38                                    inputInfo2,
39                                    outputInfo,
40                                    armnn::BinaryOperation::Add);
41     };
42 
43     validateFunc(outputInfo, isSupported);
44     return isSupported ? kTfLiteOk : kTfLiteError;
45 }
46 
47 
ValidateDivOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)48 TfLiteStatus ValidateDivOperator(DelegateData& delegateData,
49                                  TfLiteContext* tfLiteContext,
50                                  const armnn::TensorInfo& inputInfo1,
51                                  const armnn::TensorInfo& inputInfo2,
52                                  const armnn::TensorInfo& outputInfo)
53 {
54     bool isSupported = false;
55     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
56     {
57         FORWARD_LAYER_SUPPORT_FUNC("DIV",
58                                    tfLiteContext,
59                                    IsElementwiseBinarySupported,
60                                    delegateData.m_Backends,
61                                    isSupported,
62                                    armnn::BackendId(),
63                                    inputInfo1,
64                                    inputInfo2,
65                                    outputTensorInfo,
66                                    armnn::BinaryOperation::Div);
67     };
68 
69     validateFunc(outputInfo, isSupported);
70     return isSupported ? kTfLiteOk : kTfLiteError;
71 }
72 
ValidateFloorDivOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)73 TfLiteStatus ValidateFloorDivOperator(DelegateData& delegateData,
74                                       TfLiteContext* tfLiteContext,
75                                       const armnn::TensorInfo& inputInfo1,
76                                       const armnn::TensorInfo& inputInfo2,
77                                       const armnn::TensorInfo& outputInfo)
78 {
79     // need first to validate that the div operator is supported
80     // then that the floor operator is supported
81     TfLiteStatus status = ValidateDivOperator(delegateData, tfLiteContext, inputInfo1, inputInfo2, outputInfo);
82     if (status != kTfLiteOk)
83     {
84         return status;
85     }
86     // if the inputs and output of the div are all Signed32 we don't need to add the floor operator afterward.
87     if (AreAllSigned32(inputInfo1, inputInfo2, outputInfo))
88     {
89         return status;
90     }
91     // in case broadcasting is being done from one of the inputs to the div
92     // choose the full sized input tensor to pass to the floor validation routine
93     armnn::TensorInfo floorInputInfo = inputInfo1;
94     if (inputInfo1.GetNumDimensions() < inputInfo2.GetNumDimensions())
95     {
96         floorInputInfo = inputInfo2;
97     }
98     status = ValidateFloorOperator(delegateData, tfLiteContext, floorInputInfo, outputInfo);
99     return status;
100 }
101 
ValidateMaximumOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)102 TfLiteStatus ValidateMaximumOperator(DelegateData& delegateData,
103                                      TfLiteContext* tfLiteContext,
104                                      const armnn::TensorInfo& inputInfo1,
105                                      const armnn::TensorInfo& inputInfo2,
106                                      const armnn::TensorInfo& outputInfo)
107 {
108     bool isSupported = false;
109     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
110     {
111         FORWARD_LAYER_SUPPORT_FUNC("MAXIMUM",
112                                    tfLiteContext,
113                                    IsElementwiseBinarySupported,
114                                    delegateData.m_Backends,
115                                    isSupported,
116                                    armnn::BackendId(),
117                                    inputInfo1,
118                                    inputInfo2,
119                                    outputTensorInfo,
120                                    armnn::BinaryOperation::Maximum);
121     };
122 
123     validateFunc(outputInfo, isSupported);
124     return isSupported ? kTfLiteOk : kTfLiteError;
125 }
126 
ValidateMinimumOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)127 TfLiteStatus ValidateMinimumOperator(DelegateData& delegateData,
128                                      TfLiteContext* tfLiteContext,
129                                      const armnn::TensorInfo& inputInfo1,
130                                      const armnn::TensorInfo& inputInfo2,
131                                      const armnn::TensorInfo& outputInfo)
132 {
133     bool isSupported = false;
134     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
135     {
136         FORWARD_LAYER_SUPPORT_FUNC("MINIMUM",
137                                    tfLiteContext,
138                                    IsElementwiseBinarySupported,
139                                    delegateData.m_Backends,
140                                    isSupported,
141                                    armnn::BackendId(),
142                                    inputInfo1,
143                                    inputInfo2,
144                                    outputTensorInfo,
145                                    armnn::BinaryOperation::Minimum);
146     };
147 
148     validateFunc(outputInfo, isSupported);
149     return isSupported ? kTfLiteOk : kTfLiteError;
150 }
151 
ValidateMulOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)152 TfLiteStatus ValidateMulOperator(DelegateData& delegateData,
153                                  TfLiteContext* tfLiteContext,
154                                  const armnn::TensorInfo& inputInfo1,
155                                  const armnn::TensorInfo& inputInfo2,
156                                  const armnn::TensorInfo& outputInfo)
157 {
158     bool isSupported = false;
159     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
160     {
161         FORWARD_LAYER_SUPPORT_FUNC("MUL",
162                                    tfLiteContext,
163                                    IsElementwiseBinarySupported,
164                                    delegateData.m_Backends,
165                                    isSupported,
166                                    armnn::BackendId(),
167                                    inputInfo1,
168                                    inputInfo2,
169                                    outputTensorInfo,
170                                    armnn::BinaryOperation::Mul);
171     };
172 
173     validateFunc(outputInfo, isSupported);
174     return isSupported ? kTfLiteOk : kTfLiteError;
175 }
176 
ValidateSubOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)177 TfLiteStatus ValidateSubOperator(DelegateData& delegateData,
178                                  TfLiteContext* tfLiteContext,
179                                  const armnn::TensorInfo& inputInfo1,
180                                  const armnn::TensorInfo& inputInfo2,
181                                  const armnn::TensorInfo& outputInfo)
182 {
183     bool isSupported = false;
184     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
185     {
186         FORWARD_LAYER_SUPPORT_FUNC("SUB",
187                                    tfLiteContext,
188                                    IsElementwiseBinarySupported,
189                                    delegateData.m_Backends,
190                                    isSupported,
191                                    armnn::BackendId(),
192                                    inputInfo1,
193                                    inputInfo2,
194                                    outputTensorInfo,
195                                    armnn::BinaryOperation::Sub);
196     };
197 
198     validateFunc(outputInfo, isSupported);
199     return isSupported ? kTfLiteOk : kTfLiteError;
200 }
201 
AddFloorDivLayer(DelegateData & delegateData,const armnn::TensorInfo & outputTensorInfo)202 std::pair<armnn::IConnectableLayer*, armnn::IConnectableLayer*> AddFloorDivLayer(
203     DelegateData& delegateData,
204     const armnn::TensorInfo& outputTensorInfo)
205 {
206     armnn::IConnectableLayer* divisionLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
207             armnn::BinaryOperation::Div);
208     // if the output of the div is Signed32 the Floor layer is not required
209     if (armnn::DataType::Signed32 == outputTensorInfo.GetDataType())
210     {
211         return std::make_pair(divisionLayer, divisionLayer);
212     }
213     armnn::IOutputSlot& outputSlot = divisionLayer->GetOutputSlot(0);
214     outputSlot.SetTensorInfo(outputTensorInfo);
215     armnn::IConnectableLayer* floorLayer = delegateData.m_Network->AddFloorLayer();
216     outputSlot.Connect(floorLayer->GetInputSlot(0));
217     return std::make_pair(divisionLayer, floorLayer);
218 }
219 
VisitElementwiseBinaryOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t elementwiseBinaryOperatorCode)220 TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
221                                             TfLiteContext* tfLiteContext,
222                                             TfLiteNode* tfLiteNode,
223                                             int nodeIndex,
224                                             int32_t elementwiseBinaryOperatorCode)
225 {
226     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
227     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
228 
229     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
230     const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
231     if (IsDynamicTensor(tfLiteInputTensor0))
232     {
233         TF_LITE_MAYBE_KERNEL_LOG(
234             tfLiteContext,
235             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
236             elementwiseBinaryOperatorCode, nodeIndex);
237         return kTfLiteError;
238     }
239 
240     const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]];
241     if (IsDynamicTensor(tfLiteInputTensor1))
242     {
243         TF_LITE_MAYBE_KERNEL_LOG(
244             tfLiteContext,
245             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
246             elementwiseBinaryOperatorCode, nodeIndex);
247         return kTfLiteError;
248     }
249 
250     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
251     if (IsDynamicTensor(tfLiteOutputTensor))
252     {
253         TF_LITE_MAYBE_KERNEL_LOG(
254             tfLiteContext,
255             "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
256             elementwiseBinaryOperatorCode, nodeIndex);
257         return kTfLiteError;
258     }
259 
260     armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
261     armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
262 
263     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
264 
265     // Check if we need to expand the dims of the input tensor infos.
266     // This is required for a few of the backends.
267     if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
268     {
269         ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
270     }
271 
272     auto* tfLiteNodeParameters = reinterpret_cast<TfLiteAddParams*>(tfLiteNode->builtin_data);
273     TfLiteFusedActivation activationType = kTfLiteActNone;
274     if (tfLiteNodeParameters)
275     {
276         activationType = tfLiteNodeParameters->activation;
277         TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
278                                                                         outputTensorInfo, activationType);
279         if(activationStatus != kTfLiteOk)
280         {
281             return kTfLiteError;
282         }
283     }
284 
285     if (!delegateData.m_Network)
286     {
287         switch(elementwiseBinaryOperatorCode)
288         {
289             case kTfLiteBuiltinAdd:
290                 return ValidateAddOperator(delegateData,
291                                            tfLiteContext,
292                                            inputTensorInfo0,
293                                            inputTensorInfo1,
294                                            outputTensorInfo);
295             case kTfLiteBuiltinDiv:
296                 return ValidateDivOperator(delegateData,
297                                            tfLiteContext,
298                                            inputTensorInfo0,
299                                            inputTensorInfo1,
300                                            outputTensorInfo);
301             case kTfLiteBuiltinFloorDiv:
302                 return ValidateFloorDivOperator(delegateData,
303                                                 tfLiteContext,
304                                                 inputTensorInfo0,
305                                                 inputTensorInfo1,
306                                                 outputTensorInfo);
307             case kTfLiteBuiltinMaximum:
308                 return ValidateMaximumOperator(delegateData,
309                                                tfLiteContext,
310                                                inputTensorInfo0,
311                                                inputTensorInfo1,
312                                                outputTensorInfo);
313             case kTfLiteBuiltinMinimum:
314                 return ValidateMinimumOperator(delegateData,
315                                                tfLiteContext,
316                                                inputTensorInfo0,
317                                                inputTensorInfo1,
318                                                outputTensorInfo);
319             case kTfLiteBuiltinMul:
320                 return ValidateMulOperator(delegateData,
321                                            tfLiteContext,
322                                            inputTensorInfo0,
323                                            inputTensorInfo1,
324                                            outputTensorInfo);
325             case kTfLiteBuiltinSub:
326                 return ValidateSubOperator(delegateData,
327                                            tfLiteContext,
328                                            inputTensorInfo0,
329                                            inputTensorInfo1,
330                                            outputTensorInfo);
331             default:
332                 return kTfLiteError;
333         }
334     }
335 
336     armnn::IConnectableLayer* elementwiseBinaryLayer = nullptr;
337     MultiLayerFacade multiLayer;
338     switch(elementwiseBinaryOperatorCode)
339     {
340         case kTfLiteBuiltinAdd:
341             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
342                     armnn::BinaryOperation::Add);
343             break;
344         case kTfLiteBuiltinDiv:
345             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
346                     armnn::BinaryOperation::Div);
347             break;
348         case kTfLiteBuiltinFloorDiv:
349             {
350                 auto layers = AddFloorDivLayer(delegateData, outputTensorInfo);
351                 multiLayer.AssignValues(layers.first, layers.second);
352                 elementwiseBinaryLayer = &multiLayer;
353             }
354             break;
355         case kTfLiteBuiltinMaximum:
356             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
357                     armnn::BinaryOperation::Maximum);
358             break;
359         case kTfLiteBuiltinMinimum:
360             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
361                     armnn::BinaryOperation::Minimum);
362             break;
363         case kTfLiteBuiltinMul:
364             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
365                     armnn::BinaryOperation::Mul);
366             break;
367         case kTfLiteBuiltinSub:
368             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
369                     armnn::BinaryOperation::Sub);
370             break;
371         default:
372             return kTfLiteError;
373     }
374     ARMNN_ASSERT(elementwiseBinaryLayer != nullptr);
375     armnn::IOutputSlot& outputSlot = elementwiseBinaryLayer->GetOutputSlot(0);
376     outputSlot.SetTensorInfo(outputTensorInfo);
377 
378     auto inputsTensorsProcess = ProcessInputs(elementwiseBinaryLayer,
379                                               delegateData,
380                                               tfLiteContext,
381                                               tfLiteNode);
382     if (inputsTensorsProcess == kTfLiteError)
383     {
384         return inputsTensorsProcess;
385     }
386 
387     if(Connect(elementwiseBinaryLayer, tfLiteNode, delegateData) != kTfLiteOk)
388     {
389         return kTfLiteError;
390     }
391 
392     if (!tfLiteNodeParameters)
393     {
394         // No Activation
395         return kTfLiteOk;
396     }
397     // Check and Create Activation
398     return FusedActivation(tfLiteContext, tfLiteNode, activationType, elementwiseBinaryLayer, 0, delegateData);
399 }
400 
401 } // namespace armnnDelegate
402