xref: /aosp_15_r20/external/armnn/delegate/test/UnpackTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021,2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "UnpackTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker template <typename T>
UnpackAxis0Num4Test(tflite::TensorType tensorType,std::vector<armnn::BackendId> & backends)19*89c4ff92SAndroid Build Coastguard Worker void UnpackAxis0Num4Test(tflite::TensorType tensorType, std::vector<armnn::BackendId>& backends)
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 4, 1, 6 };
22*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape { 1, 6 };
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputValues { 1, 2, 3, 4, 5, 6,
25*89c4ff92SAndroid Build Coastguard Worker                                  7, 8, 9, 10, 11, 12,
26*89c4ff92SAndroid Build Coastguard Worker                                  13, 14, 15, 16, 17, 18,
27*89c4ff92SAndroid Build Coastguard Worker                                  19, 20, 21, 22, 23, 24 };
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputValues0 { 1, 2, 3, 4, 5, 6 };
30*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputValues1 { 7, 8, 9, 10, 11, 12 };
31*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputValues2 { 13, 14, 15, 16, 17, 18 };
32*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputValues3 { 19, 20, 21, 22, 23, 24 };
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<T>> expectedOutputValues{ expectedOutputValues0,
35*89c4ff92SAndroid Build Coastguard Worker                                                       expectedOutputValues1,
36*89c4ff92SAndroid Build Coastguard Worker                                                       expectedOutputValues2,
37*89c4ff92SAndroid Build Coastguard Worker                                                       expectedOutputValues3 };
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker     UnpackTest<T>(tflite::BuiltinOperator_UNPACK,
40*89c4ff92SAndroid Build Coastguard Worker                   tensorType,
41*89c4ff92SAndroid Build Coastguard Worker                   backends,
42*89c4ff92SAndroid Build Coastguard Worker                   inputShape,
43*89c4ff92SAndroid Build Coastguard Worker                   expectedOutputShape,
44*89c4ff92SAndroid Build Coastguard Worker                   inputValues,
45*89c4ff92SAndroid Build Coastguard Worker                   expectedOutputValues,
46*89c4ff92SAndroid Build Coastguard Worker                   0);
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker template <typename T>
UnpackAxis2Num6Test(tflite::TensorType tensorType,std::vector<armnn::BackendId> & backends)50*89c4ff92SAndroid Build Coastguard Worker void UnpackAxis2Num6Test(tflite::TensorType tensorType, std::vector<armnn::BackendId>& backends)
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 4, 1, 6 };
53*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape { 4, 1 };
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputValues { 1, 2, 3, 4, 5, 6,
56*89c4ff92SAndroid Build Coastguard Worker                                  7, 8, 9, 10, 11, 12,
57*89c4ff92SAndroid Build Coastguard Worker                                  13, 14, 15, 16, 17, 18,
58*89c4ff92SAndroid Build Coastguard Worker                                  19, 20, 21, 22, 23, 24 };
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputValues0 { 1, 7, 13, 19 };
61*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputValues1 { 2, 8, 14, 20 };
62*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputValues2 { 3, 9, 15, 21 };
63*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputValues3 { 4, 10, 16, 22 };
64*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputValues4 { 5, 11, 17, 23 };
65*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputValues5 { 6, 12, 18, 24 };
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<T>> expectedOutputValues{ expectedOutputValues0,
68*89c4ff92SAndroid Build Coastguard Worker                                                       expectedOutputValues1,
69*89c4ff92SAndroid Build Coastguard Worker                                                       expectedOutputValues2,
70*89c4ff92SAndroid Build Coastguard Worker                                                       expectedOutputValues3,
71*89c4ff92SAndroid Build Coastguard Worker                                                       expectedOutputValues4,
72*89c4ff92SAndroid Build Coastguard Worker                                                       expectedOutputValues5 };
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     UnpackTest<T>(tflite::BuiltinOperator_UNPACK,
75*89c4ff92SAndroid Build Coastguard Worker                   tensorType,
76*89c4ff92SAndroid Build Coastguard Worker                   backends,
77*89c4ff92SAndroid Build Coastguard Worker                   inputShape,
78*89c4ff92SAndroid Build Coastguard Worker                   expectedOutputShape,
79*89c4ff92SAndroid Build Coastguard Worker                   inputValues,
80*89c4ff92SAndroid Build Coastguard Worker                   expectedOutputValues,
81*89c4ff92SAndroid Build Coastguard Worker                   2);
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Unpack_CpuRefTests")
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker // Fp32
88*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Fp32_Axis0_Num4_CpuRef_Test")
89*89c4ff92SAndroid Build Coastguard Worker {
90*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
91*89c4ff92SAndroid Build Coastguard Worker UnpackAxis0Num4Test<float>(tflite::TensorType_FLOAT32, backends);
92*89c4ff92SAndroid Build Coastguard Worker }
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Fp32_Axis2_Num6_CpuRef_Test")
95*89c4ff92SAndroid Build Coastguard Worker {
96*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
97*89c4ff92SAndroid Build Coastguard Worker UnpackAxis2Num6Test<float>(tflite::TensorType_FLOAT32, backends);
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker // Uint8
101*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Uint8_Axis0_Num4_CpuRef_Test")
102*89c4ff92SAndroid Build Coastguard Worker {
103*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
104*89c4ff92SAndroid Build Coastguard Worker UnpackAxis0Num4Test<uint8_t>(tflite::TensorType_UINT8, backends);
105*89c4ff92SAndroid Build Coastguard Worker }
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Uint8_Axis2_Num6_CpuRef_Test")
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
110*89c4ff92SAndroid Build Coastguard Worker UnpackAxis2Num6Test<uint8_t>(tflite::TensorType_UINT8, backends);
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker } // End of Unpack_CpuRefTests
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Unpack_CpuAccTests")
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker // Fp32
119*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Fp32_Axis0_Num4_CpuAcc_Test")
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
122*89c4ff92SAndroid Build Coastguard Worker UnpackAxis0Num4Test<float>(tflite::TensorType_FLOAT32, backends);
123*89c4ff92SAndroid Build Coastguard Worker }
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Fp32_Axis2_Num6_CpuAcc_Test")
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
128*89c4ff92SAndroid Build Coastguard Worker UnpackAxis2Num6Test<float>(tflite::TensorType_FLOAT32, backends);
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker // Uint8
132*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Uint8_Axis0_Num4_CpuAcc_Test")
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
135*89c4ff92SAndroid Build Coastguard Worker UnpackAxis0Num4Test<uint8_t>(tflite::TensorType_UINT8, backends);
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Uint8_Axis2_Num6_CpuAcc_Test")
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
141*89c4ff92SAndroid Build Coastguard Worker UnpackAxis2Num6Test<uint8_t>(tflite::TensorType_UINT8, backends);
142*89c4ff92SAndroid Build Coastguard Worker }
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker } // End of Unpack_CpuAccTests
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Unpack_GpuAccTests")
147*89c4ff92SAndroid Build Coastguard Worker {
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker // Fp32
150*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Fp32_Axis0_Num4_GpuAcc_Test")
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
153*89c4ff92SAndroid Build Coastguard Worker UnpackAxis0Num4Test<float>(tflite::TensorType_FLOAT32, backends);
154*89c4ff92SAndroid Build Coastguard Worker }
155*89c4ff92SAndroid Build Coastguard Worker 
156*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Fp32_Axis2_Num6_GpuAcc_Test")
157*89c4ff92SAndroid Build Coastguard Worker {
158*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
159*89c4ff92SAndroid Build Coastguard Worker UnpackAxis2Num6Test<float>(tflite::TensorType_FLOAT32, backends);
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker // Uint8
163*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Uint8_Axis0_Num4_GpuAcc_Test")
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
166*89c4ff92SAndroid Build Coastguard Worker UnpackAxis0Num4Test<uint8_t>(tflite::TensorType_UINT8, backends);
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Unpack_Uint8_Axis2_Num6_GpuAcc_Test")
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
172*89c4ff92SAndroid Build Coastguard Worker UnpackAxis2Num6Test<uint8_t>(tflite::TensorType_UINT8, backends);
173*89c4ff92SAndroid Build Coastguard Worker }
174*89c4ff92SAndroid Build Coastguard Worker 
175*89c4ff92SAndroid Build Coastguard Worker } // End of Unpack_GpuAccTests
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker // End of Unpack Test Suite
178*89c4ff92SAndroid Build Coastguard Worker 
179*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate