1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <tensorflow/lite/builtin_ops.h>
9 #include <tensorflow/lite/c/builtin_op_data.h>
10 #include <tensorflow/lite/c/common.h>
11 #include <tensorflow/lite/minimal_logging.h>
12
13 namespace armnnDelegate
14 {
15
VisitL2NormalizationOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)16 TfLiteStatus VisitL2NormalizationOperator(DelegateData& delegateData,
17 TfLiteContext* tfLiteContext,
18 TfLiteNode* tfLiteNode,
19 int nodeIndex,
20 int32_t operatorCode)
21 {
22 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
23 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
24
25 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
27 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
28 {
29 return kTfLiteError;
30 }
31
32 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
33 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
34 {
35 return kTfLiteError;
36 }
37
38 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
39 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
40
41 armnn::L2NormalizationDescriptor descriptor;
42 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
43
44 bool isSupported = false;
45 armnn::BackendId setBackend;
46 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
47 {
48 FORWARD_LAYER_SUPPORT_FUNC("L2_NORMALIZATION",
49 tfLiteContext,
50 IsL2NormalizationSupported,
51 delegateData.m_Backends,
52 isSupported,
53 setBackend,
54 inputTensorInfo,
55 outInfo,
56 descriptor);
57 };
58
59 if (!delegateData.m_Network)
60 {
61 validateFunc(outputTensorInfo, isSupported);
62 return isSupported ? kTfLiteOk : kTfLiteError;
63 }
64
65 // Add a L2Normalization layer
66 armnn::IConnectableLayer* layer = delegateData.m_Network->AddL2NormalizationLayer(descriptor);
67 layer->SetBackendId(setBackend);
68 ARMNN_ASSERT(layer != nullptr);
69
70 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
71 outputSlot.SetTensorInfo(outputTensorInfo);
72
73 // try to connect the Constant Inputs if there are any
74 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
75 {
76 return kTfLiteError;
77 }
78
79 // Connect
80 return Connect(layer, tfLiteNode, delegateData);
81 }
82
83
VisitLocalResponseNormalizationOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t normalizationOperatorCode)84 TfLiteStatus VisitLocalResponseNormalizationOperator(DelegateData& delegateData,
85 TfLiteContext* tfLiteContext,
86 TfLiteNode* tfLiteNode,
87 int nodeIndex,
88 int32_t normalizationOperatorCode)
89 {
90 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
91 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
92
93 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
94 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
95 if (!IsValid(tfLiteContext, tfLiteInputTensor, normalizationOperatorCode, nodeIndex))
96 {
97 return kTfLiteError;
98 }
99
100 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
101 if (!IsValid(tfLiteContext, tfLiteOutputTensor, normalizationOperatorCode, nodeIndex))
102 {
103 return kTfLiteError;
104 }
105
106 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
107 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
108
109 armnn::NormalizationDescriptor descriptor;
110 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
111 descriptor.m_NormChannelType = armnn::NormalizationAlgorithmChannel::Across;
112 descriptor.m_NormMethodType = armnn::NormalizationAlgorithmMethod::LocalBrightness;
113
114 auto* params = reinterpret_cast<TfLiteLocalResponseNormParams*>(tfLiteNode->builtin_data);
115 descriptor.m_NormSize = params->radius;
116 descriptor.m_K = params->bias;
117 descriptor.m_Alpha = params->alpha;
118 descriptor.m_Beta = params->beta;
119
120 // ArmNN expects normSize to be the full size of the normalization window
121 descriptor.m_NormSize = 1 + (2 * descriptor.m_NormSize);
122
123 bool isSupported = false;
124 armnn::BackendId setBackend;
125 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
126 {
127 FORWARD_LAYER_SUPPORT_FUNC("NORMALIZATION",
128 tfLiteContext,
129 IsNormalizationSupported,
130 delegateData.m_Backends,
131 isSupported,
132 setBackend,
133 inputTensorInfo,
134 outInfo,
135 descriptor);
136 };
137
138 if (!delegateData.m_Network)
139 {
140 validateFunc(outputTensorInfo, isSupported);
141 return isSupported ? kTfLiteOk : kTfLiteError;
142 }
143
144 // Add a Normalization layer
145 armnn::IConnectableLayer* layer = delegateData.m_Network->AddNormalizationLayer(descriptor);
146 layer->SetBackendId(setBackend);
147 ARMNN_ASSERT(layer != nullptr);
148
149 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
150 outputSlot.SetTensorInfo(outputTensorInfo);
151
152 // try to connect the Constant Inputs if there are any
153 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
154 {
155 return kTfLiteError;
156 }
157
158 // Connect
159 return Connect(layer, tfLiteNode, delegateData);
160 }
161
162 } // namespace armnnDelegate
163