1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <GraphUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <TestUtils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)||defined(ARMCOMPUTECL_ENABLED)
CreateSimpleReduceNetwork(ReduceDescriptor reduceDescriptor,TensorShape & inputShape,TensorShape & outputShape)18*89c4ff92SAndroid Build Coastguard Worker INetworkPtr CreateSimpleReduceNetwork(ReduceDescriptor reduceDescriptor,
19*89c4ff92SAndroid Build Coastguard Worker TensorShape& inputShape,
20*89c4ff92SAndroid Build Coastguard Worker TensorShape& outputShape)
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker // Create a network
23*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = INetwork::Create();
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker const std::string layerName("reduce_layer");
26*89c4ff92SAndroid Build Coastguard Worker const TensorInfo inputInfo(inputShape, DataType::Float32);
27*89c4ff92SAndroid Build Coastguard Worker const TensorInfo outputInfo(outputShape, DataType::Float32);
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* const inputLayer = network->AddInputLayer(0);
30*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* const reduceLayer = network->AddReduceLayer(reduceDescriptor, layerName.c_str());
31*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* const outputLayer1 = network->AddOutputLayer(0);
32*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* const outputLayer2 = network->AddOutputLayer(1);
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
35*89c4ff92SAndroid Build Coastguard Worker reduceLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
36*89c4ff92SAndroid Build Coastguard Worker
37*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(reduceLayer->GetInputSlot(0));
38*89c4ff92SAndroid Build Coastguard Worker reduceLayer->GetOutputSlot(0).Connect(outputLayer1->GetInputSlot(0));
39*89c4ff92SAndroid Build Coastguard Worker reduceLayer->GetOutputSlot(0).Connect(outputLayer2->GetInputSlot(0));
40*89c4ff92SAndroid Build Coastguard Worker
41*89c4ff92SAndroid Build Coastguard Worker return network;
42*89c4ff92SAndroid Build Coastguard Worker }
43*89c4ff92SAndroid Build Coastguard Worker
ReduceWithMultipleAxesTest(INetworkPtr & network,const TensorShape & outputShape,const std::vector<float> & inputData,const std::vector<float> & expectedOutput,const size_t numOfAxes,Compute backendId)44*89c4ff92SAndroid Build Coastguard Worker void ReduceWithMultipleAxesTest(INetworkPtr& network,
45*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape,
46*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& inputData,
47*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& expectedOutput,
48*89c4ff92SAndroid Build Coastguard Worker const size_t numOfAxes,
49*89c4ff92SAndroid Build Coastguard Worker Compute backendId)
50*89c4ff92SAndroid Build Coastguard Worker {
51*89c4ff92SAndroid Build Coastguard Worker // Create ArmNN runtime
52*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr run = IRuntime::Create(IRuntime::CreationOptions());
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker // Optimise ArmNN network
55*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*network, {backendId}, run->GetDeviceSpec());
56*89c4ff92SAndroid Build Coastguard Worker
57*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get());
58*89c4ff92SAndroid Build Coastguard Worker if (numOfAxes == 2)
59*89c4ff92SAndroid Build Coastguard Worker {
60*89c4ff92SAndroid Build Coastguard Worker CHECK(graph.GetNumLayers() == 5);
61*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(),
62*89c4ff92SAndroid Build Coastguard Worker graph.cend(),
63*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>,
64*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ReduceLayer>,
65*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ReduceLayer>,
66*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>,
67*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>));
68*89c4ff92SAndroid Build Coastguard Worker } else
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker CHECK(graph.GetNumLayers() == 6);
71*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(),
72*89c4ff92SAndroid Build Coastguard Worker graph.cend(),
73*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>,
74*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ReduceLayer>,
75*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ReduceLayer>,
76*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ReduceLayer>,
77*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>,
78*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>));
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker // Get last layer in new chain, layers name follow 0, 1, 2 pattern
82*89c4ff92SAndroid Build Coastguard Worker std::string layerName = "reduce_layer_" + std::to_string(numOfAxes - 1);
83*89c4ff92SAndroid Build Coastguard Worker Layer* const reduceLayer = GetFirstLayerWithName(graph, layerName);
84*89c4ff92SAndroid Build Coastguard Worker CHECK(reduceLayer);
85*89c4ff92SAndroid Build Coastguard Worker auto reduceTensorInfo = reduceLayer->GetOutputSlot().GetTensorInfo();
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker // Tensorshape and the data type are correct
88*89c4ff92SAndroid Build Coastguard Worker CHECK((reduceTensorInfo.GetShape() == outputShape));
89*89c4ff92SAndroid Build Coastguard Worker CHECK((reduceTensorInfo.GetDataType() == DataType::Float32));
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker // Load network into runtime
92*89c4ff92SAndroid Build Coastguard Worker NetworkId networkIdentifier;
93*89c4ff92SAndroid Build Coastguard Worker run->LoadNetwork(networkIdentifier, std::move(optNet));
94*89c4ff92SAndroid Build Coastguard Worker
95*89c4ff92SAndroid Build Coastguard Worker // Create input and output tensors
96*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(expectedOutput.size());
97*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0);
98*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true);
99*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors
100*89c4ff92SAndroid Build Coastguard Worker {
101*89c4ff92SAndroid Build Coastguard Worker {0, armnn::ConstTensor(inputTensorInfo, inputData.data())}
102*89c4ff92SAndroid Build Coastguard Worker };
103*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())},
106*89c4ff92SAndroid Build Coastguard Worker {1, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 1), outputData.data())}
107*89c4ff92SAndroid Build Coastguard Worker };
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker // Run inference
110*89c4ff92SAndroid Build Coastguard Worker run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
111*89c4ff92SAndroid Build Coastguard Worker
112*89c4ff92SAndroid Build Coastguard Worker // Checks the results
113*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput);
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker
ReduceSumWithTwoAxesKeepDimsTest(Compute backendId)116*89c4ff92SAndroid Build Coastguard Worker void ReduceSumWithTwoAxesKeepDimsTest(Compute backendId)
117*89c4ff92SAndroid Build Coastguard Worker {
118*89c4ff92SAndroid Build Coastguard Worker armnn::ReduceDescriptor reduceDescriptor;
119*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_vAxis = {1, 2};
120*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_KeepDims = true;
121*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_ReduceOperation = armnn::ReduceOperation::Sum;
122*89c4ff92SAndroid Build Coastguard Worker
123*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = {1, 3, 2, 4};
124*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = {1, 1, 1, 4};
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker // Construct ArmNN network
127*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = CreateSimpleReduceNetwork(reduceDescriptor, inputShape, outputShape);
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
130*89c4ff92SAndroid Build Coastguard Worker const std::vector<float> inputData({1.0f, 2.0f, 3.0f, 4.0f,
131*89c4ff92SAndroid Build Coastguard Worker 5.0f, 6.0f, 7.0f, 8.0f,
132*89c4ff92SAndroid Build Coastguard Worker
133*89c4ff92SAndroid Build Coastguard Worker 10.0f, 20.0f, 30.0f, 40.0f,
134*89c4ff92SAndroid Build Coastguard Worker 50.0f, 60.0f, 70.0f, 80.0f,
135*89c4ff92SAndroid Build Coastguard Worker
136*89c4ff92SAndroid Build Coastguard Worker 100.0f, 200.0f, 300.0f, 400.0f,
137*89c4ff92SAndroid Build Coastguard Worker 500.0f, 600.0f, 700.0f, 800.0f});
138*89c4ff92SAndroid Build Coastguard Worker const std::vector<float> expectedOutput({666.0f, 888.0f, 1110.0f, 1332.0f});
139*89c4ff92SAndroid Build Coastguard Worker
140*89c4ff92SAndroid Build Coastguard Worker ReduceWithMultipleAxesTest(network,
141*89c4ff92SAndroid Build Coastguard Worker outputShape,
142*89c4ff92SAndroid Build Coastguard Worker inputData,
143*89c4ff92SAndroid Build Coastguard Worker expectedOutput,
144*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_vAxis.size(),
145*89c4ff92SAndroid Build Coastguard Worker backendId);
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker
ReduceSumWithTwoAxesTest(Compute backendId)148*89c4ff92SAndroid Build Coastguard Worker void ReduceSumWithTwoAxesTest(Compute backendId)
149*89c4ff92SAndroid Build Coastguard Worker {
150*89c4ff92SAndroid Build Coastguard Worker armnn::ReduceDescriptor reduceDescriptor;
151*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_vAxis = {1, 2};
152*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_KeepDims = false;
153*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_ReduceOperation = armnn::ReduceOperation::Sum;
154*89c4ff92SAndroid Build Coastguard Worker
155*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = {1, 3, 2, 4};
156*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = {1, 4};
157*89c4ff92SAndroid Build Coastguard Worker
158*89c4ff92SAndroid Build Coastguard Worker // Construct ArmNN network
159*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = CreateSimpleReduceNetwork(reduceDescriptor, inputShape, outputShape);
160*89c4ff92SAndroid Build Coastguard Worker
161*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
162*89c4ff92SAndroid Build Coastguard Worker const std::vector<float> inputData({1.0f, 2.0f, 3.0f, 4.0f,
163*89c4ff92SAndroid Build Coastguard Worker 5.0f, 6.0f, 7.0f, 8.0f,
164*89c4ff92SAndroid Build Coastguard Worker
165*89c4ff92SAndroid Build Coastguard Worker 10.0f, 20.0f, 30.0f, 40.0f,
166*89c4ff92SAndroid Build Coastguard Worker 50.0f, 60.0f, 70.0f, 80.0f,
167*89c4ff92SAndroid Build Coastguard Worker
168*89c4ff92SAndroid Build Coastguard Worker 100.0f, 200.0f, 300.0f, 400.0f,
169*89c4ff92SAndroid Build Coastguard Worker 500.0f, 600.0f, 700.0f, 800.0f});
170*89c4ff92SAndroid Build Coastguard Worker const std::vector<float> expectedOutput({666.0f, 888.0f, 1110.0f, 1332.0f});
171*89c4ff92SAndroid Build Coastguard Worker
172*89c4ff92SAndroid Build Coastguard Worker ReduceWithMultipleAxesTest(network,
173*89c4ff92SAndroid Build Coastguard Worker outputShape,
174*89c4ff92SAndroid Build Coastguard Worker inputData,
175*89c4ff92SAndroid Build Coastguard Worker expectedOutput,
176*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_vAxis.size(),
177*89c4ff92SAndroid Build Coastguard Worker backendId);
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker
ReduceSumWithThreeAxesKeepDimsTest(Compute backendId)180*89c4ff92SAndroid Build Coastguard Worker void ReduceSumWithThreeAxesKeepDimsTest(Compute backendId)
181*89c4ff92SAndroid Build Coastguard Worker {
182*89c4ff92SAndroid Build Coastguard Worker armnn::ReduceDescriptor reduceDescriptor;
183*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_vAxis = {0, 2, 3};
184*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_KeepDims = true;
185*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_ReduceOperation = armnn::ReduceOperation::Sum;
186*89c4ff92SAndroid Build Coastguard Worker
187*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = {2, 2, 2, 2};
188*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = {1, 2, 1, 1};
189*89c4ff92SAndroid Build Coastguard Worker
190*89c4ff92SAndroid Build Coastguard Worker // Construct ArmNN network
191*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = CreateSimpleReduceNetwork(reduceDescriptor, inputShape, outputShape);
192*89c4ff92SAndroid Build Coastguard Worker
193*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
194*89c4ff92SAndroid Build Coastguard Worker const std::vector<float> inputData({1.0f, 2.0f,
195*89c4ff92SAndroid Build Coastguard Worker 3.0f, 4.0f,
196*89c4ff92SAndroid Build Coastguard Worker
197*89c4ff92SAndroid Build Coastguard Worker 5.0f, 6.0f,
198*89c4ff92SAndroid Build Coastguard Worker 7.0f, 8.0f,
199*89c4ff92SAndroid Build Coastguard Worker
200*89c4ff92SAndroid Build Coastguard Worker 10.0f, 20.0f,
201*89c4ff92SAndroid Build Coastguard Worker 30.0f, 40.0f,
202*89c4ff92SAndroid Build Coastguard Worker
203*89c4ff92SAndroid Build Coastguard Worker 50.0f, 60.0f,
204*89c4ff92SAndroid Build Coastguard Worker 70.0f, 80.0f});
205*89c4ff92SAndroid Build Coastguard Worker const std::vector<float> expectedOutput({110.0f, 286.0f});
206*89c4ff92SAndroid Build Coastguard Worker
207*89c4ff92SAndroid Build Coastguard Worker ReduceWithMultipleAxesTest(network,
208*89c4ff92SAndroid Build Coastguard Worker outputShape,
209*89c4ff92SAndroid Build Coastguard Worker inputData,
210*89c4ff92SAndroid Build Coastguard Worker expectedOutput,
211*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_vAxis.size(),
212*89c4ff92SAndroid Build Coastguard Worker backendId);
213*89c4ff92SAndroid Build Coastguard Worker }
214*89c4ff92SAndroid Build Coastguard Worker
ReduceSumWithThreeAxesTest(Compute backendId)215*89c4ff92SAndroid Build Coastguard Worker void ReduceSumWithThreeAxesTest(Compute backendId)
216*89c4ff92SAndroid Build Coastguard Worker {
217*89c4ff92SAndroid Build Coastguard Worker armnn::ReduceDescriptor reduceDescriptor;
218*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_vAxis = {0, 2, 3};
219*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_KeepDims = false;
220*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_ReduceOperation = armnn::ReduceOperation::Sum;
221*89c4ff92SAndroid Build Coastguard Worker
222*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = {2, 2, 2, 2};
223*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = {2};
224*89c4ff92SAndroid Build Coastguard Worker
225*89c4ff92SAndroid Build Coastguard Worker // Construct ArmNN network
226*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = CreateSimpleReduceNetwork(reduceDescriptor, inputShape, outputShape);
227*89c4ff92SAndroid Build Coastguard Worker
228*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
229*89c4ff92SAndroid Build Coastguard Worker const std::vector<float> inputData({1.0f, 2.0f,
230*89c4ff92SAndroid Build Coastguard Worker 3.0f, 4.0f,
231*89c4ff92SAndroid Build Coastguard Worker
232*89c4ff92SAndroid Build Coastguard Worker 5.0f, 6.0f,
233*89c4ff92SAndroid Build Coastguard Worker 7.0f, 8.0f,
234*89c4ff92SAndroid Build Coastguard Worker
235*89c4ff92SAndroid Build Coastguard Worker 10.0f, 20.0f,
236*89c4ff92SAndroid Build Coastguard Worker 30.0f, 40.0f,
237*89c4ff92SAndroid Build Coastguard Worker
238*89c4ff92SAndroid Build Coastguard Worker 50.0f, 60.0f,
239*89c4ff92SAndroid Build Coastguard Worker 70.0f, 80.0f});
240*89c4ff92SAndroid Build Coastguard Worker const std::vector<float> expectedOutput({110.0f, 286.0f});
241*89c4ff92SAndroid Build Coastguard Worker
242*89c4ff92SAndroid Build Coastguard Worker ReduceWithMultipleAxesTest(network,
243*89c4ff92SAndroid Build Coastguard Worker outputShape,
244*89c4ff92SAndroid Build Coastguard Worker inputData,
245*89c4ff92SAndroid Build Coastguard Worker expectedOutput,
246*89c4ff92SAndroid Build Coastguard Worker reduceDescriptor.m_vAxis.size(),
247*89c4ff92SAndroid Build Coastguard Worker backendId);
248*89c4ff92SAndroid Build Coastguard Worker }
249*89c4ff92SAndroid Build Coastguard Worker #endif
250*89c4ff92SAndroid Build Coastguard Worker }
251*89c4ff92SAndroid Build Coastguard Worker
252*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
253*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Optimizer_ReduceMultipleAxesCpu")
254*89c4ff92SAndroid Build Coastguard Worker {
255*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithTwoAxesKeepDimsCpuAccTest")
256*89c4ff92SAndroid Build Coastguard Worker {
257*89c4ff92SAndroid Build Coastguard Worker ReduceSumWithTwoAxesKeepDimsTest(Compute::CpuAcc);
258*89c4ff92SAndroid Build Coastguard Worker }
259*89c4ff92SAndroid Build Coastguard Worker
260*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithTwoAxesCpuAccTest")
261*89c4ff92SAndroid Build Coastguard Worker {
262*89c4ff92SAndroid Build Coastguard Worker ReduceSumWithTwoAxesTest(Compute::CpuAcc);
263*89c4ff92SAndroid Build Coastguard Worker }
264*89c4ff92SAndroid Build Coastguard Worker
265*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithThreeAxesKeepDimsCpuAccTest")
266*89c4ff92SAndroid Build Coastguard Worker {
267*89c4ff92SAndroid Build Coastguard Worker ReduceSumWithThreeAxesKeepDimsTest(Compute::CpuAcc);
268*89c4ff92SAndroid Build Coastguard Worker }
269*89c4ff92SAndroid Build Coastguard Worker
270*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithThreeAxesCpuAccTest")
271*89c4ff92SAndroid Build Coastguard Worker {
272*89c4ff92SAndroid Build Coastguard Worker ReduceSumWithThreeAxesTest(Compute::CpuAcc);
273*89c4ff92SAndroid Build Coastguard Worker }
274*89c4ff92SAndroid Build Coastguard Worker }
275*89c4ff92SAndroid Build Coastguard Worker #endif
276*89c4ff92SAndroid Build Coastguard Worker
277*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
278*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Optimizer_ReduceMultipleAxesGpu")
279*89c4ff92SAndroid Build Coastguard Worker {
280*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithTwoAxesKeepDimsGpuAccTest")
281*89c4ff92SAndroid Build Coastguard Worker {
282*89c4ff92SAndroid Build Coastguard Worker ReduceSumWithTwoAxesKeepDimsTest(Compute::GpuAcc);
283*89c4ff92SAndroid Build Coastguard Worker }
284*89c4ff92SAndroid Build Coastguard Worker
285*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithTwoAxesGpuAccTest")
286*89c4ff92SAndroid Build Coastguard Worker {
287*89c4ff92SAndroid Build Coastguard Worker ReduceSumWithTwoAxesTest(Compute::GpuAcc);
288*89c4ff92SAndroid Build Coastguard Worker }
289*89c4ff92SAndroid Build Coastguard Worker
290*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithThreeAxesKeepDimsGpuAccTest")
291*89c4ff92SAndroid Build Coastguard Worker {
292*89c4ff92SAndroid Build Coastguard Worker ReduceSumWithThreeAxesKeepDimsTest(Compute::GpuAcc);
293*89c4ff92SAndroid Build Coastguard Worker }
294*89c4ff92SAndroid Build Coastguard Worker
295*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceSumWithThreeAxesGpuAccTest")
296*89c4ff92SAndroid Build Coastguard Worker {
297*89c4ff92SAndroid Build Coastguard Worker ReduceSumWithThreeAxesTest(Compute::GpuAcc);
298*89c4ff92SAndroid Build Coastguard Worker }
299*89c4ff92SAndroid Build Coastguard Worker }
300*89c4ff92SAndroid Build Coastguard Worker #endif
301