1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "StridedSliceTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker
StridedSlice4DTest(std::vector<armnn::BackendId> & backends)17*89c4ff92SAndroid Build Coastguard Worker void StridedSlice4DTest(std::vector<armnn::BackendId>& backends)
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape { 3, 2, 3, 1 };
20*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1, 2, 3, 1 };
21*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> beginShape { 4 };
22*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> endShape { 4 };
23*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> strideShape { 4 };
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> beginData { 1, 0, 0, 0 };
26*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> endData { 2, 2, 3, 1 };
27*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> strideData { 1, 1, 1, 1 };
28*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
29*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
30*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
31*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f };
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker StridedSliceTestImpl<float>(
34*89c4ff92SAndroid Build Coastguard Worker backends,
35*89c4ff92SAndroid Build Coastguard Worker inputData,
36*89c4ff92SAndroid Build Coastguard Worker outputData,
37*89c4ff92SAndroid Build Coastguard Worker beginData,
38*89c4ff92SAndroid Build Coastguard Worker endData,
39*89c4ff92SAndroid Build Coastguard Worker strideData,
40*89c4ff92SAndroid Build Coastguard Worker inputShape,
41*89c4ff92SAndroid Build Coastguard Worker beginShape,
42*89c4ff92SAndroid Build Coastguard Worker endShape,
43*89c4ff92SAndroid Build Coastguard Worker strideShape,
44*89c4ff92SAndroid Build Coastguard Worker outputShape
45*89c4ff92SAndroid Build Coastguard Worker );
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker
StridedSlice4DReverseTest(std::vector<armnn::BackendId> & backends)48*89c4ff92SAndroid Build Coastguard Worker void StridedSlice4DReverseTest(std::vector<armnn::BackendId>& backends)
49*89c4ff92SAndroid Build Coastguard Worker {
50*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape { 3, 2, 3, 1 };
51*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1, 2, 3, 1 };
52*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> beginShape { 4 };
53*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> endShape { 4 };
54*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> strideShape { 4 };
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> beginData { 1, -1, 0, 0 };
57*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> endData { 2, -3, 3, 1 };
58*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> strideData { 1, -1, 1, 1 };
59*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
60*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
61*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
62*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f };
63*89c4ff92SAndroid Build Coastguard Worker
64*89c4ff92SAndroid Build Coastguard Worker StridedSliceTestImpl<float>(
65*89c4ff92SAndroid Build Coastguard Worker backends,
66*89c4ff92SAndroid Build Coastguard Worker inputData,
67*89c4ff92SAndroid Build Coastguard Worker outputData,
68*89c4ff92SAndroid Build Coastguard Worker beginData,
69*89c4ff92SAndroid Build Coastguard Worker endData,
70*89c4ff92SAndroid Build Coastguard Worker strideData,
71*89c4ff92SAndroid Build Coastguard Worker inputShape,
72*89c4ff92SAndroid Build Coastguard Worker beginShape,
73*89c4ff92SAndroid Build Coastguard Worker endShape,
74*89c4ff92SAndroid Build Coastguard Worker strideShape,
75*89c4ff92SAndroid Build Coastguard Worker outputShape
76*89c4ff92SAndroid Build Coastguard Worker );
77*89c4ff92SAndroid Build Coastguard Worker }
78*89c4ff92SAndroid Build Coastguard Worker
StridedSliceSimpleStrideTest(std::vector<armnn::BackendId> & backends)79*89c4ff92SAndroid Build Coastguard Worker void StridedSliceSimpleStrideTest(std::vector<armnn::BackendId>& backends)
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape { 3, 2, 3, 1 };
82*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2, 1, 2, 1 };
83*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> beginShape { 4 };
84*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> endShape { 4 };
85*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> strideShape { 4 };
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> beginData { 0, 0, 0, 0 };
88*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> endData { 3, 2, 3, 1 };
89*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> strideData { 2, 2, 2, 1 };
90*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
91*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
92*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
93*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData { 1.0f, 1.0f,
94*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f };
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Worker StridedSliceTestImpl<float>(
97*89c4ff92SAndroid Build Coastguard Worker backends,
98*89c4ff92SAndroid Build Coastguard Worker inputData,
99*89c4ff92SAndroid Build Coastguard Worker outputData,
100*89c4ff92SAndroid Build Coastguard Worker beginData,
101*89c4ff92SAndroid Build Coastguard Worker endData,
102*89c4ff92SAndroid Build Coastguard Worker strideData,
103*89c4ff92SAndroid Build Coastguard Worker inputShape,
104*89c4ff92SAndroid Build Coastguard Worker beginShape,
105*89c4ff92SAndroid Build Coastguard Worker endShape,
106*89c4ff92SAndroid Build Coastguard Worker strideShape,
107*89c4ff92SAndroid Build Coastguard Worker outputShape
108*89c4ff92SAndroid Build Coastguard Worker );
109*89c4ff92SAndroid Build Coastguard Worker }
110*89c4ff92SAndroid Build Coastguard Worker
StridedSliceSimpleRangeMaskTest(std::vector<armnn::BackendId> & backends)111*89c4ff92SAndroid Build Coastguard Worker void StridedSliceSimpleRangeMaskTest(std::vector<armnn::BackendId>& backends)
112*89c4ff92SAndroid Build Coastguard Worker {
113*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape { 3, 2, 3, 1 };
114*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 3, 2, 3, 1 };
115*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> beginShape { 4 };
116*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> endShape { 4 };
117*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> strideShape { 4 };
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> beginData { 1, 1, 1, 1 };
120*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> endData { 1, 1, 1, 1 };
121*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> strideData { 1, 1, 1, 1 };
122*89c4ff92SAndroid Build Coastguard Worker
123*89c4ff92SAndroid Build Coastguard Worker int beginMask = -1;
124*89c4ff92SAndroid Build Coastguard Worker int endMask = -1;
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
127*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
128*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
129*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
130*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
131*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
132*89c4ff92SAndroid Build Coastguard Worker
133*89c4ff92SAndroid Build Coastguard Worker StridedSliceTestImpl<float>(
134*89c4ff92SAndroid Build Coastguard Worker backends,
135*89c4ff92SAndroid Build Coastguard Worker inputData,
136*89c4ff92SAndroid Build Coastguard Worker outputData,
137*89c4ff92SAndroid Build Coastguard Worker beginData,
138*89c4ff92SAndroid Build Coastguard Worker endData,
139*89c4ff92SAndroid Build Coastguard Worker strideData,
140*89c4ff92SAndroid Build Coastguard Worker inputShape,
141*89c4ff92SAndroid Build Coastguard Worker beginShape,
142*89c4ff92SAndroid Build Coastguard Worker endShape,
143*89c4ff92SAndroid Build Coastguard Worker strideShape,
144*89c4ff92SAndroid Build Coastguard Worker outputShape,
145*89c4ff92SAndroid Build Coastguard Worker beginMask,
146*89c4ff92SAndroid Build Coastguard Worker endMask
147*89c4ff92SAndroid Build Coastguard Worker );
148*89c4ff92SAndroid Build Coastguard Worker }
149*89c4ff92SAndroid Build Coastguard Worker
150*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("StridedSlice_CpuRefTests")
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker
153*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_CpuRef_Test")
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
156*89c4ff92SAndroid Build Coastguard Worker StridedSlice4DTest(backends);
157*89c4ff92SAndroid Build Coastguard Worker }
158*89c4ff92SAndroid Build Coastguard Worker
159*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_Reverse_CpuRef_Test")
160*89c4ff92SAndroid Build Coastguard Worker {
161*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
162*89c4ff92SAndroid Build Coastguard Worker StridedSlice4DReverseTest(backends);
163*89c4ff92SAndroid Build Coastguard Worker }
164*89c4ff92SAndroid Build Coastguard Worker
165*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleStride_CpuRef_Test")
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
168*89c4ff92SAndroid Build Coastguard Worker StridedSliceSimpleStrideTest(backends);
169*89c4ff92SAndroid Build Coastguard Worker }
170*89c4ff92SAndroid Build Coastguard Worker
171*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleRange_CpuRef_Test")
172*89c4ff92SAndroid Build Coastguard Worker {
173*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
174*89c4ff92SAndroid Build Coastguard Worker StridedSliceSimpleRangeMaskTest(backends);
175*89c4ff92SAndroid Build Coastguard Worker }
176*89c4ff92SAndroid Build Coastguard Worker
177*89c4ff92SAndroid Build Coastguard Worker } // StridedSlice_CpuRefTests TestSuite
178*89c4ff92SAndroid Build Coastguard Worker
179*89c4ff92SAndroid Build Coastguard Worker
180*89c4ff92SAndroid Build Coastguard Worker
181*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("StridedSlice_CpuAccTests")
182*89c4ff92SAndroid Build Coastguard Worker {
183*89c4ff92SAndroid Build Coastguard Worker
184*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_CpuAcc_Test")
185*89c4ff92SAndroid Build Coastguard Worker {
186*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
187*89c4ff92SAndroid Build Coastguard Worker StridedSlice4DTest(backends);
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker
190*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_Reverse_CpuAcc_Test")
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
193*89c4ff92SAndroid Build Coastguard Worker StridedSlice4DReverseTest(backends);
194*89c4ff92SAndroid Build Coastguard Worker }
195*89c4ff92SAndroid Build Coastguard Worker
196*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleStride_CpuAcc_Test")
197*89c4ff92SAndroid Build Coastguard Worker {
198*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
199*89c4ff92SAndroid Build Coastguard Worker StridedSliceSimpleStrideTest(backends);
200*89c4ff92SAndroid Build Coastguard Worker }
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleRange_CpuAcc_Test")
203*89c4ff92SAndroid Build Coastguard Worker {
204*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
205*89c4ff92SAndroid Build Coastguard Worker StridedSliceSimpleRangeMaskTest(backends);
206*89c4ff92SAndroid Build Coastguard Worker }
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker } // StridedSlice_CpuAccTests TestSuite
209*89c4ff92SAndroid Build Coastguard Worker
210*89c4ff92SAndroid Build Coastguard Worker
211*89c4ff92SAndroid Build Coastguard Worker
212*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("StridedSlice_GpuAccTests")
213*89c4ff92SAndroid Build Coastguard Worker {
214*89c4ff92SAndroid Build Coastguard Worker
215*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_GpuAcc_Test")
216*89c4ff92SAndroid Build Coastguard Worker {
217*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
218*89c4ff92SAndroid Build Coastguard Worker StridedSlice4DTest(backends);
219*89c4ff92SAndroid Build Coastguard Worker }
220*89c4ff92SAndroid Build Coastguard Worker
221*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_4D_Reverse_GpuAcc_Test")
222*89c4ff92SAndroid Build Coastguard Worker {
223*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
224*89c4ff92SAndroid Build Coastguard Worker StridedSlice4DReverseTest(backends);
225*89c4ff92SAndroid Build Coastguard Worker }
226*89c4ff92SAndroid Build Coastguard Worker
227*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleStride_GpuAcc_Test")
228*89c4ff92SAndroid Build Coastguard Worker {
229*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
230*89c4ff92SAndroid Build Coastguard Worker StridedSliceSimpleStrideTest(backends);
231*89c4ff92SAndroid Build Coastguard Worker }
232*89c4ff92SAndroid Build Coastguard Worker
233*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("StridedSlice_SimpleRange_GpuAcc_Test")
234*89c4ff92SAndroid Build Coastguard Worker {
235*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
236*89c4ff92SAndroid Build Coastguard Worker StridedSliceSimpleRangeMaskTest(backends);
237*89c4ff92SAndroid Build Coastguard Worker }
238*89c4ff92SAndroid Build Coastguard Worker
239*89c4ff92SAndroid Build Coastguard Worker } // StridedSlice_GpuAccTests TestSuite
240*89c4ff92SAndroid Build Coastguard Worker
241*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate