xref: /aosp_15_r20/external/armnn/delegate/classic/src/SpaceDepth.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <tensorflow/lite/builtin_ops.h>
9 #include <tensorflow/lite/c/builtin_op_data.h>
10 #include <tensorflow/lite/c/common.h>
11 #include <tensorflow/lite/minimal_logging.h>
12 
13 namespace armnnDelegate
14 {
15 
VisitSpaceToDepthOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)16 TfLiteStatus VisitSpaceToDepthOperator(DelegateData& delegateData,
17                                        TfLiteContext* tfLiteContext,
18                                        TfLiteNode* tfLiteNode,
19                                        int nodeIndex,
20                                        int32_t operatorCode)
21 {
22     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
23     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
24 
25     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
27     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
28     {
29         return kTfLiteError;
30     }
31 
32     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
33     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
34     {
35         return kTfLiteError;
36     }
37 
38     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
39     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
40 
41     armnn::SpaceToDepthDescriptor descriptor;
42     auto* params = reinterpret_cast<TfLiteSpaceToDepthParams*>(tfLiteNode->builtin_data);
43     descriptor.m_BlockSize = params->block_size;
44 
45     bool isSupported = false;
46     armnn::BackendId setBackend;
47     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
48     {
49         FORWARD_LAYER_SUPPORT_FUNC("SPACE_TO_DEPTH",
50                                    tfLiteContext,
51                                    IsSpaceToDepthSupported,
52                                    delegateData.m_Backends,
53                                    isSupported,
54                                    setBackend,
55                                    inputTensorInfo,
56                                    outInfo,
57                                    descriptor);
58     };
59 
60     if (!delegateData.m_Network)
61     {
62         validateFunc(outputTensorInfo, isSupported);
63         return isSupported ? kTfLiteOk : kTfLiteError;
64     }
65 
66     // Add a SpaceToDepth layer
67     armnn::IConnectableLayer* layer = delegateData.m_Network->AddSpaceToDepthLayer(descriptor);
68     layer->SetBackendId(setBackend);
69     ARMNN_ASSERT(layer != nullptr);
70 
71     // try to connect the Constant Inputs if there are any
72     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
73     {
74         return kTfLiteError;
75     }
76 
77     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
78     outputSlot.SetTensorInfo(outputTensorInfo);
79 
80     // Connect
81     return Connect(layer, tfLiteNode, delegateData);
82 }
83 
VisitDepthToSpaceOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)84 TfLiteStatus VisitDepthToSpaceOperator(DelegateData& delegateData,
85                                        TfLiteContext* tfLiteContext,
86                                        TfLiteNode* tfLiteNode,
87                                        int nodeIndex,
88                                        int32_t operatorCode)
89 {
90     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
91     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
92 
93     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
94     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
95     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
96     {
97         return kTfLiteError;
98     }
99 
100     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
101     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
102     {
103         return kTfLiteError;
104     }
105 
106     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
107     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
108 
109     armnn::DepthToSpaceDescriptor descriptor;
110     auto* params = reinterpret_cast<TfLiteDepthToSpaceParams*>(tfLiteNode->builtin_data);
111     descriptor.m_BlockSize = params->block_size;
112 
113     bool isSupported = false;
114     armnn::BackendId setBackend;
115     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
116     {
117         FORWARD_LAYER_SUPPORT_FUNC("DEPTH_TO_SPACE",
118                                    tfLiteContext,
119                                    IsDepthToSpaceSupported,
120                                    delegateData.m_Backends,
121                                    isSupported,
122                                    setBackend,
123                                    inputTensorInfo,
124                                    outInfo,
125                                    descriptor);
126     };
127 
128     if (!delegateData.m_Network)
129     {
130         validateFunc(outputTensorInfo, isSupported);
131         return isSupported ? kTfLiteOk : kTfLiteError;
132     }
133 
134     // Add a DepthToSpace layer
135     armnn::IConnectableLayer* layer = delegateData.m_Network->AddDepthToSpaceLayer(descriptor);
136     layer->SetBackendId(setBackend);
137     ARMNN_ASSERT(layer != nullptr);
138 
139     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
140     outputSlot.SetTensorInfo(outputTensorInfo);
141 
142     // try to connect the Constant Inputs if there are any
143     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
144     {
145         return kTfLiteError;
146     }
147 
148     // Connect
149     return Connect(layer, tfLiteNode, delegateData);
150 }
151 
152 } // namespace armnnDelegate
153