xref: /aosp_15_r20/external/armnn/delegate/classic/src/Gather.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020,2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 
10 #include <algorithm>
11 #include <iterator>
12 #include <string>
13 #include <vector>
14 
15 namespace armnnDelegate
16 {
VisitGatherOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)17 TfLiteStatus VisitGatherOperator(DelegateData& delegateData,
18                                  TfLiteContext* tfLiteContext,
19                                  TfLiteNode* tfLiteNode,
20                                  int nodeIndex,
21                                  int32_t operatorCode)
22 {
23     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25 
26     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27 
28     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
29     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
30     {
31         return kTfLiteError;
32     }
33 
34     const TfLiteTensor& tfLiteIndicesTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
35     if (!IsValid(tfLiteContext, tfLiteIndicesTensor, operatorCode, nodeIndex))
36     {
37         return kTfLiteError;
38     }
39 
40     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
41     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
42     {
43         return kTfLiteError;
44     }
45 
46     auto* gatherParameters = reinterpret_cast<TfLiteGatherParams*>(tfLiteNode->builtin_data);
47     auto axis = gatherParameters->axis;
48 
49     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
50     const armnn::TensorInfo& indicesTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteIndicesTensor);
51     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
52     armnn::GatherDescriptor gatherDescriptor;
53     gatherDescriptor.m_Axis = axis;
54 
55     auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
56     auto indicesDimensions = indicesTensorInfo.GetNumDimensions();
57     auto outputDimensions = outputTensorInfo.GetNumDimensions();
58     if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
59     {
60         TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext,
61             "TfLiteArmnnDelegate: Operation has invalid axis: %d. It is out of bounds [-%d, %d))",
62             axis, inputDimensions, inputDimensions);
63         return kTfLiteError;
64     }
65     if (outputDimensions != static_cast<unsigned int>(inputDimensions) + indicesDimensions - 1)
66     {
67         TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext,
68             "Operation has invalid output dimensions: %d. Output must be an (%d + %d - 1)-D tensor",
69             outputDimensions, inputDimensions, indicesDimensions);
70         return kTfLiteError;
71     }
72 
73     armnn::BackendId setBackend;
74     if (!delegateData.m_Network)
75     {
76         // Check if supported
77         bool isSupported = false;
78         FORWARD_LAYER_SUPPORT_FUNC("GATHER",
79                                    tfLiteContext,
80                                    IsGatherSupported,
81                                    delegateData.m_Backends,
82                                    isSupported,
83                                    setBackend,
84                                    inputTensorInfo,
85                                    indicesTensorInfo,
86                                    outputTensorInfo,
87                                    gatherDescriptor);
88         return isSupported ? kTfLiteOk : kTfLiteError;
89     }
90 
91     armnn::IConnectableLayer* layer = delegateData.m_Network->AddGatherLayer(gatherDescriptor);
92     layer->SetBackendId(setBackend);
93     ARMNN_ASSERT(layer != nullptr);
94     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
95 
96     auto inputsTensorsProcess = ProcessInputs(layer,
97                                               delegateData,
98                                               tfLiteContext,
99                                               tfLiteNode);
100     if (inputsTensorsProcess == kTfLiteError)
101     {
102         return inputsTensorsProcess;
103     }
104 
105     return Connect(layer, tfLiteNode, delegateData);
106 }
107 } // namespace armnnDelegate