xref: /aosp_15_r20/external/armnn/delegate/classic/src/Slice.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/utility/IgnoreUnused.hpp>
9 
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 
15 namespace armnnDelegate
16 {
17 
VisitSliceOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t sliceOperatorCode)18 TfLiteStatus VisitSliceOperator(DelegateData& delegateData,
19                                 TfLiteContext* tfLiteContext,
20                                 TfLiteNode* tfLiteNode,
21                                 int nodeIndex,
22                                 int32_t sliceOperatorCode)
23 {
24     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
25     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26 
27     // Read inputs [input, begin, size]
28     int numInputs = tfLiteNode->inputs->size;
29     std::vector<const TfLiteTensor*> tfLiteInputs;
30     tfLiteInputs.reserve(numInputs);
31     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
32     for (int i = 0; i < numInputs; i++)
33     {
34         const TfLiteTensor* inputTensor = &tfLiteTensors[tfLiteNode->inputs->data[i]];
35         tfLiteInputs.push_back(inputTensor);
36         if (!IsValid(tfLiteContext, *inputTensor, sliceOperatorCode, nodeIndex))
37         {
38             return kTfLiteError;
39         }
40     }
41 
42     // We save the begin and size tensors in our descriptor. Therefore we have to read those values from inputs
43     int inputRank = tfLiteInputs[0]->dims->size;
44     auto ReadInt32Input = [&](int inputIndex, std::vector<uint32_t>& outputData) ->  TfLiteStatus
45     {
46         if (tfLiteInputs[inputIndex]->type != kTfLiteInt32)
47         {
48             TF_LITE_MAYBE_KERNEL_LOG(
49                     tfLiteContext,
50                     "TfLiteArmnnDelegate: The Begin- and Size-Tensors of the Slice operation need to "
51                     "be of type int32. Operator: #%d node #%d: ",
52                     sliceOperatorCode, nodeIndex);
53             return kTfLiteError;
54         }
55         int rank = tfLiteInputs[inputIndex]->dims->size;
56         if (rank != 1)
57         {
58             TF_LITE_MAYBE_KERNEL_LOG(
59                     tfLiteContext,
60                     "TfLiteArmnnDelegate: The Begin- and Size-Tensors of the Slice operation need to "
61                     "be a 1D-Tensor. Operator: #%d node #%d: ",
62                     sliceOperatorCode, nodeIndex);
63             return kTfLiteError;
64         }
65         int numValues = tfLiteInputs[inputIndex]->dims->data[0];
66         if (numValues != inputRank)
67         {
68             TF_LITE_MAYBE_KERNEL_LOG(
69                     tfLiteContext,
70                     "TfLiteArmnnDelegate: The number of values in the Begin- and Size-Tensors of the "
71                     "Slice operation need to be equal to the rank of the Input-Tensor. Operator: #%d node #%d: ",
72                     sliceOperatorCode, nodeIndex);
73             return kTfLiteError;
74         }
75         // return tensor data
76         auto* tensorDataPtr = tflite::GetTensorData<uint32_t>(tfLiteInputs[inputIndex]);
77         outputData.assign(tensorDataPtr, tensorDataPtr+numValues);
78         return kTfLiteOk;
79     };
80 
81     std::vector<uint32_t> begin;
82     if (ReadInt32Input(1, begin) != kTfLiteOk)
83         return kTfLiteError;
84     std::vector<uint32_t> size;
85     if (ReadInt32Input(2, size) != kTfLiteOk)
86         return kTfLiteError;
87 
88     // Write all data to the descriptor
89     armnn::SliceDescriptor descriptor(begin, size);
90 
91     // Validate output
92     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
93     if (!IsValid(tfLiteContext, tfLiteOutputTensor, sliceOperatorCode, nodeIndex))
94     {
95         return kTfLiteError;
96     }
97 
98     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(*tfLiteInputs[0]);
99     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
100 
101     bool isSupported = false;
102     armnn::BackendId setBackend;
103     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
104     {
105         FORWARD_LAYER_SUPPORT_FUNC("SLICE",
106                                    tfLiteContext,
107                                    IsSliceSupported,
108                                    delegateData.m_Backends,
109                                    isSupported,
110                                    setBackend,
111                                    inputTensorInfo,
112                                    outInfo,
113                                    descriptor);
114     };
115 
116     if (!delegateData.m_Network)
117     {
118         validateFunc(outputTensorInfo, isSupported);
119         return isSupported ? kTfLiteOk : kTfLiteError;
120     }
121 
122     // Add a Slice layer
123     armnn::IConnectableLayer* layer = delegateData.m_Network->AddSliceLayer(descriptor);
124     layer->SetBackendId(setBackend);
125     ARMNN_ASSERT(layer != nullptr);
126 
127     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
128     outputSlot.SetTensorInfo(outputTensorInfo);
129 
130     // try to connect the Constant Inputs if there are any
131     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
132     {
133         return kTfLiteError;
134     }
135 
136     // Connect
137     return Connect(layer, tfLiteNode, delegateData);
138 }
139 
140 } // namespace armnnDelegate
141 
142