xref: /aosp_15_r20/external/armnn/delegate/test/StridedSliceTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "StridedSliceTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker 
StridedSlice4DTest(std::vector<armnn::BackendId> & backends)17*89c4ff92SAndroid Build Coastguard Worker void StridedSlice4DTest(std::vector<armnn::BackendId>& backends)
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape  { 3, 2, 3, 1 };
20*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputShape { 1, 2, 3, 1 };
21*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> beginShape  { 4 };
22*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> endShape    { 4 };
23*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> strideShape { 4 };
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> beginData  { 1, 0, 0, 0 };
26*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> endData    { 2, 2, 3, 1 };
27*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> strideData { 1, 1, 1, 1 };
28*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData  { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
29*89c4ff92SAndroid Build Coastguard Worker                                     3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
30*89c4ff92SAndroid Build Coastguard Worker                                     5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
31*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f };
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     StridedSliceTestImpl<float>(
34*89c4ff92SAndroid Build Coastguard Worker             backends,
35*89c4ff92SAndroid Build Coastguard Worker             inputData,
36*89c4ff92SAndroid Build Coastguard Worker             outputData,
37*89c4ff92SAndroid Build Coastguard Worker             beginData,
38*89c4ff92SAndroid Build Coastguard Worker             endData,
39*89c4ff92SAndroid Build Coastguard Worker             strideData,
40*89c4ff92SAndroid Build Coastguard Worker             inputShape,
41*89c4ff92SAndroid Build Coastguard Worker             beginShape,
42*89c4ff92SAndroid Build Coastguard Worker             endShape,
43*89c4ff92SAndroid Build Coastguard Worker             strideShape,
44*89c4ff92SAndroid Build Coastguard Worker             outputShape
45*89c4ff92SAndroid Build Coastguard Worker             );
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker 
StridedSlice4DReverseTest(std::vector<armnn::BackendId> & backends)48*89c4ff92SAndroid Build Coastguard Worker void StridedSlice4DReverseTest(std::vector<armnn::BackendId>& backends)
49*89c4ff92SAndroid Build Coastguard Worker {
50*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape  { 3, 2, 3, 1 };
51*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputShape { 1, 2, 3, 1 };
52*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> beginShape  { 4 };
53*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> endShape    { 4 };
54*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> strideShape { 4 };
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> beginData  { 1, -1, 0, 0 };
57*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> endData    { 2, -3, 3, 1 };
58*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> strideData { 1, -1, 1, 1 };
59*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   inputData  { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
60*89c4ff92SAndroid Build Coastguard Worker                                       3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
61*89c4ff92SAndroid Build Coastguard Worker                                       5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
62*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   outputData { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f };
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     StridedSliceTestImpl<float>(
65*89c4ff92SAndroid Build Coastguard Worker             backends,
66*89c4ff92SAndroid Build Coastguard Worker             inputData,
67*89c4ff92SAndroid Build Coastguard Worker             outputData,
68*89c4ff92SAndroid Build Coastguard Worker             beginData,
69*89c4ff92SAndroid Build Coastguard Worker             endData,
70*89c4ff92SAndroid Build Coastguard Worker             strideData,
71*89c4ff92SAndroid Build Coastguard Worker             inputShape,
72*89c4ff92SAndroid Build Coastguard Worker             beginShape,
73*89c4ff92SAndroid Build Coastguard Worker             endShape,
74*89c4ff92SAndroid Build Coastguard Worker             strideShape,
75*89c4ff92SAndroid Build Coastguard Worker             outputShape
76*89c4ff92SAndroid Build Coastguard Worker     );
77*89c4ff92SAndroid Build Coastguard Worker }
78*89c4ff92SAndroid Build Coastguard Worker 
StridedSliceSimpleStrideTest(std::vector<armnn::BackendId> & backends)79*89c4ff92SAndroid Build Coastguard Worker void StridedSliceSimpleStrideTest(std::vector<armnn::BackendId>& backends)
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape  { 3, 2, 3, 1 };
82*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputShape { 2, 1, 2, 1 };
83*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> beginShape  { 4 };
84*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> endShape    { 4 };
85*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> strideShape { 4 };
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> beginData  { 0, 0, 0, 0 };
88*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> endData    { 3, 2, 3, 1 };
89*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> strideData { 2, 2, 2, 1 };
90*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   inputData  { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
91*89c4ff92SAndroid Build Coastguard Worker                                       3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
92*89c4ff92SAndroid Build Coastguard Worker                                       5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
93*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   outputData { 1.0f, 1.0f,
94*89c4ff92SAndroid Build Coastguard Worker                                       5.0f, 5.0f };
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     StridedSliceTestImpl<float>(
97*89c4ff92SAndroid Build Coastguard Worker             backends,
98*89c4ff92SAndroid Build Coastguard Worker             inputData,
99*89c4ff92SAndroid Build Coastguard Worker             outputData,
100*89c4ff92SAndroid Build Coastguard Worker             beginData,
101*89c4ff92SAndroid Build Coastguard Worker             endData,
102*89c4ff92SAndroid Build Coastguard Worker             strideData,
103*89c4ff92SAndroid Build Coastguard Worker             inputShape,
104*89c4ff92SAndroid Build Coastguard Worker             beginShape,
105*89c4ff92SAndroid Build Coastguard Worker             endShape,
106*89c4ff92SAndroid Build Coastguard Worker             strideShape,
107*89c4ff92SAndroid Build Coastguard Worker             outputShape
108*89c4ff92SAndroid Build Coastguard Worker     );
109*89c4ff92SAndroid Build Coastguard Worker }
110*89c4ff92SAndroid Build Coastguard Worker 
StridedSliceSimpleRangeMaskTest(std::vector<armnn::BackendId> & backends)111*89c4ff92SAndroid Build Coastguard Worker void StridedSliceSimpleRangeMaskTest(std::vector<armnn::BackendId>& backends)
112*89c4ff92SAndroid Build Coastguard Worker {
113*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape  { 3, 2, 3, 1 };
114*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputShape { 3, 2, 3, 1 };
115*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> beginShape  { 4 };
116*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> endShape    { 4 };
117*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> strideShape { 4 };
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> beginData  { 1, 1, 1, 1 };
120*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> endData    { 1, 1, 1, 1 };
121*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> strideData { 1, 1, 1, 1 };
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     int beginMask = -1;
124*89c4ff92SAndroid Build Coastguard Worker     int endMask   = -1;
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   inputData  { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
127*89c4ff92SAndroid Build Coastguard Worker                                       3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
128*89c4ff92SAndroid Build Coastguard Worker                                       5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
129*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   outputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
130*89c4ff92SAndroid Build Coastguard Worker                                       3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
131*89c4ff92SAndroid Build Coastguard Worker                                       5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker     StridedSliceTestImpl<float>(
134*89c4ff92SAndroid Build Coastguard Worker             backends,
135*89c4ff92SAndroid Build Coastguard Worker             inputData,
136*89c4ff92SAndroid Build Coastguard Worker             outputData,
137*89c4ff92SAndroid Build Coastguard Worker             beginData,
138*89c4ff92SAndroid Build Coastguard Worker             endData,
139*89c4ff92SAndroid Build Coastguard Worker             strideData,
140*89c4ff92SAndroid Build Coastguard Worker             inputShape,
141*89c4ff92SAndroid Build Coastguard Worker             beginShape,
142*89c4ff92SAndroid Build Coastguard Worker             endShape,
143*89c4ff92SAndroid Build Coastguard Worker             strideShape,
144*89c4ff92SAndroid Build Coastguard Worker             outputShape,
145*89c4ff92SAndroid Build Coastguard Worker             beginMask,
146*89c4ff92SAndroid Build Coastguard Worker             endMask
147*89c4ff92SAndroid Build Coastguard Worker     );
148*89c4ff92SAndroid Build Coastguard Worker }
149*89c4ff92SAndroid Build Coastguard Worker 
150*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("StridedSlice_CpuRefTests")
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_CpuRef_Test")
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
156*89c4ff92SAndroid Build Coastguard Worker     StridedSlice4DTest(backends);
157*89c4ff92SAndroid Build Coastguard Worker }
158*89c4ff92SAndroid Build Coastguard Worker 
159*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_Reverse_CpuRef_Test")
160*89c4ff92SAndroid Build Coastguard Worker {
161*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
162*89c4ff92SAndroid Build Coastguard Worker     StridedSlice4DReverseTest(backends);
163*89c4ff92SAndroid Build Coastguard Worker }
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleStride_CpuRef_Test")
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
168*89c4ff92SAndroid Build Coastguard Worker     StridedSliceSimpleStrideTest(backends);
169*89c4ff92SAndroid Build Coastguard Worker }
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleRange_CpuRef_Test")
172*89c4ff92SAndroid Build Coastguard Worker {
173*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
174*89c4ff92SAndroid Build Coastguard Worker     StridedSliceSimpleRangeMaskTest(backends);
175*89c4ff92SAndroid Build Coastguard Worker }
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker } // StridedSlice_CpuRefTests TestSuite
178*89c4ff92SAndroid Build Coastguard Worker 
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker 
181*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("StridedSlice_CpuAccTests")
182*89c4ff92SAndroid Build Coastguard Worker {
183*89c4ff92SAndroid Build Coastguard Worker 
184*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_CpuAcc_Test")
185*89c4ff92SAndroid Build Coastguard Worker {
186*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
187*89c4ff92SAndroid Build Coastguard Worker     StridedSlice4DTest(backends);
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker 
190*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_Reverse_CpuAcc_Test")
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
193*89c4ff92SAndroid Build Coastguard Worker     StridedSlice4DReverseTest(backends);
194*89c4ff92SAndroid Build Coastguard Worker }
195*89c4ff92SAndroid Build Coastguard Worker 
196*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleStride_CpuAcc_Test")
197*89c4ff92SAndroid Build Coastguard Worker {
198*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
199*89c4ff92SAndroid Build Coastguard Worker     StridedSliceSimpleStrideTest(backends);
200*89c4ff92SAndroid Build Coastguard Worker }
201*89c4ff92SAndroid Build Coastguard Worker 
202*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleRange_CpuAcc_Test")
203*89c4ff92SAndroid Build Coastguard Worker {
204*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
205*89c4ff92SAndroid Build Coastguard Worker     StridedSliceSimpleRangeMaskTest(backends);
206*89c4ff92SAndroid Build Coastguard Worker }
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker } // StridedSlice_CpuAccTests TestSuite
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("StridedSlice_GpuAccTests")
213*89c4ff92SAndroid Build Coastguard Worker {
214*89c4ff92SAndroid Build Coastguard Worker 
215*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_GpuAcc_Test")
216*89c4ff92SAndroid Build Coastguard Worker {
217*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
218*89c4ff92SAndroid Build Coastguard Worker     StridedSlice4DTest(backends);
219*89c4ff92SAndroid Build Coastguard Worker }
220*89c4ff92SAndroid Build Coastguard Worker 
221*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_Reverse_GpuAcc_Test")
222*89c4ff92SAndroid Build Coastguard Worker {
223*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
224*89c4ff92SAndroid Build Coastguard Worker     StridedSlice4DReverseTest(backends);
225*89c4ff92SAndroid Build Coastguard Worker }
226*89c4ff92SAndroid Build Coastguard Worker 
227*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleStride_GpuAcc_Test")
228*89c4ff92SAndroid Build Coastguard Worker {
229*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
230*89c4ff92SAndroid Build Coastguard Worker     StridedSliceSimpleStrideTest(backends);
231*89c4ff92SAndroid Build Coastguard Worker }
232*89c4ff92SAndroid Build Coastguard Worker 
233*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleRange_GpuAcc_Test")
234*89c4ff92SAndroid Build Coastguard Worker {
235*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
236*89c4ff92SAndroid Build Coastguard Worker     StridedSliceSimpleRangeMaskTest(backends);
237*89c4ff92SAndroid Build Coastguard Worker }
238*89c4ff92SAndroid Build Coastguard Worker 
239*89c4ff92SAndroid Build Coastguard Worker } // StridedSlice_GpuAccTests TestSuite
240*89c4ff92SAndroid Build Coastguard Worker 
241*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate