xref: /aosp_15_r20/external/armnn/delegate/classic/src/Redefine.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/utility/IgnoreUnused.hpp>
9 
10 #include <ClassicDelegateUtils.hpp>
11 
12 #include <tensorflow/lite/builtin_ops.h>
13 #include <tensorflow/lite/c/builtin_op_data.h>
14 #include <tensorflow/lite/c/common.h>
15 #include <tensorflow/lite/minimal_logging.h>
16 #include <numeric>
17 
18 namespace armnnDelegate
19 {
20 
VisitCastOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)21 TfLiteStatus VisitCastOperator(DelegateData& delegateData,
22                                TfLiteContext* tfLiteContext,
23                                TfLiteNode* tfLiteNode,
24                                int nodeIndex,
25                                int32_t operatorCode)
26 {
27     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
29 
30     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
31     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
32     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
33     {
34         return kTfLiteError;
35     }
36 
37     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
38     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
39     {
40         return kTfLiteError;
41     }
42 
43     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
44     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
45 
46     bool isSupported = false;
47     armnn::BackendId setBackend;
48     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
49     {
50         FORWARD_LAYER_SUPPORT_FUNC("CAST",
51                                    tfLiteContext,
52                                    IsCastSupported,
53                                    delegateData.m_Backends,
54                                    isSupported,
55                                    setBackend,
56                                    inputTensorInfo,
57                                    outInfo);
58     };
59 
60     // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
61     // support for the operator
62     // If supported, VisitCastOperator will be called again to add the layer to the network as seen further below
63     if (!delegateData.m_Network)
64     {
65         validateFunc(outputTensorInfo, isSupported);
66         return isSupported ? kTfLiteOk : kTfLiteError;
67     }
68 
69     // Add a Cast layer
70     armnn::IConnectableLayer* layer = delegateData.m_Network->AddCastLayer();
71     layer->SetBackendId(setBackend);
72     ARMNN_ASSERT(layer != nullptr);
73 
74     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
75     outputSlot.SetTensorInfo(outputTensorInfo);
76 
77     // try to connect the Constant Inputs if there are any
78     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
79     {
80         return kTfLiteError;
81     }
82 
83     // Connect
84     return Connect(layer, tfLiteNode, delegateData);
85 }
86 
87 
CreateOutputTensorShape(const armnn::TensorInfo & inputTensorInfo,const std::vector<int32_t> & targetShape,armnn::ReshapeDescriptor & reshapeDesc)88 TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo,
89                                      const std::vector<int32_t>& targetShape,
90                                      armnn::ReshapeDescriptor& reshapeDesc)
91 {
92     std::vector<unsigned int> outputDims(targetShape.begin(), targetShape.end());
93     const auto stretchDim = std::find(targetShape.begin(), targetShape.end(), -1);
94 
95     if (stretchDim != targetShape.end())
96     {
97         if (std::find(std::next(stretchDim), targetShape.end(), -1) != targetShape.end())
98         {
99             // Return kTfLiteError and log the error after returning
100             return kTfLiteError;
101         }
102 
103         auto targetNumElements =
104             armnn::numeric_cast<unsigned int>(
105                 std::accumulate(targetShape.begin(), targetShape.end(), -1, std::multiplies<int32_t>()));
106 
107         auto stretchIndex = static_cast<size_t>(std::distance(targetShape.begin(), stretchDim));
108         outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
109     }
110 
111     armnn::TensorShape outputShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()),
112                                                         outputDims.data());
113     reshapeDesc.m_TargetShape = outputShape;
114     return kTfLiteOk;
115 }
116 
VisitReshapeOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)117 TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
118                                   TfLiteContext* tfLiteContext,
119                                   TfLiteNode* tfLiteNode,
120                                   int nodeIndex,
121                                   int32_t operatorCode)
122 {
123     auto numInputs = tfLiteNode->inputs->size;
124 
125     if (numInputs == 2)
126     {
127         TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
128     }
129     else
130     {
131         TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
132     }
133     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
134 
135     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
136     const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
137     if (!IsValid(tfLiteContext, tfLiteInputTensor0, operatorCode, nodeIndex))
138     {
139         return kTfLiteError;
140     }
141 
142     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
143     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
144     {
145         return kTfLiteError;
146     }
147 
148     const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
149     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
150 
151     armnn::ReshapeDescriptor reshapeDesc;
152     std::vector<int32_t> targetShape;
153 
154     TfLiteReshapeParams* reshapeOptions = reinterpret_cast<TfLiteReshapeParams*>(tfLiteNode->builtin_data);
155 
156     // The new shape can be defined by either a second input tensor or by a builtin option, we need to check for both.
157     // Options might be set without valid data. we need to check the dimensions are in a valid range.
158     if (reshapeOptions && reshapeOptions->num_dimensions > 0 && reshapeOptions->num_dimensions <= 8)
159     {
160         for (int i=0; i < reshapeOptions->num_dimensions; ++i)
161         {
162             targetShape.push_back(reshapeOptions->shape[i]);
163         }
164     }
165     else if (numInputs == 2)
166     {
167         // Get shape from the second input tensor
168         const TfLiteTensor& tfLiteShapeInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
169         if (!IsValid(tfLiteContext, tfLiteShapeInputTensor, operatorCode, nodeIndex))
170         {
171             return kTfLiteError;
172         }
173 
174         if (tfLiteShapeInputTensor.dims->size != 1)
175         {
176             TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
177                                      "TfLiteArmnnDelegate: Target 'shape' input is not a 1D tensor in "
178                                      "operator #%d node #%d: Falling back to TfLiteOptions.",
179                                      operatorCode, nodeIndex);
180         }
181         else
182         {
183             // Get the shape data out of the input tensor
184             auto* shapeTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteShapeInputTensor);
185             auto shapeTensorNumValues = tfLiteShapeInputTensor.dims->data[0];
186             for (auto i=0; i < shapeTensorNumValues; ++i)
187             {
188                 targetShape.push_back(*(shapeTensorDataPtr+i));
189             }
190         }
191     }
192     else
193     {
194         TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
195                                  "Target shape not defined in reshape parameters or input tensor. "
196                                  "At least one method required in operator #%d node #%d: ",
197                                  operatorCode, nodeIndex);
198         return kTfLiteError;
199     }
200 
201     // Use the data to create the required tensor shape.
202     if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk)
203     {
204         TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
205                                  "TfLiteArmnnDelegate: At most one component of shape can be -1 in: "
206                                  "operator #%d node #%d: ",
207                                  operatorCode, nodeIndex);
208         return kTfLiteError;
209     }
210 
211     if (reshapeDesc.m_TargetShape.GetNumElements() != inputTensorInfo0.GetNumElements())
212     {
213         TF_LITE_MAYBE_KERNEL_LOG(
214             tfLiteContext,
215             "TfLiteArmnnDelegate: Reshape, number of elements in output shape does not match input "
216             "operator #%d node #%d: ",
217             operatorCode, nodeIndex);
218         return kTfLiteError;
219     }
220 
221     bool isSupported = false;
222     armnn::BackendId setBackend;
223     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
224     {
225         FORWARD_LAYER_SUPPORT_FUNC("RESHAPE",
226                                    tfLiteContext,
227                                    IsReshapeSupported,
228                                    delegateData.m_Backends,
229                                    isSupported,
230                                    setBackend,
231                                    inputTensorInfo0,
232                                    outInfo,
233                                    reshapeDesc);
234     };
235 
236     if (!delegateData.m_Network)
237     {
238         validateFunc(outputTensorInfo, isSupported);
239         return isSupported ? kTfLiteOk : kTfLiteError;
240     }
241 
242     armnn::IConnectableLayer* layer = delegateData.m_Network->AddReshapeLayer(reshapeDesc);
243     layer->SetBackendId(setBackend);
244     ARMNN_ASSERT(layer != nullptr);
245 
246     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
247     outputSlot.SetTensorInfo(outputTensorInfo);
248 
249     // try to connect the Constant Inputs if there are any
250     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
251     {
252         return kTfLiteError;
253     }
254 
255     // Connect
256     return Connect(layer, tfLiteNode, delegateData);
257 }
258 
VisitSqueezeOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)259 TfLiteStatus VisitSqueezeOperator(DelegateData& delegateData,
260                                   TfLiteContext* tfLiteContext,
261                                   TfLiteNode* tfLiteNode,
262                                   int nodeIndex,
263                                   int32_t operatorCode)
264 {
265     armnn::IgnoreUnused(delegateData,
266                         tfLiteContext,
267                         tfLiteNode,
268                         nodeIndex,
269                         operatorCode);
270 
271     return kTfLiteError;
272 }
273 
VisitExpandDimsOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)274 TfLiteStatus VisitExpandDimsOperator(DelegateData& delegateData,
275                                      TfLiteContext* tfLiteContext,
276                                      TfLiteNode* tfLiteNode,
277                                      int nodeIndex,
278                                      int32_t operatorCode)
279 {
280     armnn::IgnoreUnused(delegateData,
281                         tfLiteContext,
282                         tfLiteNode,
283                         nodeIndex,
284                         operatorCode);
285 
286     return kTfLiteError;
287 }
288 
289 } // namespace armnnDelegate
290