1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <ClassicDelegateUtils.hpp>
9
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14
15 namespace armnnDelegate
16 {
17
ValidateActivationOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputInfo,armnn::ActivationDescriptor & activationDesc)18 TfLiteStatus ValidateActivationOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 const armnn::TensorInfo& inputInfo,
21 const armnn::TensorInfo& outputInfo,
22 armnn::ActivationDescriptor& activationDesc)
23 {
24 bool isSupported = false;
25 auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
26 {
27 FORWARD_LAYER_SUPPORT_FUNC("ACTIVATION",
28 tfLiteContext,
29 IsActivationSupported,
30 delegateData.m_Backends,
31 isSupported,
32 armnn::BackendId(),
33 inputInfo,
34 outputInfo,
35 activationDesc);
36 };
37
38 validateFunc(outputInfo, isSupported);
39 return isSupported ? kTfLiteOk : kTfLiteError;
40 }
41
VisitActivationOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)42 TfLiteStatus VisitActivationOperator(DelegateData& delegateData,
43 TfLiteContext* tfLiteContext,
44 TfLiteNode* tfLiteNode,
45 int nodeIndex,
46 int32_t operatorCode)
47 {
48 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
49 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
50
51 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
52 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
53 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
54 {
55 return kTfLiteError;
56 }
57
58 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
59 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
60 {
61 return kTfLiteError;
62 }
63
64 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
65 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
66
67 armnn::ActivationDescriptor activationDesc;
68 switch(operatorCode)
69 {
70 case kTfLiteBuiltinRelu:
71 {
72 activationDesc.m_Function = armnn::ActivationFunction::ReLu;
73 break;
74 }
75 case kTfLiteBuiltinRelu6:
76 {
77 activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
78 activationDesc.m_A = 6.0f;
79 break;
80 }
81 case kTfLiteBuiltinLogistic:
82 {
83 activationDesc.m_Function = armnn::ActivationFunction::Sigmoid;
84 break;
85 }
86 case kTfLiteBuiltinTanh:
87 {
88 activationDesc.m_Function = armnn::ActivationFunction::TanH;
89 activationDesc.m_A = 1.0f;
90 activationDesc.m_B = 1.0f;
91 break;
92 }
93 case kTfLiteBuiltinElu:
94 {
95 activationDesc.m_Function = armnn::ActivationFunction::Elu;
96 activationDesc.m_A = 1.0f;
97 break;
98 }
99 case kTfLiteBuiltinHardSwish:
100 {
101 activationDesc.m_Function = armnn::ActivationFunction::HardSwish;
102 break;
103 }
104 default:
105 {
106 return kTfLiteError;
107 }
108 }
109 if (!delegateData.m_Network)
110 {
111 return ValidateActivationOperator(delegateData,
112 tfLiteContext,
113 inputTensorInfo,
114 outputTensorInfo,
115 activationDesc);
116 }
117 armnn::IConnectableLayer* activationLayer = delegateData.m_Network->AddActivationLayer(activationDesc);
118 ARMNN_ASSERT(activationLayer != nullptr);
119
120 armnn::IOutputSlot& outputSlot = activationLayer->GetOutputSlot(0);
121 outputSlot.SetTensorInfo(outputTensorInfo);
122
123 // try to connect the Constant Inputs if there are any
124 if(ProcessInputs(activationLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
125 {
126 return kTfLiteError;
127 }
128
129 // Connect
130 return Connect(activationLayer, tfLiteNode, delegateData);
131 }
132
133 } // namespace armnnDelegate
134