xref: /aosp_15_r20/external/armnn/src/armnn/test/optimizations/ReduceMultipleAxesTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <GraphUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <TestUtils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)||defined(ARMCOMPUTECL_ENABLED)
CreateSimpleReduceNetwork(ReduceDescriptor reduceDescriptor,TensorShape & inputShape,TensorShape & outputShape)18*89c4ff92SAndroid Build Coastguard Worker INetworkPtr CreateSimpleReduceNetwork(ReduceDescriptor reduceDescriptor,
19*89c4ff92SAndroid Build Coastguard Worker                                       TensorShape& inputShape,
20*89c4ff92SAndroid Build Coastguard Worker                                       TensorShape& outputShape)
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker     // Create a network
23*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = INetwork::Create();
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("reduce_layer");
26*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo inputInfo(inputShape, DataType::Float32);
27*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo outputInfo(outputShape, DataType::Float32);
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const inputLayer = network->AddInputLayer(0);
30*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const reduceLayer = network->AddReduceLayer(reduceDescriptor, layerName.c_str());
31*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const outputLayer1 = network->AddOutputLayer(0);
32*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const outputLayer2 = network->AddOutputLayer(1);
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
35*89c4ff92SAndroid Build Coastguard Worker     reduceLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(reduceLayer->GetInputSlot(0));
38*89c4ff92SAndroid Build Coastguard Worker     reduceLayer->GetOutputSlot(0).Connect(outputLayer1->GetInputSlot(0));
39*89c4ff92SAndroid Build Coastguard Worker     reduceLayer->GetOutputSlot(0).Connect(outputLayer2->GetInputSlot(0));
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker     return network;
42*89c4ff92SAndroid Build Coastguard Worker }
43*89c4ff92SAndroid Build Coastguard Worker 
ReduceWithMultipleAxesTest(INetworkPtr & network,const TensorShape & outputShape,const std::vector<float> & inputData,const std::vector<float> & expectedOutput,const size_t numOfAxes,Compute backendId)44*89c4ff92SAndroid Build Coastguard Worker void ReduceWithMultipleAxesTest(INetworkPtr& network,
45*89c4ff92SAndroid Build Coastguard Worker                                 const TensorShape& outputShape,
46*89c4ff92SAndroid Build Coastguard Worker                                 const std::vector<float>& inputData,
47*89c4ff92SAndroid Build Coastguard Worker                                 const std::vector<float>& expectedOutput,
48*89c4ff92SAndroid Build Coastguard Worker                                 const size_t numOfAxes,
49*89c4ff92SAndroid Build Coastguard Worker                                 Compute backendId)
50*89c4ff92SAndroid Build Coastguard Worker {
51*89c4ff92SAndroid Build Coastguard Worker     // Create ArmNN runtime
52*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr run = IRuntime::Create(IRuntime::CreationOptions());
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     // Optimise ArmNN network
55*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*network, {backendId}, run->GetDeviceSpec());
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
58*89c4ff92SAndroid Build Coastguard Worker     if (numOfAxes == 2)
59*89c4ff92SAndroid Build Coastguard Worker     {
60*89c4ff92SAndroid Build Coastguard Worker         CHECK(graph.GetNumLayers() == 5);
61*89c4ff92SAndroid Build Coastguard Worker         CHECK(CheckSequence(graph.cbegin(),
62*89c4ff92SAndroid Build Coastguard Worker                             graph.cend(),
63*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<InputLayer>,
64*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<ReduceLayer>,
65*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<ReduceLayer>,
66*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<OutputLayer>,
67*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<OutputLayer>));
68*89c4ff92SAndroid Build Coastguard Worker     } else
69*89c4ff92SAndroid Build Coastguard Worker     {
70*89c4ff92SAndroid Build Coastguard Worker         CHECK(graph.GetNumLayers() == 6);
71*89c4ff92SAndroid Build Coastguard Worker         CHECK(CheckSequence(graph.cbegin(),
72*89c4ff92SAndroid Build Coastguard Worker                             graph.cend(),
73*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<InputLayer>,
74*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<ReduceLayer>,
75*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<ReduceLayer>,
76*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<ReduceLayer>,
77*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<OutputLayer>,
78*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<OutputLayer>));
79*89c4ff92SAndroid Build Coastguard Worker     }
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker     // Get last layer in new chain, layers name follow 0, 1, 2 pattern
82*89c4ff92SAndroid Build Coastguard Worker     std::string layerName = "reduce_layer_" + std::to_string(numOfAxes - 1);
83*89c4ff92SAndroid Build Coastguard Worker     Layer* const reduceLayer = GetFirstLayerWithName(graph, layerName);
84*89c4ff92SAndroid Build Coastguard Worker     CHECK(reduceLayer);
85*89c4ff92SAndroid Build Coastguard Worker     auto reduceTensorInfo = reduceLayer->GetOutputSlot().GetTensorInfo();
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     // Tensorshape and the data type are correct
88*89c4ff92SAndroid Build Coastguard Worker     CHECK((reduceTensorInfo.GetShape() == outputShape));
89*89c4ff92SAndroid Build Coastguard Worker     CHECK((reduceTensorInfo.GetDataType() == DataType::Float32));
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     // Load network into runtime
92*89c4ff92SAndroid Build Coastguard Worker     NetworkId networkIdentifier;
93*89c4ff92SAndroid Build Coastguard Worker     run->LoadNetwork(networkIdentifier, std::move(optNet));
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker     // Create input and output tensors
96*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(expectedOutput.size());
97*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0);
98*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo.SetConstant(true);
99*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
100*89c4ff92SAndroid Build Coastguard Worker         {
101*89c4ff92SAndroid Build Coastguard Worker             {0, armnn::ConstTensor(inputTensorInfo, inputData.data())}
102*89c4ff92SAndroid Build Coastguard Worker         };
103*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
104*89c4ff92SAndroid Build Coastguard Worker         {
105*89c4ff92SAndroid Build Coastguard Worker             {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())},
106*89c4ff92SAndroid Build Coastguard Worker             {1, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 1), outputData.data())}
107*89c4ff92SAndroid Build Coastguard Worker         };
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker     // Run inference
110*89c4ff92SAndroid Build Coastguard Worker     run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker     // Checks the results
113*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker 
ReduceSumWithTwoAxesKeepDimsTest(Compute backendId)116*89c4ff92SAndroid Build Coastguard Worker void ReduceSumWithTwoAxesKeepDimsTest(Compute backendId)
117*89c4ff92SAndroid Build Coastguard Worker {
118*89c4ff92SAndroid Build Coastguard Worker     armnn::ReduceDescriptor reduceDescriptor;
119*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_vAxis = {1, 2};
120*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_KeepDims = true;
121*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_ReduceOperation = armnn::ReduceOperation::Sum;
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     TensorShape inputShape = {1, 3, 2, 4};
124*89c4ff92SAndroid Build Coastguard Worker     TensorShape outputShape = {1, 1, 1, 4};
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker     // Construct ArmNN network
127*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreateSimpleReduceNetwork(reduceDescriptor, inputShape, outputShape);
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output.
130*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> inputData({1.0f, 2.0f, 3.0f, 4.0f,
131*89c4ff92SAndroid Build Coastguard Worker                                         5.0f, 6.0f, 7.0f, 8.0f,
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker                                         10.0f, 20.0f, 30.0f, 40.0f,
134*89c4ff92SAndroid Build Coastguard Worker                                         50.0f, 60.0f, 70.0f, 80.0f,
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker                                         100.0f, 200.0f, 300.0f, 400.0f,
137*89c4ff92SAndroid Build Coastguard Worker                                         500.0f, 600.0f, 700.0f, 800.0f});
138*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> expectedOutput({666.0f, 888.0f, 1110.0f, 1332.0f});
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker     ReduceWithMultipleAxesTest(network,
141*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
142*89c4ff92SAndroid Build Coastguard Worker                                inputData,
143*89c4ff92SAndroid Build Coastguard Worker                                expectedOutput,
144*89c4ff92SAndroid Build Coastguard Worker                                reduceDescriptor.m_vAxis.size(),
145*89c4ff92SAndroid Build Coastguard Worker                                backendId);
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker 
ReduceSumWithTwoAxesTest(Compute backendId)148*89c4ff92SAndroid Build Coastguard Worker void ReduceSumWithTwoAxesTest(Compute backendId)
149*89c4ff92SAndroid Build Coastguard Worker {
150*89c4ff92SAndroid Build Coastguard Worker     armnn::ReduceDescriptor reduceDescriptor;
151*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_vAxis = {1, 2};
152*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_KeepDims = false;
153*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_ReduceOperation = armnn::ReduceOperation::Sum;
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker     TensorShape inputShape = {1, 3, 2, 4};
156*89c4ff92SAndroid Build Coastguard Worker     TensorShape outputShape = {1, 4};
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker     // Construct ArmNN network
159*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreateSimpleReduceNetwork(reduceDescriptor, inputShape, outputShape);
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output.
162*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> inputData({1.0f, 2.0f, 3.0f, 4.0f,
163*89c4ff92SAndroid Build Coastguard Worker                                         5.0f, 6.0f, 7.0f, 8.0f,
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker                                         10.0f, 20.0f, 30.0f, 40.0f,
166*89c4ff92SAndroid Build Coastguard Worker                                         50.0f, 60.0f, 70.0f, 80.0f,
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker                                         100.0f, 200.0f, 300.0f, 400.0f,
169*89c4ff92SAndroid Build Coastguard Worker                                         500.0f, 600.0f, 700.0f, 800.0f});
170*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> expectedOutput({666.0f, 888.0f, 1110.0f, 1332.0f});
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker     ReduceWithMultipleAxesTest(network,
173*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
174*89c4ff92SAndroid Build Coastguard Worker                                inputData,
175*89c4ff92SAndroid Build Coastguard Worker                                expectedOutput,
176*89c4ff92SAndroid Build Coastguard Worker                                reduceDescriptor.m_vAxis.size(),
177*89c4ff92SAndroid Build Coastguard Worker                                backendId);
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker 
ReduceSumWithThreeAxesKeepDimsTest(Compute backendId)180*89c4ff92SAndroid Build Coastguard Worker void ReduceSumWithThreeAxesKeepDimsTest(Compute backendId)
181*89c4ff92SAndroid Build Coastguard Worker {
182*89c4ff92SAndroid Build Coastguard Worker     armnn::ReduceDescriptor reduceDescriptor;
183*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_vAxis = {0, 2, 3};
184*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_KeepDims = true;
185*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_ReduceOperation = armnn::ReduceOperation::Sum;
186*89c4ff92SAndroid Build Coastguard Worker 
187*89c4ff92SAndroid Build Coastguard Worker     TensorShape inputShape = {2, 2, 2, 2};
188*89c4ff92SAndroid Build Coastguard Worker     TensorShape outputShape = {1, 2, 1, 1};
189*89c4ff92SAndroid Build Coastguard Worker 
190*89c4ff92SAndroid Build Coastguard Worker     // Construct ArmNN network
191*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreateSimpleReduceNetwork(reduceDescriptor, inputShape, outputShape);
192*89c4ff92SAndroid Build Coastguard Worker 
193*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output.
194*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> inputData({1.0f, 2.0f,
195*89c4ff92SAndroid Build Coastguard Worker                                         3.0f, 4.0f,
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker                                         5.0f, 6.0f,
198*89c4ff92SAndroid Build Coastguard Worker                                         7.0f, 8.0f,
199*89c4ff92SAndroid Build Coastguard Worker 
200*89c4ff92SAndroid Build Coastguard Worker                                         10.0f, 20.0f,
201*89c4ff92SAndroid Build Coastguard Worker                                         30.0f, 40.0f,
202*89c4ff92SAndroid Build Coastguard Worker 
203*89c4ff92SAndroid Build Coastguard Worker                                         50.0f, 60.0f,
204*89c4ff92SAndroid Build Coastguard Worker                                         70.0f, 80.0f});
205*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> expectedOutput({110.0f, 286.0f});
206*89c4ff92SAndroid Build Coastguard Worker 
207*89c4ff92SAndroid Build Coastguard Worker     ReduceWithMultipleAxesTest(network,
208*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
209*89c4ff92SAndroid Build Coastguard Worker                                inputData,
210*89c4ff92SAndroid Build Coastguard Worker                                expectedOutput,
211*89c4ff92SAndroid Build Coastguard Worker                                reduceDescriptor.m_vAxis.size(),
212*89c4ff92SAndroid Build Coastguard Worker                                backendId);
213*89c4ff92SAndroid Build Coastguard Worker }
214*89c4ff92SAndroid Build Coastguard Worker 
ReduceSumWithThreeAxesTest(Compute backendId)215*89c4ff92SAndroid Build Coastguard Worker void ReduceSumWithThreeAxesTest(Compute backendId)
216*89c4ff92SAndroid Build Coastguard Worker {
217*89c4ff92SAndroid Build Coastguard Worker     armnn::ReduceDescriptor reduceDescriptor;
218*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_vAxis = {0, 2, 3};
219*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_KeepDims = false;
220*89c4ff92SAndroid Build Coastguard Worker     reduceDescriptor.m_ReduceOperation = armnn::ReduceOperation::Sum;
221*89c4ff92SAndroid Build Coastguard Worker 
222*89c4ff92SAndroid Build Coastguard Worker     TensorShape inputShape = {2, 2, 2, 2};
223*89c4ff92SAndroid Build Coastguard Worker     TensorShape outputShape = {2};
224*89c4ff92SAndroid Build Coastguard Worker 
225*89c4ff92SAndroid Build Coastguard Worker     // Construct ArmNN network
226*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreateSimpleReduceNetwork(reduceDescriptor, inputShape, outputShape);
227*89c4ff92SAndroid Build Coastguard Worker 
228*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output.
229*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> inputData({1.0f, 2.0f,
230*89c4ff92SAndroid Build Coastguard Worker                                         3.0f, 4.0f,
231*89c4ff92SAndroid Build Coastguard Worker 
232*89c4ff92SAndroid Build Coastguard Worker                                         5.0f, 6.0f,
233*89c4ff92SAndroid Build Coastguard Worker                                         7.0f, 8.0f,
234*89c4ff92SAndroid Build Coastguard Worker 
235*89c4ff92SAndroid Build Coastguard Worker                                         10.0f, 20.0f,
236*89c4ff92SAndroid Build Coastguard Worker                                         30.0f, 40.0f,
237*89c4ff92SAndroid Build Coastguard Worker 
238*89c4ff92SAndroid Build Coastguard Worker                                         50.0f, 60.0f,
239*89c4ff92SAndroid Build Coastguard Worker                                         70.0f, 80.0f});
240*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> expectedOutput({110.0f, 286.0f});
241*89c4ff92SAndroid Build Coastguard Worker 
242*89c4ff92SAndroid Build Coastguard Worker     ReduceWithMultipleAxesTest(network,
243*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
244*89c4ff92SAndroid Build Coastguard Worker                                inputData,
245*89c4ff92SAndroid Build Coastguard Worker                                expectedOutput,
246*89c4ff92SAndroid Build Coastguard Worker                                reduceDescriptor.m_vAxis.size(),
247*89c4ff92SAndroid Build Coastguard Worker                                backendId);
248*89c4ff92SAndroid Build Coastguard Worker }
249*89c4ff92SAndroid Build Coastguard Worker #endif
250*89c4ff92SAndroid Build Coastguard Worker }
251*89c4ff92SAndroid Build Coastguard Worker 
252*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
253*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Optimizer_ReduceMultipleAxesCpu")
254*89c4ff92SAndroid Build Coastguard Worker {
255*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithTwoAxesKeepDimsCpuAccTest")
256*89c4ff92SAndroid Build Coastguard Worker {
257*89c4ff92SAndroid Build Coastguard Worker     ReduceSumWithTwoAxesKeepDimsTest(Compute::CpuAcc);
258*89c4ff92SAndroid Build Coastguard Worker }
259*89c4ff92SAndroid Build Coastguard Worker 
260*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithTwoAxesCpuAccTest")
261*89c4ff92SAndroid Build Coastguard Worker {
262*89c4ff92SAndroid Build Coastguard Worker     ReduceSumWithTwoAxesTest(Compute::CpuAcc);
263*89c4ff92SAndroid Build Coastguard Worker }
264*89c4ff92SAndroid Build Coastguard Worker 
265*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithThreeAxesKeepDimsCpuAccTest")
266*89c4ff92SAndroid Build Coastguard Worker {
267*89c4ff92SAndroid Build Coastguard Worker     ReduceSumWithThreeAxesKeepDimsTest(Compute::CpuAcc);
268*89c4ff92SAndroid Build Coastguard Worker }
269*89c4ff92SAndroid Build Coastguard Worker 
270*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithThreeAxesCpuAccTest")
271*89c4ff92SAndroid Build Coastguard Worker {
272*89c4ff92SAndroid Build Coastguard Worker     ReduceSumWithThreeAxesTest(Compute::CpuAcc);
273*89c4ff92SAndroid Build Coastguard Worker }
274*89c4ff92SAndroid Build Coastguard Worker }
275*89c4ff92SAndroid Build Coastguard Worker #endif
276*89c4ff92SAndroid Build Coastguard Worker 
277*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
278*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Optimizer_ReduceMultipleAxesGpu")
279*89c4ff92SAndroid Build Coastguard Worker {
280*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithTwoAxesKeepDimsGpuAccTest")
281*89c4ff92SAndroid Build Coastguard Worker {
282*89c4ff92SAndroid Build Coastguard Worker     ReduceSumWithTwoAxesKeepDimsTest(Compute::GpuAcc);
283*89c4ff92SAndroid Build Coastguard Worker }
284*89c4ff92SAndroid Build Coastguard Worker 
285*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithTwoAxesGpuAccTest")
286*89c4ff92SAndroid Build Coastguard Worker {
287*89c4ff92SAndroid Build Coastguard Worker     ReduceSumWithTwoAxesTest(Compute::GpuAcc);
288*89c4ff92SAndroid Build Coastguard Worker }
289*89c4ff92SAndroid Build Coastguard Worker 
290*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithThreeAxesKeepDimsGpuAccTest")
291*89c4ff92SAndroid Build Coastguard Worker {
292*89c4ff92SAndroid Build Coastguard Worker     ReduceSumWithThreeAxesKeepDimsTest(Compute::GpuAcc);
293*89c4ff92SAndroid Build Coastguard Worker }
294*89c4ff92SAndroid Build Coastguard Worker 
295*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithThreeAxesGpuAccTest")
296*89c4ff92SAndroid Build Coastguard Worker {
297*89c4ff92SAndroid Build Coastguard Worker     ReduceSumWithThreeAxesTest(Compute::GpuAcc);
298*89c4ff92SAndroid Build Coastguard Worker }
299*89c4ff92SAndroid Build Coastguard Worker }
300*89c4ff92SAndroid Build Coastguard Worker #endif
301