xref: /aosp_15_r20/external/armnn/delegate/classic/src/Normalization.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <tensorflow/lite/builtin_ops.h>
9 #include <tensorflow/lite/c/builtin_op_data.h>
10 #include <tensorflow/lite/c/common.h>
11 #include <tensorflow/lite/minimal_logging.h>
12 
13 namespace armnnDelegate
14 {
15 
VisitL2NormalizationOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)16 TfLiteStatus VisitL2NormalizationOperator(DelegateData& delegateData,
17                                           TfLiteContext* tfLiteContext,
18                                           TfLiteNode* tfLiteNode,
19                                           int nodeIndex,
20                                           int32_t operatorCode)
21 {
22     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
23     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
24 
25     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
27     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
28     {
29         return kTfLiteError;
30     }
31 
32     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
33     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
34     {
35         return kTfLiteError;
36     }
37 
38     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
39     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
40 
41     armnn::L2NormalizationDescriptor descriptor;
42     descriptor.m_DataLayout = armnn::DataLayout::NHWC;
43 
44     bool isSupported = false;
45     armnn::BackendId setBackend;
46     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
47     {
48         FORWARD_LAYER_SUPPORT_FUNC("L2_NORMALIZATION",
49                                    tfLiteContext,
50                                    IsL2NormalizationSupported,
51                                    delegateData.m_Backends,
52                                    isSupported,
53                                    setBackend,
54                                    inputTensorInfo,
55                                    outInfo,
56                                    descriptor);
57     };
58 
59     if (!delegateData.m_Network)
60     {
61         validateFunc(outputTensorInfo, isSupported);
62         return isSupported ? kTfLiteOk : kTfLiteError;
63     }
64 
65     // Add a L2Normalization layer
66     armnn::IConnectableLayer* layer = delegateData.m_Network->AddL2NormalizationLayer(descriptor);
67     layer->SetBackendId(setBackend);
68     ARMNN_ASSERT(layer != nullptr);
69 
70     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
71     outputSlot.SetTensorInfo(outputTensorInfo);
72 
73     // try to connect the Constant Inputs if there are any
74     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
75     {
76         return kTfLiteError;
77     }
78 
79     // Connect
80     return Connect(layer, tfLiteNode, delegateData);
81 }
82 
83 
VisitLocalResponseNormalizationOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t normalizationOperatorCode)84 TfLiteStatus VisitLocalResponseNormalizationOperator(DelegateData& delegateData,
85                                                      TfLiteContext* tfLiteContext,
86                                                      TfLiteNode* tfLiteNode,
87                                                      int nodeIndex,
88                                                      int32_t normalizationOperatorCode)
89 {
90     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
91     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
92 
93     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
94     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
95     if (!IsValid(tfLiteContext, tfLiteInputTensor, normalizationOperatorCode, nodeIndex))
96     {
97         return kTfLiteError;
98     }
99 
100     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
101     if (!IsValid(tfLiteContext, tfLiteOutputTensor, normalizationOperatorCode, nodeIndex))
102     {
103         return kTfLiteError;
104     }
105 
106     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
107     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
108 
109     armnn::NormalizationDescriptor descriptor;
110     descriptor.m_DataLayout = armnn::DataLayout::NHWC;
111     descriptor.m_NormChannelType = armnn::NormalizationAlgorithmChannel::Across;
112     descriptor.m_NormMethodType  = armnn::NormalizationAlgorithmMethod::LocalBrightness;
113 
114     auto* params = reinterpret_cast<TfLiteLocalResponseNormParams*>(tfLiteNode->builtin_data);
115     descriptor.m_NormSize = params->radius;
116     descriptor.m_K        = params->bias;
117     descriptor.m_Alpha    = params->alpha;
118     descriptor.m_Beta     = params->beta;
119 
120     // ArmNN expects normSize to be the full size of the normalization window
121     descriptor.m_NormSize = 1 + (2 * descriptor.m_NormSize);
122 
123     bool isSupported = false;
124     armnn::BackendId setBackend;
125     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
126     {
127         FORWARD_LAYER_SUPPORT_FUNC("NORMALIZATION",
128                                    tfLiteContext,
129                                    IsNormalizationSupported,
130                                    delegateData.m_Backends,
131                                    isSupported,
132                                    setBackend,
133                                    inputTensorInfo,
134                                    outInfo,
135                                    descriptor);
136     };
137 
138     if (!delegateData.m_Network)
139     {
140         validateFunc(outputTensorInfo, isSupported);
141         return isSupported ? kTfLiteOk : kTfLiteError;
142     }
143 
144     // Add a Normalization layer
145     armnn::IConnectableLayer* layer = delegateData.m_Network->AddNormalizationLayer(descriptor);
146     layer->SetBackendId(setBackend);
147     ARMNN_ASSERT(layer != nullptr);
148 
149     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
150     outputSlot.SetTensorInfo(outputTensorInfo);
151 
152     // try to connect the Constant Inputs if there are any
153     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
154     {
155         return kTfLiteError;
156     }
157 
158     // Connect
159     return Connect(layer, tfLiteNode, delegateData);
160 }
161 
162 } // namespace armnnDelegate
163