xref: /aosp_15_r20/external/armnn/delegate/classic/src/Softmax.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 
15 namespace armnnDelegate
16 {
17 
ValidateSoftmaxOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::SoftmaxDescriptor & descriptor)18 TfLiteStatus ValidateSoftmaxOperator(DelegateData& delegateData,
19                                      TfLiteContext* tfLiteContext,
20                                      const armnn::TensorInfo& inputInfo,
21                                      const armnn::TensorInfo& outputTensorInfo,
22                                      const armnn::SoftmaxDescriptor& descriptor)
23 {
24     bool isSupported = false;
25     FORWARD_LAYER_SUPPORT_FUNC("SOFTMAX",
26                                tfLiteContext,
27                                IsSoftmaxSupported,
28                                delegateData.m_Backends,
29                                isSupported,
30                                armnn::BackendId(),
31                                inputInfo,
32                                outputTensorInfo,
33                                descriptor);
34     return isSupported ? kTfLiteOk : kTfLiteError;
35 }
36 
37 
ValidateLogSoftmaxOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::LogSoftmaxDescriptor & descriptor)38 TfLiteStatus ValidateLogSoftmaxOperator(DelegateData& delegateData,
39                                         TfLiteContext* tfLiteContext,
40                                         const armnn::TensorInfo& inputInfo,
41                                         const armnn::TensorInfo& outputTensorInfo,
42                                         const armnn::LogSoftmaxDescriptor& descriptor)
43 {
44     bool isSupported = false;
45     FORWARD_LAYER_SUPPORT_FUNC("LOG_SOFTMAX",
46                                tfLiteContext,
47                                IsLogSoftmaxSupported,
48                                delegateData.m_Backends,
49                                isSupported,
50                                armnn::BackendId(),
51                                inputInfo,
52                                outputTensorInfo,
53                                descriptor);
54     return isSupported ? kTfLiteOk : kTfLiteError;
55 }
56 
VisitSoftmaxOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t softmaxOperatorCode)57 TfLiteStatus VisitSoftmaxOperator(DelegateData& delegateData,
58                                   TfLiteContext* tfLiteContext,
59                                   TfLiteNode* tfLiteNode,
60                                   int nodeIndex,
61                                   int32_t softmaxOperatorCode)
62 {
63     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
64     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
65 
66     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
67     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
68     if (IsDynamicTensor(tfLiteInputTensor))
69     {
70         TF_LITE_MAYBE_KERNEL_LOG(
71             tfLiteContext,
72             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in node #%d: ",
73             nodeIndex);
74         return kTfLiteError;
75     }
76     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
77     if (IsDynamicTensor(tfLiteOutputTensor))
78     {
79         TF_LITE_MAYBE_KERNEL_LOG(
80             tfLiteContext,
81             "TfLiteArmnnDelegate: Dynamic output tensors are not supported in node #%d: ",
82             nodeIndex);
83         return kTfLiteError;
84     }
85 
86     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
87     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
88 
89 
90     if (!delegateData.m_Network)
91     {
92         switch(softmaxOperatorCode)
93         {
94             case kTfLiteBuiltinSoftmax:
95             {
96                 armnn::SoftmaxDescriptor descriptor;
97                 auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
98                 descriptor.m_Beta = params->beta;
99                 return ValidateSoftmaxOperator(delegateData,
100                                                tfLiteContext,
101                                                inputTensorInfo,
102                                                outputTensorInfo,
103                                                descriptor);
104             }
105             case kTfLiteBuiltinLogSoftmax:
106             {
107                 armnn::LogSoftmaxDescriptor descriptor;
108                 return ValidateLogSoftmaxOperator(delegateData,
109                                                   tfLiteContext,
110                                                   inputTensorInfo,
111                                                   outputTensorInfo,
112                                                   descriptor);
113             }
114             default:
115                 return kTfLiteError;
116         }
117     }
118 
119     armnn::IConnectableLayer* softmaxLayer = nullptr;
120 
121     switch(softmaxOperatorCode)
122     {
123         case kTfLiteBuiltinSoftmax:
124         {
125             armnn::SoftmaxDescriptor descriptor;
126             auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
127             descriptor.m_Beta = params->beta;
128             softmaxLayer = delegateData.m_Network->AddSoftmaxLayer(descriptor);
129             break;
130         }
131         case kTfLiteBuiltinLogSoftmax:
132         {
133             armnn::LogSoftmaxDescriptor descriptor;
134             softmaxLayer = delegateData.m_Network->AddLogSoftmaxLayer(descriptor);
135             break;
136         }
137         default:
138             return kTfLiteError;
139     }
140     ARMNN_ASSERT(softmaxLayer != nullptr);
141 
142     armnn::IOutputSlot& outputSlot = softmaxLayer->GetOutputSlot(0);
143     outputSlot.SetTensorInfo(outputTensorInfo);
144 
145     // try to connect the Constant Inputs if there are any
146     if(ProcessInputs(softmaxLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
147     {
148         return kTfLiteError;
149     }
150 
151     // Connect
152     return Connect(softmaxLayer, tfLiteNode, delegateData);
153 }
154 
155 } // namespace armnnDelegate
156