xref: /aosp_15_r20/external/armnn/delegate/classic/src/Activation.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 
15 namespace armnnDelegate
16 {
17 
ValidateActivationOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputInfo,armnn::ActivationDescriptor & activationDesc)18 TfLiteStatus ValidateActivationOperator(DelegateData& delegateData,
19                                         TfLiteContext* tfLiteContext,
20                                         const armnn::TensorInfo& inputInfo,
21                                         const armnn::TensorInfo& outputInfo,
22                                         armnn::ActivationDescriptor& activationDesc)
23 {
24     bool isSupported = false;
25     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
26     {
27         FORWARD_LAYER_SUPPORT_FUNC("ACTIVATION",
28                                    tfLiteContext,
29                                    IsActivationSupported,
30                                    delegateData.m_Backends,
31                                    isSupported,
32                                    armnn::BackendId(),
33                                    inputInfo,
34                                    outputInfo,
35                                    activationDesc);
36     };
37 
38     validateFunc(outputInfo, isSupported);
39     return isSupported ? kTfLiteOk : kTfLiteError;
40 }
41 
VisitActivationOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)42 TfLiteStatus VisitActivationOperator(DelegateData& delegateData,
43                                      TfLiteContext* tfLiteContext,
44                                      TfLiteNode* tfLiteNode,
45                                      int nodeIndex,
46                                      int32_t operatorCode)
47 {
48     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
49     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
50 
51     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
52     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
53     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
54     {
55         return kTfLiteError;
56     }
57 
58     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
59     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
60     {
61         return kTfLiteError;
62     }
63 
64     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
65     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
66 
67     armnn::ActivationDescriptor activationDesc;
68     switch(operatorCode)
69     {
70         case kTfLiteBuiltinRelu:
71         {
72             activationDesc.m_Function = armnn::ActivationFunction::ReLu;
73             break;
74         }
75         case kTfLiteBuiltinRelu6:
76         {
77             activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
78             activationDesc.m_A = 6.0f;
79             break;
80         }
81         case kTfLiteBuiltinLogistic:
82         {
83             activationDesc.m_Function = armnn::ActivationFunction::Sigmoid;
84             break;
85         }
86         case kTfLiteBuiltinTanh:
87         {
88             activationDesc.m_Function = armnn::ActivationFunction::TanH;
89             activationDesc.m_A = 1.0f;
90             activationDesc.m_B = 1.0f;
91             break;
92         }
93         case kTfLiteBuiltinElu:
94         {
95             activationDesc.m_Function = armnn::ActivationFunction::Elu;
96             activationDesc.m_A = 1.0f;
97             break;
98         }
99         case kTfLiteBuiltinHardSwish:
100         {
101             activationDesc.m_Function = armnn::ActivationFunction::HardSwish;
102             break;
103         }
104         default:
105         {
106             return kTfLiteError;
107         }
108     }
109     if (!delegateData.m_Network)
110     {
111         return ValidateActivationOperator(delegateData,
112                                           tfLiteContext,
113                                           inputTensorInfo,
114                                           outputTensorInfo,
115                                           activationDesc);
116     }
117     armnn::IConnectableLayer* activationLayer = delegateData.m_Network->AddActivationLayer(activationDesc);
118     ARMNN_ASSERT(activationLayer != nullptr);
119 
120     armnn::IOutputSlot& outputSlot = activationLayer->GetOutputSlot(0);
121     outputSlot.SetTensorInfo(outputTensorInfo);
122 
123     // try to connect the Constant Inputs if there are any
124     if(ProcessInputs(activationLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
125     {
126         return kTfLiteError;
127     }
128 
129     // Connect
130     return Connect(activationLayer, tfLiteNode, delegateData);
131 }
132 
133 } // namespace armnnDelegate
134