xref: /aosp_15_r20/external/armnn/delegate/classic/src/FullyConnected.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 
10 #include "armnnUtils/TensorUtils.hpp"
11 #include <armnn/utility/IgnoreUnused.hpp>
12 
13 #include <tensorflow/lite/builtin_ops.h>
14 #include <tensorflow/lite/c/builtin_op_data.h>
15 #include <tensorflow/lite/c/common.h>
16 #include <tensorflow/lite/minimal_logging.h>
17 
18 namespace armnnDelegate
19 {
20 
VisitFullyConnectedOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)21 TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
22                                          TfLiteContext* tfLiteContext,
23                                          TfLiteNode* tfLiteNode,
24                                          int nodeIndex,
25                                          int32_t operatorCode)
26 {
27     auto numInputs = tfLiteNode->inputs->size;
28     if (numInputs < 2)
29     {
30         TF_LITE_MAYBE_KERNEL_LOG(
31             tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
32             2, numInputs, nodeIndex);
33         return kTfLiteError;
34     }
35     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
36     bool biasEnabled = IsOptionalOperandPresent(tfLiteNode, 2);
37 
38     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
39     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
40     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
41     {
42         return kTfLiteError;
43     }
44 
45     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
46     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
47     {
48         return kTfLiteError;
49     }
50 
51     const TfLiteTensor& tfLiteWeightsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
52     if (!IsValid(tfLiteContext, tfLiteWeightsTensor, operatorCode, nodeIndex))
53     {
54         return kTfLiteError;
55     }
56 
57     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
58     const armnn::TensorInfo& weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
59     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
60 
61     // Check that we support fused activation before we attempt to create a layer
62     auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams *>(tfLiteNode->builtin_data);
63     TfLiteFusedActivation activationType=kTfLiteActNone;
64     if (tfLiteNodeParameters)
65     {
66         activationType = tfLiteNodeParameters->activation;
67         TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
68                                                                         outputTensorInfo, activationType);
69         if(activationStatus != kTfLiteOk)
70         {
71             return kTfLiteError;
72         }
73     }
74 
75     // Fully Connected Layer accepts two dimensional weights input
76     int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
77     if (weightsDimension != 2)
78     {
79         TF_LITE_MAYBE_KERNEL_LOG(
80             tfLiteContext,
81             "TfLiteArmnnDelegate: Dimension #$d for Fully Connected weights is not supported by Armnn"
82             " in operator #%d node #%d: ", weightsDimension, operatorCode, nodeIndex);
83         return kTfLiteError;
84     }
85 
86     armnn::TensorInfo biasTensorInfo;
87     if (biasEnabled)
88     {
89         const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
90         if (!IsValid(tfLiteContext, tfLiteBiasTensor, operatorCode, nodeIndex))
91         {
92             return kTfLiteError;
93         }
94         biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
95     }
96     else
97     {
98         biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
99     }
100 
101     armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
102     if (inputTensorInfo.GetNumDimensions() > 2)
103     {
104         // Calculate reshape to flatten to 2D [batch_size, input_size]
105         std::vector<unsigned int> reshapedDimensions(2);
106         reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
107         reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
108 
109         if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
110         {
111             TF_LITE_MAYBE_KERNEL_LOG(
112                 tfLiteContext,
113                 "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
114                 reshapedDimensions[1], operatorCode, nodeIndex);
115             return kTfLiteError;
116         }
117 
118         reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
119     }
120     armnn::TensorInfo reshapedOutputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
121 
122     if (outputTensorInfo.GetNumDimensions() > 2)
123     {
124         // Calculate reshape to flatten to 2D [batch_size, input_size]
125         std::vector<unsigned int> reshapedDimensions(2);
126         reshapedDimensions[1] = weightsTensorInfo.GetShape()[0];
127         reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1];
128 
129         if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
130         {
131             TF_LITE_MAYBE_KERNEL_LOG(
132                     tfLiteContext,
133                     "TfLiteArmnnDelegate: Failed to deduce output tensor shape from filter size #%d #%d node #%d: ",
134                     reshapedDimensions[1], operatorCode, nodeIndex);
135             return kTfLiteError;
136         }
137         reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
138     }
139 
140     armnn::FullyConnectedDescriptor descriptor;
141     descriptor.m_TransposeWeightMatrix = true;
142     descriptor.m_BiasEnabled           = biasEnabled;
143     descriptor.m_ConstantWeights       = weightsTensorInfo.IsConstant();
144 
145     bool isSupported = false;
146     armnn::BackendId setBackend;
147     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
148     {
149 
150         FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED",
151                                    tfLiteContext,
152                                    IsFullyConnectedSupported,
153                                    delegateData.m_Backends,
154                                    isSupported,
155                                    setBackend,
156                                    reshapedTensorInfo,
157                                    outputTensorInfo,
158                                    weightsTensorInfo,
159                                    biasTensorInfo,
160                                    descriptor);
161     };
162 
163     if (!delegateData.m_Network)
164     {
165         validateFunc(reshapedOutputTensorInfo, isSupported);
166         return isSupported ? kTfLiteOk : kTfLiteError;
167     }
168 
169     armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor);
170     layer->SetBackendId(setBackend);
171     ARMNN_ASSERT(layer != nullptr);
172 
173     // Add a constant layer for weights and biases if inputs are constant.
174     if (weightsTensorInfo.IsConstant())
175     {
176         auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
177                                                weightsTensorInfo);
178 
179         armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor);
180 
181         weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
182         weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensorInfo);
183     }
184 
185     if (biasEnabled)
186     {
187         const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
188         if(biasTensorInfo.IsConstant())
189         {
190             auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
191                                                 biasTensorInfo);
192 
193             armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor);
194             ARMNN_ASSERT(biasLayer != nullptr);
195 
196             biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
197             biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
198         }
199     }
200 
201     // The data input can also be constant, so we must check that this is also allocated to an input slot
202     if(inputTensorInfo.IsConstant())
203     {
204         auto input =
205                 CreateConstTensor(&tfLiteContext->tensors[tfLiteNode->inputs->data[0]],
206                                   inputTensorInfo);
207 
208         armnn::IConnectableLayer *inputLayer = delegateData.m_Network->AddConstantLayer(input);
209         inputLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
210         inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
211     }
212 
213     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
214     outputSlot.SetTensorInfo(outputTensorInfo);
215 
216     armnn::IConnectableLayer* reshapeLayer = nullptr;
217     if (inputTensorInfo.GetNumDimensions() > 2)
218     {
219         // Add reshape to flatten to 2D [batch_size, input_size]
220         armnn::ReshapeDescriptor reshapeDescriptor;
221         reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
222         reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
223         ARMNN_ASSERT(reshapeLayer != nullptr);
224 
225         reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
226 
227         // Connect
228         delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
229         reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
230 
231         if (!descriptor.m_ConstantWeights)
232         {
233             delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1));
234         }
235 
236         if (biasEnabled && !biasTensorInfo.IsConstant())
237         {
238             delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
239         }
240         delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
241     }
242 
243     if (reshapeLayer == nullptr)
244     {
245         if(Connect(layer, tfLiteNode, delegateData) != kTfLiteOk)
246         {
247             return kTfLiteError;
248         }
249     }
250 
251     if (outputTensorInfo.GetNumDimensions() > 2)
252     {
253         layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo,
254                                 delegateData);
255         if (!layer)
256         {
257             TF_LITE_MAYBE_KERNEL_LOG(
258                     tfLiteContext,
259                     "TfLiteArmnnDelegate: Failed to add reshape for FullyConnected #%d node #%d: ",
260                     operatorCode,
261                     nodeIndex);
262             return kTfLiteError;
263         }
264     }
265 
266     if (!tfLiteNodeParameters)
267     {
268         // No Activation
269         return kTfLiteOk;
270     }
271 
272     // Check and Create Activation
273     return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
274 }
275 
276 } // namespace armnnDelegate