xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/layerTests/SoftmaxTestImpl.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "SoftmaxTestImpl.hpp"
7 
8 #include <armnnUtils/QuantizeHelper.hpp>
9 #include <ResolveType.hpp>
10 
11 
12 #include <armnn/backends/TensorHandle.hpp>
13 
14 #include <armnnTestUtils/TensorCopyUtils.hpp>
15 #include <armnnTestUtils/WorkloadTestUtils.hpp>
16 
17 #include <armnnTestUtils/TensorHelpers.hpp>
18 
19 #include <algorithm>
20 
21 namespace
22 {
23 
24 struct Simple3dSoftmaxOutputData
25 {
26     const std::vector<float> outputData =
27     {
28         0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
29         0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f
30     };
31 
32     const armnn::TensorShape inputShape{ 1, 8, 1 };
33 
34     const std::vector<float> inputData =
35     {
36         0.0f, 1.0f, 0.0f, 0.0f,
37         0.5f, 0.0f, 0.0f, 0.0f,
38     };
39 };
40 
41 struct Simple4dSoftmaxData
42 {
43     const armnn::TensorShape inputShape{ 1, 8, 1, 1 };
44 
45     const std::vector<float> outputData =
46     {
47         0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
48         0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f
49     };
50 
51     const std::vector<float> inputData =
52     {
53          0.0f, 1.0f, 0.0f, 0.0f,
54          0.5f, 0.0f, 0.0f, 0.0f
55     };
56 };
57 
58 template<armnn::DataType ArmnnType, std::size_t n, typename T = armnn::ResolveType<ArmnnType>>
SimpleSoftmaxBaseTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta,const armnn::TensorShape & inputShape,const std::vector<float> & outputData,const std::vector<float> & inputData,int axis=-1)59 LayerTestResult<T, n> SimpleSoftmaxBaseTestImpl(
60     armnn::IWorkloadFactory& workloadFactory,
61     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
62     const armnn::ITensorHandleFactory& tensorHandleFactory,
63     float beta,
64     const armnn::TensorShape& inputShape,
65     const std::vector<float>& outputData,
66     const std::vector<float>& inputData,
67     int axis = -1)
68 {
69     using std::exp;
70 
71     const float qScale = 1.f / 256.f;
72     const int qOffset = 0;
73 
74     armnn::TensorInfo inputTensorInfo;
75     armnn::TensorInfo outputTensorInfo;
76 
77     inputTensorInfo = armnn::TensorInfo(inputShape, ArmnnType);
78     inputTensorInfo.SetQuantizationScale(qScale);
79     inputTensorInfo.SetQuantizationOffset(qOffset);
80 
81     outputTensorInfo = armnn::TensorInfo(inputShape, ArmnnType);
82     outputTensorInfo.SetQuantizationScale(qScale);
83     outputTensorInfo.SetQuantizationOffset(qOffset);
84 
85     // Each row is independently softmax'd.
86     std::vector<T> input = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
87     std::vector<T> expectedOutput = armnnUtils::QuantizedVector<T>(outputData, qScale, qOffset);
88     std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
89 
90     std::unique_ptr<armnn::ITensorHandle> inputHandle  = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
91     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
92 
93     armnn::SoftmaxQueueDescriptor data;
94     data.m_Parameters.m_Beta = beta;
95     data.m_Parameters.m_Axis = axis;
96 
97     armnn::WorkloadInfo info;
98     AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
99     AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
100 
101     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Softmax, data, info);
102 
103     inputHandle->Allocate();
104     outputHandle->Allocate();
105     CopyDataToITensorHandle(inputHandle.get(), input.data());
106 
107     ARMNN_ASSERT(workload);
108 
109     ExecuteWorkload(*workload, memoryManager);
110 
111     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
112 
113     return LayerTestResult<T, n>(actualOutput,
114                                  expectedOutput,
115                                  outputHandle->GetShape(),
116                                  outputTensorInfo.GetShape());
117 }
118 
119 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
SimpleSoftmaxTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)120 LayerTestResult<T, 2> SimpleSoftmaxTestImpl(
121     armnn::IWorkloadFactory& workloadFactory,
122     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
123     const armnn::ITensorHandleFactory& tensorHandleFactory,
124     float beta)
125 {
126     using std::exp;
127     const armnn::TensorShape inputShape{ 2, 4 };
128 
129     float x0[4] = { exp((0.f - 1.0f) * beta), exp((1.0f - 1.0f) * beta),
130                     exp((0.0f - 1.0f) * beta), exp((0.0f - 1.0f) * beta) };
131     float sum0 = x0[0] + x0[1] + x0[2] + x0[3];
132     float x1[4] = { exp((0.5f - 0.5f) * beta), exp((0.0f - 0.5f) * beta),
133                     exp((0.0f - 0.5f) * beta), exp((0.0f - 0.5f) * beta) };
134     float sum1 = x1[0] + x1[1] + x1[2] + x1[3];
135 
136     const std::vector<float> outputData = { x0[0] / sum0, x0[1] / sum0, x0[2] / sum0, x0[3] / sum0,
137                                             x1[0] / sum1, x1[1] / sum1, x1[2] / sum1, x1[3] / sum1 };
138 
139     const std::vector<float> inputData =
140             {
141                 0.f, 1.f, 0.f, 0.f,
142                 .5f, 0.f, 0.f, 0.f,
143             };
144 
145     return SimpleSoftmaxBaseTestImpl<ArmnnType, 2>(workloadFactory, memoryManager, tensorHandleFactory, beta,
146                                                    inputShape, outputData, inputData);
147 }
148 
149 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
SimpleSoftmaxTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta,int axis)150 LayerTestResult<T, 2> SimpleSoftmaxTestImpl(
151         armnn::IWorkloadFactory& workloadFactory,
152         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
153         const armnn::ITensorHandleFactory& tensorHandleFactory,
154         float beta,
155         int axis)
156 {
157     armnn::TensorShape inputShape;
158     std::vector<float> inputData;
159     std::vector<float> outputData;
160     switch (axis)
161     {
162     case -2:
163     case 0:
164         {
165         inputShape = {5, 2};
166 
167         inputData =
168                 {
169                         17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
170                 };
171 
172         outputData =
173                 {
174                         0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
175                         0.087144312427294f,
176                         0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
177                         7.246299848982885e-08f
178                 };
179         break;
180         }
181     case -1:
182     case 1:
183         {
184         inputShape = {2, 5};
185 
186         inputData =
187                 {
188                         17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
189                 };
190 
191         outputData =
192                 {
193                         0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
194                         7.246299848982885e-08f,
195                         0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
196                         7.246299848982885e-08f
197                 };
198         break;
199         }
200     }
201     return SimpleSoftmaxBaseTestImpl<ArmnnType, 2>(workloadFactory, memoryManager, tensorHandleFactory, beta,
202                                                    inputShape, outputData, inputData, axis);
203 }
204 
205 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
Simple3dSoftmaxTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta,const armnn::TensorShape & inputShape,const std::vector<float> & outputData,const std::vector<float> & inputData,int axis=1)206 LayerTestResult<T, 3> Simple3dSoftmaxTestImpl(
207     armnn::IWorkloadFactory& workloadFactory,
208     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
209     const armnn::ITensorHandleFactory& tensorHandleFactory,
210     float beta,
211     const armnn::TensorShape& inputShape,
212     const std::vector<float>& outputData,
213     const std::vector<float>& inputData,
214     int axis = 1)
215 {
216     return SimpleSoftmaxBaseTestImpl<ArmnnType, 3>(workloadFactory, memoryManager, tensorHandleFactory, beta,
217                                                    inputShape, outputData, inputData, axis);
218 }
219 
220 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
Simple4dSoftmaxTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta,const armnn::TensorShape & inputShape,const std::vector<float> & outputData,const std::vector<float> & inputData,int axis=1)221 LayerTestResult<T, 4> Simple4dSoftmaxTestImpl(
222     armnn::IWorkloadFactory& workloadFactory,
223     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
224     const armnn::ITensorHandleFactory& tensorHandleFactory,
225     float beta,
226     const armnn::TensorShape& inputShape,
227     const std::vector<float>& outputData,
228     const std::vector<float>& inputData,
229     int axis = 1)
230 {
231 
232     return SimpleSoftmaxBaseTestImpl<ArmnnType, 4>(workloadFactory, memoryManager, tensorHandleFactory, beta,
233                                                    inputShape, outputData, inputData, axis);
234 }
235 
236 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
CompareSoftmaxTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,armnn::IWorkloadFactory & refWorkloadFactory,const armnn::ITensorHandleFactory & tensorHandleFactory,const armnn::ITensorHandleFactory & refTensorHandleFactory,float beta)237 LayerTestResult<T, 2> CompareSoftmaxTestImpl(
238         armnn::IWorkloadFactory& workloadFactory,
239         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
240         armnn::IWorkloadFactory& refWorkloadFactory,
241         const armnn::ITensorHandleFactory& tensorHandleFactory,
242         const armnn::ITensorHandleFactory& refTensorHandleFactory,
243         float beta)
244 {
245     const int batchSize = 20;
246     const int channels = 30;
247 
248     armnn::TensorInfo inputTensorInfo;
249     armnn::TensorInfo outputTensorInfo;
250 
251     unsigned int inputShape[] = { batchSize, channels };
252 
253     inputTensorInfo = armnn::TensorInfo(2, inputShape, ArmnnType);
254     outputTensorInfo = armnn::TensorInfo(2, inputShape, ArmnnType);
255     float qScale = 1.f / 256.f;
256     int qOffset = 0;
257     inputTensorInfo.SetQuantizationScale(qScale);
258     inputTensorInfo.SetQuantizationOffset(qOffset);
259     outputTensorInfo.SetQuantizationScale(qScale);
260     outputTensorInfo.SetQuantizationOffset(qOffset);
261 
262     auto input = MakeRandomTensor<T>(inputTensorInfo, 0xF00D, 0.0f, 1.0f);
263     std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
264     std::vector<T> expectedOutput(outputTensorInfo.GetNumElements());
265 
266     std::unique_ptr<armnn::ITensorHandle> inputHandle  = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
267     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
268 
269     armnn::SoftmaxQueueDescriptor data;
270     data.m_Parameters.m_Beta = beta;
271 
272     armnn::WorkloadInfo info;
273     AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
274     AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
275 
276     std::unique_ptr<armnn::ITensorHandle> outputHandleRef =
277         refTensorHandleFactory.CreateTensorHandle(outputTensorInfo);
278     std::unique_ptr<armnn::ITensorHandle> inputHandleRef  =
279         refTensorHandleFactory.CreateTensorHandle(inputTensorInfo);
280 
281     armnn::SoftmaxQueueDescriptor refData = data;
282     armnn::WorkloadInfo refInfo = info;
283     SetWorkloadInput(refData, refInfo, 0, inputTensorInfo, inputHandleRef.get());
284     SetWorkloadOutput(refData, refInfo, 0, outputTensorInfo, outputHandleRef.get());
285 
286     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Softmax, data, info);
287     std::unique_ptr<armnn::IWorkload> workloadRef = refWorkloadFactory.CreateWorkload(armnn::LayerType::Softmax,
288                                                                                       refData,
289                                                                                       refInfo);
290 
291     outputHandleRef->Allocate();
292     inputHandleRef->Allocate();
293 
294     inputHandle->Allocate();
295     outputHandle->Allocate();
296 
297     CopyDataToITensorHandle(inputHandle.get(), input.data());
298     CopyDataToITensorHandle(inputHandleRef.get(), input.data());
299 
300     ExecuteWorkload(*workload, memoryManager);
301 
302     workloadRef->Execute();
303 
304     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
305     CopyDataFromITensorHandle(expectedOutput.data(), outputHandleRef.get());
306 
307     return LayerTestResult<T, 2>(actualOutput,
308                                  expectedOutput,
309                                  outputHandle->GetShape(),
310                                  outputTensorInfo.GetShape());
311 }
312 
313 } // anonymous namespace
314 
SimpleSoftmaxTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)315 LayerTestResult<float,2> SimpleSoftmaxTest(
316     armnn::IWorkloadFactory& workloadFactory,
317     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
318     const armnn::ITensorHandleFactory& tensorHandleFactory,
319     float beta)
320 {
321     return SimpleSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, tensorHandleFactory, beta);
322 }
323 
SimpleAxisSoftmaxTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta,int axis)324 LayerTestResult<float,2> SimpleAxisSoftmaxTest(
325         armnn::IWorkloadFactory& workloadFactory,
326         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
327         const armnn::ITensorHandleFactory& tensorHandleFactory,
328         float beta,
329         int axis)
330 {
331     return SimpleSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager,
332                                                            tensorHandleFactory, beta, axis);
333 }
334 
Simple3dSoftmaxTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)335 LayerTestResult<float,3> Simple3dSoftmaxTest(
336         armnn::IWorkloadFactory& workloadFactory,
337         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
338         const armnn::ITensorHandleFactory& tensorHandleFactory,
339         float beta)
340 {
341     Simple3dSoftmaxOutputData data;
342     return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, tensorHandleFactory, beta,
343                                                              data.inputShape, data.outputData, data.inputData);
344 }
345 
Simple3dAxisSoftmaxTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta,int axis)346 LayerTestResult<float,3> Simple3dAxisSoftmaxTest(
347         armnn::IWorkloadFactory& workloadFactory,
348         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
349         const armnn::ITensorHandleFactory& tensorHandleFactory,
350         float beta,
351         int axis)
352 {
353     armnn::TensorShape inputShape;
354     std::vector<float> inputData;
355     std::vector<float> outputData;
356     switch (axis)
357     {
358     case -3:
359     case 0:
360         {
361             inputShape = {5, 2, 2};
362 
363             inputData =
364                     {
365                             17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
366 
367                             15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f
368                     };
369 
370             outputData =
371                     {
372                             0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
373                             0.236882800924671f,
374                             0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
375                             0.087144312427294f,
376 
377                             0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
378                             0.032058600957022f,
379                             0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
380                             7.246299848982885e-08f
381                     };
382             break;
383         }
384     case -2:
385     case 1:
386         {
387             inputShape = {2, 5, 2};
388 
389             inputData =
390                     {
391                             17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
392 
393                             17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
394                     };
395 
396             outputData =
397                     {
398                             0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
399                             0.087144312427294f,
400                             0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
401                             7.246299848982885e-08f,
402 
403                             0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
404                             0.087144312427294f,
405                             0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
406                             7.246299848982885e-08f
407                     };
408         break;
409         }
410     case -1:
411     case 2:
412         {
413             inputShape = {2, 2, 5};
414 
415             inputData =
416                     {
417                             17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
418                             17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
419                     };
420 
421             outputData =
422                     {
423                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
424                             7.246299848982885e-08f,
425                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
426                             7.246299848982885e-08f,
427 
428                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
429                             7.246299848982885e-08f,
430                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
431                             7.246299848982885e-08f
432                     };
433             break;
434         }
435     }
436 
437     return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, tensorHandleFactory, beta,
438                                                              inputShape, outputData, inputData, axis);
439 }
440 
Simple4dSoftmaxTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)441 LayerTestResult<float,4> Simple4dSoftmaxTest(
442         armnn::IWorkloadFactory& workloadFactory,
443         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
444         const armnn::ITensorHandleFactory& tensorHandleFactory,
445         float beta)
446 {
447     Simple4dSoftmaxData data;
448     return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, tensorHandleFactory,
449                                                              beta, data.inputShape, data.outputData, data.inputData);
450 }
451 
Simple4dAxisSoftmaxTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta,int axis)452 LayerTestResult<float,4> Simple4dAxisSoftmaxTest(
453         armnn::IWorkloadFactory& workloadFactory,
454         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
455         const armnn::ITensorHandleFactory& tensorHandleFactory,
456         float beta,
457         int axis)
458 {
459     armnn::TensorShape inputShape;
460     std::vector<float> inputData;
461     std::vector<float> outputData;
462     switch (axis)
463     {
464     case -4:
465     case 0:
466         {
467             inputShape = {5, 2, 2, 2};
468 
469             inputData =
470                     {
471                             17.0f, -1.0f, 17.0f, -1.0f, 17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f,
472                             16.0f, -2.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f, 15.0f, -3.0f,
473                             15.0f, -3.0f, 15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 14.0f, -4.0f,
474                             14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f, 1.0f, -17.0f, 1.0f, -17.0f
475                     };
476 
477             outputData =
478                     {
479                             0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
480                             0.643914213228014f,
481                             0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.236882800924671f,
482                             0.236882800924671f,
483                             0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.236882800924671f,
484                             0.236882800924671f,
485                             0.236882800924671f, 0.087144312427294f, 0.087144312427294f, 0.087144312427294f,
486                             0.087144312427294f,
487 
488                             0.087144312427294f, 0.087144312427294f, 0.087144312427294f, 0.087144312427294f,
489                             0.032058600957022f,
490                             0.032058600957022f, 0.032058600957022f, 0.032058600957022f, 0.032058600957022f,
491                             0.032058600957022f,
492                             0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f,
493                             7.246299848982885e-08f,
494                             7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
495                             7.246299848982885e-08f, 7.246299848982885e-08f
496                     };
497             break;
498         }
499     case -3:
500     case 1:
501         {
502             inputShape = {2, 5, 2, 2};
503 
504             inputData =
505                     {
506                             17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
507                             15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f,
508                             17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
509                             15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f
510                     };
511 
512             outputData =
513                     {
514                             0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
515                             0.236882800924671f,
516                             0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
517                             0.087144312427294f,
518                             0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
519                             0.032058600957022f,
520                             0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
521                             7.246299848982885e-08f,
522 
523 
524                             0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
525                             0.236882800924671f,
526                             0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
527                             0.087144312427294f,
528                             0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
529                             0.032058600957022f,
530                             0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
531                             7.246299848982885e-08f
532                     };
533             break;
534         }
535     case -2:
536     case 2:
537         {
538         inputShape = {2, 2, 5, 2};
539 
540         inputData =
541                 {
542                         17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
543                         17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
544                         17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
545                         17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
546                 };
547 
548         outputData =
549                 {
550                         0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
551                         0.087144312427294f,
552                         0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
553                         7.246299848982885e-08f,
554                         0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
555                         0.087144312427294f,
556                         0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
557                         7.246299848982885e-08f,
558 
559                         0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
560                         0.087144312427294f,
561                         0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
562                         7.246299848982885e-08f,
563                         0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
564                         0.087144312427294f,
565                         0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
566                         7.246299848982885e-08f
567                 };
568         break;
569         }
570     case -1:
571     case 3:
572         {
573             inputShape = {2, 2, 2, 5};
574 
575             inputData =
576                     {
577                             17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
578                             17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
579                             17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
580                             17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
581                     };
582 
583             outputData =
584                     {
585                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
586                             7.246299848982885e-08f,
587                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
588                             7.246299848982885e-08f,
589                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
590                             7.246299848982885e-08f,
591                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
592                             7.246299848982885e-08f,
593 
594                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
595                             7.246299848982885e-08f,
596                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
597                             7.246299848982885e-08f,
598                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
599                             7.246299848982885e-08f,
600                             0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
601                             7.246299848982885e-08f
602                     };
603             break;
604         }
605     }
606 
607     return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(
608         workloadFactory,
609         memoryManager,
610         tensorHandleFactory,
611         beta,
612         inputShape,
613         outputData,
614         inputData,
615         axis);
616 }
617 
SimpleSoftmaxUint8Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)618 LayerTestResult<uint8_t,2> SimpleSoftmaxUint8Test(
619     armnn::IWorkloadFactory& workloadFactory,
620     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
621     const armnn::ITensorHandleFactory& tensorHandleFactory,
622     float beta)
623 {
624     return SimpleSoftmaxTestImpl<armnn::DataType::QAsymmU8>(workloadFactory, memoryManager, tensorHandleFactory, beta);
625 }
626 
Simple3dSoftmaxUint8Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)627 LayerTestResult<uint8_t,3> Simple3dSoftmaxUint8Test(
628         armnn::IWorkloadFactory& workloadFactory,
629         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
630         const armnn::ITensorHandleFactory& tensorHandleFactory,
631         float beta)
632 {
633     Simple3dSoftmaxOutputData data;
634     return Simple3dSoftmaxTestImpl<armnn::DataType::QAsymmU8>(
635         workloadFactory,
636         memoryManager,
637         tensorHandleFactory,
638         beta,
639         data.inputShape,
640         data.outputData,
641         data.inputData);
642 }
643 
Simple4dSoftmaxUint8Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)644 LayerTestResult<uint8_t,4> Simple4dSoftmaxUint8Test(
645         armnn::IWorkloadFactory& workloadFactory,
646         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
647         const armnn::ITensorHandleFactory& tensorHandleFactory,
648         float beta)
649 {
650     Simple4dSoftmaxData data;
651 
652     return Simple4dSoftmaxTestImpl<armnn::DataType::QAsymmU8>(workloadFactory, memoryManager, tensorHandleFactory, beta,
653                                                                      data.inputShape, data.outputData, data.inputData);
654 }
655 
SimpleSoftmaxFloat16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)656 LayerTestResult<armnn::Half,2> SimpleSoftmaxFloat16Test(
657         armnn::IWorkloadFactory& workloadFactory,
658         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
659         const armnn::ITensorHandleFactory& tensorHandleFactory,
660         float beta)
661 {
662     return SimpleSoftmaxTestImpl<armnn::DataType::Float16>(workloadFactory, memoryManager, tensorHandleFactory, beta);
663 }
664 
Simple3dSoftmaxFloat16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)665 LayerTestResult<armnn::Half,3> Simple3dSoftmaxFloat16Test(
666         armnn::IWorkloadFactory& workloadFactory,
667         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
668         const armnn::ITensorHandleFactory& tensorHandleFactory,
669         float beta)
670 {
671     Simple3dSoftmaxOutputData data;
672     return Simple3dSoftmaxTestImpl<armnn::DataType::Float16>(workloadFactory, memoryManager, tensorHandleFactory, beta,
673                                                              data.inputShape, data.outputData, data.inputData);
674 }
675 
Simple4dSoftmaxFloat16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)676 LayerTestResult<armnn::Half,4> Simple4dSoftmaxFloat16Test(
677         armnn::IWorkloadFactory& workloadFactory,
678         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
679         const armnn::ITensorHandleFactory& tensorHandleFactory,
680         float beta)
681 {
682     Simple4dSoftmaxData data;
683     return Simple4dSoftmaxTestImpl<armnn::DataType::Float16>(workloadFactory, memoryManager, tensorHandleFactory, beta,
684                                                              data.inputShape, data.outputData, data.inputData);
685 }
686 
SimpleSoftmaxUint16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)687 LayerTestResult<int16_t,2> SimpleSoftmaxUint16Test(
688         armnn::IWorkloadFactory& workloadFactory,
689         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
690         const armnn::ITensorHandleFactory& tensorHandleFactory,
691         float beta)
692 {
693     return SimpleSoftmaxTestImpl<armnn::DataType::QSymmS16>(workloadFactory, memoryManager, tensorHandleFactory, beta);
694 }
695 
Simple3dSoftmaxUint16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)696 LayerTestResult<int16_t,3> Simple3dSoftmaxUint16Test(
697         armnn::IWorkloadFactory& workloadFactory,
698         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
699         const armnn::ITensorHandleFactory& tensorHandleFactory,
700         float beta)
701 {
702     Simple3dSoftmaxOutputData data;
703     return Simple3dSoftmaxTestImpl<armnn::DataType::QSymmS16>(workloadFactory, memoryManager, tensorHandleFactory, beta,
704                                                                      data.inputShape, data.outputData, data.inputData);
705 }
706 
Simple4dSoftmaxUint16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,float beta)707 LayerTestResult<int16_t,4> Simple4dSoftmaxUint16Test(
708         armnn::IWorkloadFactory& workloadFactory,
709         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
710         const armnn::ITensorHandleFactory& tensorHandleFactory,
711         float beta)
712 {
713     Simple4dSoftmaxData data;
714 
715     return Simple4dSoftmaxTestImpl<armnn::DataType::QSymmS16>(workloadFactory, memoryManager, tensorHandleFactory, beta,
716                                                                      data.inputShape, data.outputData, data.inputData);
717 }
718 
CompareSoftmaxTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,armnn::IWorkloadFactory & refWorkloadFactory,const armnn::ITensorHandleFactory & tensorHandleFactory,const armnn::ITensorHandleFactory & refTensorHandleFactory,float beta)719 LayerTestResult<float,2> CompareSoftmaxTest(
720     armnn::IWorkloadFactory& workloadFactory,
721     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
722     armnn::IWorkloadFactory& refWorkloadFactory,
723     const armnn::ITensorHandleFactory& tensorHandleFactory,
724     const armnn::ITensorHandleFactory& refTensorHandleFactory,
725     float beta)
726 {
727     return CompareSoftmaxTestImpl<armnn::DataType::Float32>(
728         workloadFactory, memoryManager, refWorkloadFactory, tensorHandleFactory, refTensorHandleFactory, beta);
729 }
730 
CompareSoftmaxUint8Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,armnn::IWorkloadFactory & refWorkloadFactory,const armnn::ITensorHandleFactory & tensorHandleFactory,const armnn::ITensorHandleFactory & refTensorHandleFactory,float beta)731 LayerTestResult<uint8_t,2> CompareSoftmaxUint8Test(
732     armnn::IWorkloadFactory& workloadFactory,
733     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
734     armnn::IWorkloadFactory& refWorkloadFactory,
735     const armnn::ITensorHandleFactory& tensorHandleFactory,
736     const armnn::ITensorHandleFactory& refTensorHandleFactory,
737     float beta)
738 {
739     return CompareSoftmaxTestImpl<armnn::DataType::QAsymmU8>(
740         workloadFactory, memoryManager, refWorkloadFactory, tensorHandleFactory, refTensorHandleFactory, beta);
741 }
742