1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "ArgMinMaxTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker
ArgMaxFP32Test(std::vector<armnn::BackendId> & backends,int axisValue)18*89c4ff92SAndroid Build Coastguard Worker void ArgMaxFP32Test(std::vector<armnn::BackendId>& backends, int axisValue)
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker // Set input data
21*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape { 1, 3, 2, 4 };
22*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1, 3, 4 };
23*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> axisShape { 1 };
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 1.0f, 2.0f, 3.0f, 4.0f,
26*89c4ff92SAndroid Build Coastguard Worker 5.0f, 6.0f, 7.0f, 8.0f,
27*89c4ff92SAndroid Build Coastguard Worker
28*89c4ff92SAndroid Build Coastguard Worker 10.0f, 20.0f, 30.0f, 40.0f,
29*89c4ff92SAndroid Build Coastguard Worker 50.0f, 60.0f, 70.0f, 80.0f,
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker 100.0f, 200.0f, 300.0f, 400.0f,
32*89c4ff92SAndroid Build Coastguard Worker 500.0f, 600.0f, 700.0f, 800.0f };
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> expectedOutputValues = { 1, 1, 1, 1,
35*89c4ff92SAndroid Build Coastguard Worker 1, 1, 1, 1,
36*89c4ff92SAndroid Build Coastguard Worker 1, 1, 1, 1 };
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker ArgMinMaxTest<float, int32_t>(tflite::BuiltinOperator_ARG_MAX,
39*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32,
40*89c4ff92SAndroid Build Coastguard Worker backends,
41*89c4ff92SAndroid Build Coastguard Worker inputShape,
42*89c4ff92SAndroid Build Coastguard Worker axisShape,
43*89c4ff92SAndroid Build Coastguard Worker outputShape,
44*89c4ff92SAndroid Build Coastguard Worker inputValues,
45*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
46*89c4ff92SAndroid Build Coastguard Worker axisValue,
47*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT32);
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker
ArgMinFP32Test(std::vector<armnn::BackendId> & backends,int axisValue)50*89c4ff92SAndroid Build Coastguard Worker void ArgMinFP32Test(std::vector<armnn::BackendId>& backends, int axisValue)
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker // Set input data
53*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape { 1, 3, 2, 4 };
54*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1, 3, 2 };
55*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> axisShape { 1 };
56*89c4ff92SAndroid Build Coastguard Worker
57*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 1.0f, 2.0f, 3.0f, 4.0f,
58*89c4ff92SAndroid Build Coastguard Worker 5.0f, 6.0f, 7.0f, 8.0f,
59*89c4ff92SAndroid Build Coastguard Worker
60*89c4ff92SAndroid Build Coastguard Worker 10.0f, 20.0f, 30.0f, 40.0f,
61*89c4ff92SAndroid Build Coastguard Worker 50.0f, 60.0f, 70.0f, 80.0f,
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker 100.0f, 200.0f, 300.0f, 400.0f,
64*89c4ff92SAndroid Build Coastguard Worker 500.0f, 600.0f, 700.0f, 800.0f };
65*89c4ff92SAndroid Build Coastguard Worker
66*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> expectedOutputValues = { 0, 0,
67*89c4ff92SAndroid Build Coastguard Worker 0, 0,
68*89c4ff92SAndroid Build Coastguard Worker 0, 0 };
69*89c4ff92SAndroid Build Coastguard Worker
70*89c4ff92SAndroid Build Coastguard Worker ArgMinMaxTest<float, int32_t>(tflite::BuiltinOperator_ARG_MIN,
71*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32,
72*89c4ff92SAndroid Build Coastguard Worker backends,
73*89c4ff92SAndroid Build Coastguard Worker inputShape,
74*89c4ff92SAndroid Build Coastguard Worker axisShape,
75*89c4ff92SAndroid Build Coastguard Worker outputShape,
76*89c4ff92SAndroid Build Coastguard Worker inputValues,
77*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
78*89c4ff92SAndroid Build Coastguard Worker axisValue,
79*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT32);
80*89c4ff92SAndroid Build Coastguard Worker }
81*89c4ff92SAndroid Build Coastguard Worker
ArgMaxUint8Test(std::vector<armnn::BackendId> & backends,int axisValue)82*89c4ff92SAndroid Build Coastguard Worker void ArgMaxUint8Test(std::vector<armnn::BackendId>& backends, int axisValue)
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker // Set input data
85*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape { 1, 1, 1, 5 };
86*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1, 1, 1 };
87*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> axisShape { 1 };
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> inputValues = { 5, 2, 8, 10, 9 };
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> expectedOutputValues = { 3 };
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker ArgMinMaxTest<uint8_t, int32_t>(tflite::BuiltinOperator_ARG_MAX,
94*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_UINT8,
95*89c4ff92SAndroid Build Coastguard Worker backends,
96*89c4ff92SAndroid Build Coastguard Worker inputShape,
97*89c4ff92SAndroid Build Coastguard Worker axisShape,
98*89c4ff92SAndroid Build Coastguard Worker outputShape,
99*89c4ff92SAndroid Build Coastguard Worker inputValues,
100*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
101*89c4ff92SAndroid Build Coastguard Worker axisValue,
102*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT32);
103*89c4ff92SAndroid Build Coastguard Worker }
104*89c4ff92SAndroid Build Coastguard Worker
105*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("ArgMinMax_CpuRefTests")
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker
108*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArgMaxFP32Test_CpuRef_Test")
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
111*89c4ff92SAndroid Build Coastguard Worker ArgMaxFP32Test(backends, 2);
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArgMinFP32Test_CpuRef_Test")
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
117*89c4ff92SAndroid Build Coastguard Worker ArgMinFP32Test(backends, 3);
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker
120*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArgMaxUint8Test_CpuRef_Test")
121*89c4ff92SAndroid Build Coastguard Worker {
122*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
123*89c4ff92SAndroid Build Coastguard Worker ArgMaxUint8Test(backends, -1);
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker } // TEST_SUITE("ArgMinMax_CpuRefTests")
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("ArgMinMax_CpuAccTests")
129*89c4ff92SAndroid Build Coastguard Worker {
130*89c4ff92SAndroid Build Coastguard Worker
131*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArgMaxFP32Test_CpuAcc_Test")
132*89c4ff92SAndroid Build Coastguard Worker {
133*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
134*89c4ff92SAndroid Build Coastguard Worker ArgMaxFP32Test(backends, 2);
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker
137*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArgMinFP32Test_CpuAcc_Test")
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
140*89c4ff92SAndroid Build Coastguard Worker ArgMinFP32Test(backends, 3);
141*89c4ff92SAndroid Build Coastguard Worker }
142*89c4ff92SAndroid Build Coastguard Worker
143*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArgMaxUint8Test_CpuAcc_Test")
144*89c4ff92SAndroid Build Coastguard Worker {
145*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
146*89c4ff92SAndroid Build Coastguard Worker ArgMaxUint8Test(backends, -1);
147*89c4ff92SAndroid Build Coastguard Worker }
148*89c4ff92SAndroid Build Coastguard Worker
149*89c4ff92SAndroid Build Coastguard Worker } // TEST_SUITE("ArgMinMax_CpuAccTests")
150*89c4ff92SAndroid Build Coastguard Worker
151*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("ArgMinMax_GpuAccTests")
152*89c4ff92SAndroid Build Coastguard Worker {
153*89c4ff92SAndroid Build Coastguard Worker
154*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArgMaxFP32Test_GpuAcc_Test")
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
157*89c4ff92SAndroid Build Coastguard Worker ArgMaxFP32Test(backends, 2);
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker
160*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArgMinFP32Test_GpuAcc_Test")
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
163*89c4ff92SAndroid Build Coastguard Worker ArgMinFP32Test(backends, 3);
164*89c4ff92SAndroid Build Coastguard Worker }
165*89c4ff92SAndroid Build Coastguard Worker
166*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("ArgMaxUint8Test_GpuAcc_Test")
167*89c4ff92SAndroid Build Coastguard Worker {
168*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
169*89c4ff92SAndroid Build Coastguard Worker ArgMaxUint8Test(backends, -1);
170*89c4ff92SAndroid Build Coastguard Worker }
171*89c4ff92SAndroid Build Coastguard Worker
172*89c4ff92SAndroid Build Coastguard Worker } // TEST_SUITE("ArgMinMax_GpuAccTests")
173*89c4ff92SAndroid Build Coastguard Worker
174*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate