xref: /aosp_15_r20/external/armnn/src/backends/reference/test/RefCreateWorkloadTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <CreateWorkload.hpp>
7 
8 #include <armnn/utility/PolymorphicDowncast.hpp>
9 #include <reference/RefTensorHandle.hpp>
10 #include <reference/RefTensorHandleFactory.hpp>
11 #include <reference/RefWorkloadFactory.hpp>
12 #include <reference/workloads/RefWorkloads.hpp>
13 
14 #include <doctest/doctest.h>
15 
16 namespace
17 {
18 
19 template<typename Workload>
CheckInputOutput(std::unique_ptr<Workload> workload,const TensorInfo & inputInfo,const TensorInfo & outputInfo)20 void CheckInputOutput(std::unique_ptr<Workload> workload, const TensorInfo& inputInfo, const TensorInfo& outputInfo)
21 {
22     auto queueDescriptor = workload->GetData();
23     auto inputHandle  = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
24     auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
25     CHECK((inputHandle->GetTensorInfo() == inputInfo));
26     CHECK((outputHandle->GetTensorInfo() == outputInfo));
27 }
28 
29 template <typename Workload>
CheckInputsOutput(std::unique_ptr<Workload> workload,const TensorInfo & inputInfo0,const TensorInfo & inputInfo1,const TensorInfo & outputInfo)30 void CheckInputsOutput(std::unique_ptr<Workload> workload,
31                        const TensorInfo&         inputInfo0,
32                        const TensorInfo&         inputInfo1,
33                        const TensorInfo&         outputInfo)
34 {
35     auto queueDescriptor = workload->GetData();
36     auto inputHandle0     = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
37     auto inputHandle1     = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[1]);
38     auto outputHandle    = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
39     CHECK((inputHandle0->GetTensorInfo() == inputInfo0));
40     CHECK((inputHandle1->GetTensorInfo() == inputInfo1));
41     CHECK((outputHandle->GetTensorInfo() == outputInfo));
42 }
43 
GetFactory()44 armnn::RefWorkloadFactory GetFactory()
45 {
46     std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
47     return RefWorkloadFactory(memoryManager);
48 }
49 
50 }
51 
52 TEST_SUITE("CreateWorkloadRef")
53 {
54 template <typename ActivationWorkloadType, armnn::DataType DataType>
RefCreateActivationWorkloadTest()55 static void RefCreateActivationWorkloadTest()
56 {
57     Graph graph;
58     RefWorkloadFactory factory = GetFactory();
59     auto workload = CreateActivationWorkloadTest<ActivationWorkloadType, DataType>(factory, graph);
60 
61     // Checks that outputs are as we expect them (see definition of CreateActivationWorkloadTest).
62     CheckInputOutput(std::move(workload),
63         TensorInfo({ 1, 1 }, DataType),
64         TensorInfo({ 1, 1 }, DataType));
65 }
66 
67 TEST_CASE("CreateActivationFloat32Workload")
68 {
69     RefCreateActivationWorkloadTest<RefActivationWorkload, armnn::DataType::Float32>();
70 }
71 
72 TEST_CASE("CreateActivationUint8Workload")
73 {
74     RefCreateActivationWorkloadTest<RefActivationWorkload, armnn::DataType::QAsymmU8>();
75 }
76 
77 template <typename WorkloadType,
78           typename DescriptorType,
79           typename LayerType,
80           armnn::DataType DataType>
RefCreateElementwiseWorkloadTest()81 static void RefCreateElementwiseWorkloadTest()
82 {
83     Graph graph;
84     RefWorkloadFactory factory = GetFactory();
85     auto workload = CreateElementwiseWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>(
86         factory, graph);
87 
88     CheckInputsOutput(std::move(workload),
89         TensorInfo({ 2, 3 }, DataType),
90         TensorInfo({ 2, 3 }, DataType),
91         TensorInfo({ 2, 3 }, DataType));
92 }
93 
94 TEST_CASE("CreateSubtractionWorkloadWithBlobTest")
95 {
96     Graph graph;
97     RefWorkloadFactory factory = GetFactory();
98     armnn::DataType DataType = armnn::DataType::Float32;
99 
100     auto workload = CreateSubtractionWithBlobWorkloadTest<RefSubtractionWorkload<>,
101                                                           SubtractionQueueDescriptor,
102                                                           armnn::DataType::Float32>
103                                                           (factory, graph);
104 
105     CheckInputsOutput(std::move(workload),
106         TensorInfo({ 2, 3 }, DataType),
107         TensorInfo({ 2, 3 }, DataType),
108         TensorInfo({ 2, 3 }, DataType));
109 }
110 
111 TEST_CASE("CreateAdditionWorkloadWithBlobTest")
112 {
113     Graph graph;
114     RefWorkloadFactory factory = GetFactory();
115     armnn::DataType DataType = armnn::DataType::Float32;
116 
117     auto workload = CreateAdditionWithBlobWorkloadTest<RefAdditionWorkload<>,
118                                                        AdditionQueueDescriptor,
119                                                        armnn::DataType::Float32>(factory, graph);
120 
121     CheckInputsOutput(std::move(workload),
122         TensorInfo({ 2, 3 }, DataType),
123         TensorInfo({ 2, 3 }, DataType),
124         TensorInfo({ 2, 3 }, DataType));
125 }
126 
127 TEST_CASE("CreateMultiplicationWorkloadWithBlobTest")
128 {
129     Graph              graph;
130     RefWorkloadFactory factory  = GetFactory();
131     armnn::DataType    DataType = armnn::DataType::Float32;
132 
133     auto workload = CreateMultiplicationWithBlobWorkloadTest<RefMultiplicationWorkload<>,
134                                                              MultiplicationQueueDescriptor,
135                                                              armnn::DataType::Float32>(factory, graph);
136 
137     CheckInputsOutput(std::move(workload),
138                       TensorInfo({2, 3}, DataType),
139                       TensorInfo({2, 3}, DataType),
140                       TensorInfo({2, 3}, DataType));
141 }
142 
143 TEST_CASE("CreateAdditionFloatWorkload")
144 {
145     RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
146         AdditionQueueDescriptor,
147         AdditionLayer,
148         armnn::DataType::Float32>();
149 }
150 
151 TEST_CASE("CreateAdditionUint8Workload")
152 {
153     RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
154         AdditionQueueDescriptor,
155         AdditionLayer,
156         armnn::DataType::QAsymmU8>();
157 }
158 
159 TEST_CASE("CreateAdditionInt16Workload")
160 {
161     RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
162         AdditionQueueDescriptor,
163         AdditionLayer,
164         armnn::DataType::QSymmS16>();
165 }
166 
167 TEST_CASE("CreateAdditionInt32Workload")
168 {
169     RefCreateElementwiseWorkloadTest<RefAdditionWorkload<int32_t>,
170             AdditionQueueDescriptor,
171             AdditionLayer,
172             armnn::DataType::Signed32>();
173 }
174 
175 TEST_CASE("CreateSubtractionFloat32Workload")
176 {
177     RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
178         SubtractionQueueDescriptor,
179         SubtractionLayer,
180         armnn::DataType::Float32>();
181 }
182 
183 TEST_CASE("CreateSubtractionFloat16Workload")
184 {
185     RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
186         SubtractionQueueDescriptor,
187         SubtractionLayer,
188         armnn::DataType::Float16>();
189 }
190 
191 TEST_CASE("CreateSubtractionUint8Workload")
192 {
193     RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
194         SubtractionQueueDescriptor,
195         SubtractionLayer,
196         armnn::DataType::QAsymmU8>();
197 }
198 
199 TEST_CASE("CreateSubtractionInt16Workload")
200 {
201     RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
202         SubtractionQueueDescriptor,
203         SubtractionLayer,
204         armnn::DataType::QSymmS16>();
205 }
206 
207 TEST_CASE("CreateSubtractionInt32Workload")
208 {
209     RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<int32_t>,
210             SubtractionQueueDescriptor,
211             SubtractionLayer,
212             armnn::DataType::Signed32>();
213 }
214 
215 TEST_CASE("CreateMultiplicationFloatWorkload")
216 {
217     RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
218         MultiplicationQueueDescriptor,
219         MultiplicationLayer,
220         armnn::DataType::Float32>();
221 }
222 
223 TEST_CASE("CreateMultiplicationUint8Workload")
224 {
225     RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
226         MultiplicationQueueDescriptor,
227         MultiplicationLayer,
228         armnn::DataType::QAsymmU8>();
229 }
230 
231 TEST_CASE("CreateMultiplicationInt16Workload")
232 {
233     RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
234         MultiplicationQueueDescriptor,
235         MultiplicationLayer,
236         armnn::DataType::QSymmS16>();
237 }
238 
239 TEST_CASE("CreateMultiplicationInt32Workload")
240 {
241     RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<int32_t>,
242             MultiplicationQueueDescriptor,
243             MultiplicationLayer,
244             armnn::DataType::Signed32>();
245 }
246 
247 TEST_CASE("CreateDivisionFloat32Workload")
248 {
249     RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
250         DivisionQueueDescriptor,
251         DivisionLayer,
252         armnn::DataType::Float32>();
253 }
254 
255 TEST_CASE("CreateDivisionFloat16Workload")
256 {
257     RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
258         DivisionQueueDescriptor,
259         DivisionLayer,
260         armnn::DataType::Float16>();
261 }
262 
263 TEST_CASE("CreateDivisionUint8Workload")
264 {
265     RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
266         DivisionQueueDescriptor,
267         DivisionLayer,
268         armnn::DataType::QAsymmU8>();
269 }
270 
271 TEST_CASE("CreateDivisionInt16Workload")
272 {
273     RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
274         DivisionQueueDescriptor,
275         DivisionLayer,
276         armnn::DataType::QSymmS16>();
277 }
278 
279 TEST_CASE("CreateDivisionInt32Workload")
280 {
281     RefCreateElementwiseWorkloadTest<RefDivisionWorkload<int32_t>,
282             DivisionQueueDescriptor,
283             DivisionLayer,
284             armnn::DataType::Signed32>();
285 }
286 
287 template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)288 static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
289 {
290     Graph graph;
291     RefWorkloadFactory factory = GetFactory();
292     auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory,
293                                                                                                    graph,
294                                                                                                    dataLayout);
295 
296     TensorShape inputShape;
297     TensorShape outputShape;
298 
299     switch (dataLayout)
300     {
301         case DataLayout::NHWC:
302             inputShape  = { 2, 4, 4, 3 };
303             outputShape = { 2, 4, 4, 3 };
304             break;
305         case DataLayout::NCHW:
306         default:
307             inputShape  = { 2, 3, 4, 4 };
308             outputShape = { 2, 3, 4, 4 };
309             break;
310     }
311 
312     // Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
313     CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
314 }
315 
316 TEST_CASE("CreateBatchNormalizationWithBlobFloat32Workload")
317 {
318     Graph graph;
319     RefWorkloadFactory factory = GetFactory();
320     auto dataType = armnn::DataType::Float32;
321     auto workload = CreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,
322                                                          armnn::DataType::Float32>(factory, graph, DataLayout::NHWC);
323 
324     TensorShape inputShape;
325     TensorShape outputShape;
326 
327     inputShape  = { 2, 4, 4, 3 };
328     outputShape = { 2, 4, 4, 3 };
329 
330     // Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
331     CheckInputOutput(std::move(workload), TensorInfo(inputShape, dataType), TensorInfo(outputShape, dataType));
332 }
333 
334 TEST_CASE("CreateBatchNormalizationFloat32Workload")
335 {
336     RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,armnn::DataType::Float32>
337             (DataLayout::NCHW);
338 }
339 
340 TEST_CASE("CreateBatchNormalizationFloat32WorkloadNhwc")
341 {
342     RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::Float32>
343             (DataLayout::NHWC);
344 }
345 
346 TEST_CASE("CreateBatchNormalizationFloat16Workload")
347 {
348     RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,armnn::DataType::Float16>
349             (DataLayout::NCHW);
350 }
351 
352 TEST_CASE("CreateBatchNormalizationFloat16WorkloadNhwc")
353 {
354     RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::Float16>
355             (DataLayout::NHWC);
356 }
357 
358 TEST_CASE("CreateBatchNormalizationUint8Workload")
359 {
360     RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QAsymmU8>
361             (DataLayout::NCHW);
362 }
363 
364 TEST_CASE("CreateBatchNormalizationUint8WorkloadNhwc")
365 {
366     RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QAsymmU8>
367             (DataLayout::NHWC);
368 }
369 
370 TEST_CASE("CreateBatchNormalizationInt16Workload")
371 {
372     RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QSymmS16>
373             (DataLayout::NCHW);
374 }
375 
376 TEST_CASE("CreateBatchNormalizationInt16WorkloadNhwc")
377 {
378     RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QSymmS16>
379             (DataLayout::NHWC);
380 }
381 
382 TEST_CASE("CreateConvertFp16ToFp32Float32Workload")
383 {
384     Graph                graph;
385     RefWorkloadFactory factory = GetFactory();
386     auto workload = CreateConvertFp16ToFp32WorkloadTest<RefConvertFp16ToFp32Workload>(factory, graph);
387 
388     // Checks that outputs and inputs are as we expect them
389     CheckInputOutput(
390         std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float16), TensorInfo({1, 3, 2, 3}, DataType::Float32));
391 }
392 
393 TEST_CASE("CreateConvertFp32ToFp16Float16Workload")
394 {
395     Graph                graph;
396     RefWorkloadFactory factory = GetFactory();
397     auto workload = CreateConvertFp32ToFp16WorkloadTest<RefConvertFp32ToFp16Workload>(factory, graph);
398 
399     // Checks that outputs and inputs are as we expect them
400     CheckInputOutput(
401         std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float32), TensorInfo({1, 3, 2, 3}, DataType::Float16));
402 }
403 
RefCreateConvolution2dWorkloadTest(DataLayout dataLayout=DataLayout::NCHW)404 static void RefCreateConvolution2dWorkloadTest(DataLayout dataLayout = DataLayout::NCHW)
405 {
406     Graph graph;
407     RefWorkloadFactory factory = GetFactory();
408     auto workload = CreateConvolution2dWorkloadTest<RefConvolution2dWorkload, DataType::Float32>
409                     (factory, graph, dataLayout);
410 
411     TensorShape inputShape  = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 3, 8, 16})
412                                                                : std::initializer_list<unsigned int>({2, 8, 16, 3});
413     TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 2, 2, 10})
414                                                                : std::initializer_list<unsigned int>({2, 2, 10, 2});
415 
416     // Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
417     CheckInputOutput(std::move(workload),
418                      TensorInfo(inputShape, DataType::Float32),
419                      TensorInfo(outputShape, DataType::Float32));
420 }
421 
422 TEST_CASE("CreateConvolution2dFloatNchwWorkload")
423 {
424     RefCreateConvolution2dWorkloadTest(DataLayout::NCHW);
425 }
426 
427 TEST_CASE("CreateConvolution2dFloatNhwcWorkload")
428 {
429     RefCreateConvolution2dWorkloadTest(DataLayout::NHWC);
430 }
431 
432 TEST_CASE("CreateConvolution2dWithBlobWorkload")
433 {
434     DataLayout dataLayout = DataLayout::NHWC;
435     Graph graph;
436     RefWorkloadFactory factory = GetFactory();
437     auto workload = CreateConvolution2dFusedActivationWithBlobWorkloadTest<RefConvolution2dWorkload, DataType::Float32>
438                     (factory, graph, dataLayout);
439 
440     TensorShape inputShape  = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 3, 8, 16})
441                                                                : std::initializer_list<unsigned int>({2, 8, 16, 3});
442     TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 2, 2, 10})
443                                                                : std::initializer_list<unsigned int>({2, 2, 10, 2});
444 
445     // Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
446     CheckInputOutput(std::move(workload),
447                      TensorInfo(inputShape, DataType::Float32),
448                      TensorInfo(outputShape, DataType::Float32));
449 }
450 
RefCreateDepthwiseConvolutionWorkloadTest(DataLayout dataLayout)451 static void RefCreateDepthwiseConvolutionWorkloadTest(DataLayout dataLayout)
452 {
453     Graph graph;
454     RefWorkloadFactory factory = GetFactory();
455     auto workload = CreateDepthwiseConvolution2dWorkloadTest<RefDepthwiseConvolution2dWorkload, DataType::Float32>
456             (factory, graph, dataLayout);
457 
458     TensorShape inputShape  = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({ 2, 2, 5, 5 })
459                                                                : std::initializer_list<unsigned int>({ 2, 5, 5, 2 });
460     TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({ 2, 2, 5, 5 })
461                                                                : std::initializer_list<unsigned int>({ 2, 5, 5, 2 });
462 
463     // Checks that inputs/outputs are as we expect them (see definition of CreateDepthwiseConvolution2dWorkloadTest).
464     CheckInputOutput(std::move(workload),
465                      TensorInfo(inputShape, DataType::Float32),
466                      TensorInfo(outputShape, DataType::Float32));
467 }
468 
469 TEST_CASE("CreateDepthwiseConvolutionFloat32NhwcWorkload")
470 {
471     RefCreateDepthwiseConvolutionWorkloadTest(DataLayout::NHWC);
472 }
473 
474 TEST_CASE("RefCreateFullyConnectedWithBlobWorkloadTest")
475 {
476     Graph graph;
477     RefWorkloadFactory factory = GetFactory();
478     auto workload = CreateFullyConnectedWithBlobWorkloadTest<RefFullyConnectedWorkload,
479                                                          armnn::DataType::Float32>(factory, graph);
480 
481     // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
482     float inputsQScale = 1.0f;
483     float outputQScale = 1.0f;
484     CheckInputOutput(std::move(workload),
485         TensorInfo({ 3, 1, 4, 5 }, armnn::DataType::Float32, inputsQScale),
486         TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale));
487 }
488 
489 TEST_CASE("CreateFullyConnectedWorkloadWeightsBiasesAsInputsFloat32")
490 {
491     Graph graph;
492     RefWorkloadFactory factory = GetFactory();
493 
494     auto workload =
495             CreateFullyConnectedWorkloadWeightsBiasesAsInputsTest<RefFullyConnectedWorkload,
496                                                                   armnn::DataType::Float32>(factory, graph);
497 
498     // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
499     float inputsQScale = 1.0f;
500     float outputQScale = 1.0f;
501     CheckInputsOutput(std::move(workload),
502                       TensorInfo({ 3, 1, 4, 5 }, armnn::DataType::Float32, inputsQScale),
503                       TensorInfo({ 7, 20 }, armnn::DataType::Float32, inputsQScale),
504                       TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale));
505 }
506 
507 template <typename FullyConnectedWorkloadType, armnn::DataType DataType>
RefCreateFullyConnectedWorkloadTest()508 static void RefCreateFullyConnectedWorkloadTest()
509 {
510     Graph graph;
511     RefWorkloadFactory factory = GetFactory();
512     auto workload = CreateFullyConnectedWorkloadTest<FullyConnectedWorkloadType, DataType>(factory, graph);
513 
514     // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
515     float inputsQScale = DataType == armnn::DataType::QAsymmU8 ? 1.0f : 1.0f;
516     float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 1.0f;
517     CheckInputOutput(std::move(workload),
518         TensorInfo({ 3, 1, 4, 5 }, DataType, inputsQScale),
519         TensorInfo({ 3, 7 }, DataType, outputQScale));
520 }
521 
522 TEST_CASE("CreateFullyConnectedWorkloadFloat32")
523 {
524     RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::Float32>();
525 }
526 
527 TEST_CASE("CreateFullyConnectedWorkloadQuantisedAsymm8")
528 {
529     RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::QAsymmU8>();
530 }
531 
532 TEST_CASE("CreateFullyConnectedWorkloadQuantisedSymm16")
533 {
534     RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::QSymmS16>();
535 }
536 
537 template <typename NormalizationWorkloadType, armnn::DataType DataType>
RefCreateNormalizationWorkloadTest(DataLayout dataLayout)538 static void RefCreateNormalizationWorkloadTest(DataLayout dataLayout)
539 {
540     Graph graph;
541     RefWorkloadFactory factory = GetFactory();
542     auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
543 
544     TensorShape inputShape;
545     TensorShape outputShape;
546 
547     switch (dataLayout)
548     {
549         case DataLayout::NHWC:
550             inputShape  = { 3, 1, 5, 5 };
551             outputShape = { 3, 1, 5, 5 };
552             break;
553         case DataLayout::NCHW:
554         default:
555             inputShape  = { 3, 5, 5, 1 };
556             outputShape = { 3, 5, 5, 1 };
557             break;
558     }
559 
560     // Checks that outputs and inputs are as we expect them (see definition of CreateNormalizationWorkloadTest).
561     CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
562 }
563 
564 TEST_CASE("CreateRefNormalizationFloat32NchwWorkload")
565 {
566     RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
567 }
568 
569 TEST_CASE("CreateRefNormalizationFloat32NhwcWorkload")
570 {
571     RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
572 }
573 
574 TEST_CASE("CreateRefNormalizationUint8NchwWorkload")
575 {
576     RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
577 }
578 
579 TEST_CASE("CreateRefNormalizationUint8NhwcWorkload")
580 {
581     RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
582 }
583 
584 TEST_CASE("CreateRefNormalizationInt16NchwWorkload")
585 {
586     RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
587 }
588 
589 TEST_CASE("CreateRefNormalizationInt16NhwcWorkload")
590 {
591     RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
592 }
593 
594 template <typename Pooling2dWorkloadType, armnn::DataType DataType>
RefCreatePooling2dWorkloadTest(DataLayout dataLayout)595 static void RefCreatePooling2dWorkloadTest(DataLayout dataLayout)
596 {
597     Graph graph;
598     RefWorkloadFactory factory = GetFactory();
599     auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph, dataLayout);
600 
601     TensorShape inputShape;
602     TensorShape outputShape;
603 
604     switch (dataLayout)
605     {
606         case DataLayout::NHWC:
607             inputShape  = { 3, 5, 5, 2 };
608             outputShape = { 3, 2, 4, 2 };
609             break;
610         case DataLayout::NCHW:
611         default:
612             inputShape =  { 3, 2, 5, 5 };
613             outputShape = { 3, 2, 2, 4 };
614     }
615 
616     // Checks that outputs and inputs are as we expect them (see definition of CreatePooling2dWorkloadTest).
617     CheckInputOutput(std::move(workload),
618                      TensorInfo(inputShape, DataType),
619                      TensorInfo(outputShape, DataType));
620 }
621 
622 TEST_CASE("CreatePooling2dFloat32Workload")
623 {
624     RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
625 }
626 
627 TEST_CASE("CreatePooling2dFloat32NhwcWorkload")
628 {
629     RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
630 }
631 
632 TEST_CASE("CreatePooling2dUint8Workload")
633 {
634     RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
635 }
636 
637 TEST_CASE("CreatePooling2dUint8NhwcWorkload")
638 {
639     RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
640 }
641 
642 TEST_CASE("CreatePooling2dInt16Workload")
643 {
644     RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
645 }
646 
647 TEST_CASE("CreatePooling2dInt16NhwcWorkload")
648 {
649     RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
650 }
651 
652 template <typename SoftmaxWorkloadType, armnn::DataType DataType>
RefCreateSoftmaxWorkloadTest()653 static void RefCreateSoftmaxWorkloadTest()
654 {
655     Graph graph;
656     RefWorkloadFactory factory = GetFactory();
657     auto workload = CreateSoftmaxWorkloadTest<SoftmaxWorkloadType, DataType>(factory, graph);
658 
659     // Checks that outputs and inputs are as we expect them (see definition of CreateSoftmaxWorkloadTest).
660 
661     armnn::TensorInfo tensorInfo({4, 1}, DataType);
662     if (DataType == armnn::DataType::QAsymmU8)
663     {
664         tensorInfo.SetQuantizationOffset(0);
665         tensorInfo.SetQuantizationScale(1.f / 256);
666     }
667     else if (DataType == armnn::DataType::QAsymmS8)
668     {
669         tensorInfo.SetQuantizationOffset(-128);
670         tensorInfo.SetQuantizationScale(1.f / 256);
671     }
672     CheckInputOutput(
673         std::move(workload),
674         tensorInfo,
675         tensorInfo);
676 }
677 
678 TEST_CASE("CreateSoftmaxFloat32Workload")
679 {
680     RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::Float32>();
681 }
682 
683 TEST_CASE("CreateSoftmaxFloat16Workload")
684 {
685     RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::Float16>();
686 }
687 
688 TEST_CASE("CreateSoftmaxQuantisedAsymm8Workload")
689 {
690     RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::QAsymmU8>();
691 }
692 
693 TEST_CASE("CreateSoftmaxQuantisedSymm16Workload")
694 {
695     RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::QSymmS16>();
696 }
697 
698 template <typename SplitterWorkloadType, armnn::DataType DataType>
RefCreateSplitterWorkloadTest()699 static void RefCreateSplitterWorkloadTest()
700 {
701     Graph graph;
702     RefWorkloadFactory factory = GetFactory();
703     auto workload = CreateSplitterWorkloadTest<SplitterWorkloadType, DataType>(factory, graph);
704 
705     // Checks that outputs are as we expect them (see definition of CreateSplitterWorkloadTest).
706     SplitterQueueDescriptor queueDescriptor = workload->GetData();
707     auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
708     CHECK((inputHandle->GetTensorInfo() == TensorInfo({ 5, 7, 7 }, DataType)));
709 
710     auto outputHandle0 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
711     CHECK((outputHandle0->GetTensorInfo() == TensorInfo({ 1, 7, 7 }, DataType)));
712 
713     auto outputHandle1 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[1]);
714     CHECK((outputHandle1->GetTensorInfo() == TensorInfo({ 2, 7, 7 }, DataType)));
715 
716     auto outputHandle2 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[2]);
717     CHECK((outputHandle2->GetTensorInfo() == TensorInfo({ 2, 7, 7 }, DataType)));
718 }
719 
720 TEST_CASE("CreateSplitterFloat32Workload")
721 {
722     RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::Float32>();
723 }
724 
725 TEST_CASE("CreateSplitterFloat16Workload")
726 {
727     RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::Float16>();
728 }
729 
730 TEST_CASE("CreateSplitterUint8Workload")
731 {
732     RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::QAsymmU8>();
733 }
734 
735 template <typename SplitterWorkloadType, typename ConcatWorkloadType, armnn::DataType DataType>
RefCreateSplitterConcatWorkloadTest()736 static void RefCreateSplitterConcatWorkloadTest()
737 {
738     // Tests that it is possible to decide which output of the splitter layer
739     // should be lined to which input of the concat layer.
740     // We tested that is is possible to specify 0th output
741     // of the splitter to be the 1st input to the concat and the 1st output of the splitter to be 0th input
742     // of the concat.
743 
744     Graph graph;
745     RefWorkloadFactory factory = GetFactory();
746     auto workloads = CreateSplitterConcatWorkloadTest<SplitterWorkloadType, ConcatWorkloadType, DataType>
747             (factory, graph);
748 
749     auto wlSplitter = std::move(workloads.first);
750     auto wlConcat = std::move(workloads.second);
751 
752     //Checks that the index of inputs/outputs matches what we declared on InputDescriptor construction.
753     armnn::RefTensorHandle* sOut0 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[0]);
754     armnn::RefTensorHandle* sOut1 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[1]);
755     armnn::RefTensorHandle* mIn0 = dynamic_cast<armnn::RefTensorHandle*>(wlConcat->GetData().m_Inputs[0]);
756     armnn::RefTensorHandle* mIn1 = dynamic_cast<armnn::RefTensorHandle*>(wlConcat->GetData().m_Inputs[1]);
757 
758     CHECK(sOut0);
759     CHECK(sOut1);
760     CHECK(mIn0);
761     CHECK(mIn1);
762 
763     bool validDataPointers = (sOut0 == mIn1) && (sOut1 == mIn0);
764 
765     CHECK(validDataPointers);
766 }
767 
768 TEST_CASE("CreateSplitterConcatFloat32")
769 {
770     RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::Float32>();
771 }
772 
773 TEST_CASE("CreateSplitterConcatFloat16")
774 {
775     RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::Float16>();
776 }
777 
778 TEST_CASE("CreateSplitterConcatUint8")
779 {
780     RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::QAsymmU8>();
781 }
782 
783 template <typename SplitterWorkloadType, typename ActivationWorkloadType, armnn::DataType DataType>
RefCreateSingleOutputMultipleInputsTest()784 static void RefCreateSingleOutputMultipleInputsTest()
785 {
786     // Tests that it is possible to assign multiple (two) different layers to each of the outputs of a splitter layer.
787     // We created a splitter with two outputs. That each of those outputs is used by two different activation layers.
788 
789     Graph graph;
790     RefWorkloadFactory factory = GetFactory();
791     std::unique_ptr<SplitterWorkloadType> wlSplitter;
792     std::unique_ptr<ActivationWorkloadType> wlActiv0_0;
793     std::unique_ptr<ActivationWorkloadType> wlActiv0_1;
794     std::unique_ptr<ActivationWorkloadType> wlActiv1_0;
795     std::unique_ptr<ActivationWorkloadType> wlActiv1_1;
796 
797     CreateSplitterMultipleInputsOneOutputWorkloadTest<SplitterWorkloadType,
798         ActivationWorkloadType, DataType>(factory, graph, wlSplitter, wlActiv0_0, wlActiv0_1, wlActiv1_0, wlActiv1_1);
799 
800     armnn::RefTensorHandle* sOut0 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[0]);
801     armnn::RefTensorHandle* sOut1 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[1]);
802     armnn::RefTensorHandle* activ0_0Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv0_0->GetData().m_Inputs[0]);
803     armnn::RefTensorHandle* activ0_1Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv0_1->GetData().m_Inputs[0]);
804     armnn::RefTensorHandle* activ1_0Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv1_0->GetData().m_Inputs[0]);
805     armnn::RefTensorHandle* activ1_1Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv1_1->GetData().m_Inputs[0]);
806 
807 
808     CHECK(sOut0);
809     CHECK(sOut1);
810     CHECK(activ0_0Im);
811     CHECK(activ0_1Im);
812     CHECK(activ1_0Im);
813     CHECK(activ1_1Im);
814 
815     bool validDataPointers = (sOut0 == activ0_0Im) && (sOut0 == activ0_1Im) &&
816                              (sOut1 == activ1_0Im) && (sOut1 == activ1_1Im);
817 
818     CHECK(validDataPointers);
819 }
820 
821 TEST_CASE("CreateSingleOutputMultipleInputsFloat32")
822 {
823     RefCreateSingleOutputMultipleInputsTest<RefSplitterWorkload, RefActivationWorkload,
824         armnn::DataType::Float32>();
825 }
826 
827 TEST_CASE("CreateSingleOutputMultipleInputsUint8")
828 {
829     RefCreateSingleOutputMultipleInputsTest<RefSplitterWorkload, RefActivationWorkload,
830         armnn::DataType::QAsymmU8>();
831 }
832 
833 template <typename ResizeBilinearWorkloadType, armnn::DataType DataType>
RefCreateResizeBilinearTest(DataLayout dataLayout)834 static void RefCreateResizeBilinearTest(DataLayout dataLayout)
835 {
836     Graph graph;
837     RefWorkloadFactory factory = GetFactory();
838     auto workload = CreateResizeBilinearWorkloadTest<ResizeBilinearWorkloadType, DataType>(factory, graph, dataLayout);
839 
840     TensorShape inputShape;
841     TensorShape outputShape;
842 
843     switch (dataLayout)
844     {
845         case DataLayout::NHWC:
846             inputShape  = { 2, 4, 4, 3 };
847             outputShape = { 2, 2, 2, 3 };
848             break;
849         case DataLayout::NCHW:
850         default:
851             inputShape  = { 2, 3, 4, 4 };
852             outputShape = { 2, 3, 2, 2 };
853     }
854 
855     // Checks that outputs and inputs are as we expect them (see definition of CreateResizeBilinearWorkloadTest).
856     CheckInputOutput(std::move(workload),
857                      TensorInfo(inputShape, DataType),
858                      TensorInfo(outputShape, DataType));
859 }
860 
861 TEST_CASE("CreateResizeBilinearFloat32")
862 {
863     RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
864 }
865 
866 TEST_CASE("CreateResizeBilinearFloat16")
867 {
868     RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float16>(DataLayout::NCHW);
869 }
870 
871 TEST_CASE("CreateResizeBilinearUint8")
872 {
873     RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
874 }
875 
876 TEST_CASE("CreateResizeBilinearQuantisedAsymm16")
877 {
878     RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
879 }
880 
881 TEST_CASE("CreateResizeBilinearFloat32Nhwc")
882 {
883     RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
884 }
885 
886 template <typename BatchToSpaceNdWorkloadType, armnn::DataType DataType>
RefCreateBatchToSpaceNdTest()887 static void RefCreateBatchToSpaceNdTest()
888 {
889     Graph graph;
890     RefWorkloadFactory factory;
891 
892     auto workload = CreateBatchToSpaceNdWorkloadTest<BatchToSpaceNdWorkloadType, DataType>(factory, graph);
893 
894     CheckInputOutput(std::move(workload),
895                      TensorInfo({ 1, 1, 1, 1 }, DataType),
896                      TensorInfo({ 1, 1, 1, 1 }, DataType));
897 }
898 
899 TEST_CASE("CreateBatchToSpaceNdFloat32")
900 {
901     RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::Float32>();
902 }
903 
904 TEST_CASE("CreateBatchToSpaceNdFloat16")
905 {
906     RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::Float16>();
907 }
908 
909 TEST_CASE("CreateBatchToSpaceNdUint8")
910 {
911     RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::QAsymmU8>();
912 }
913 
914 TEST_CASE("CreateBatchToSpaceNdQSymm16")
915 {
916     RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::QSymmS16>();
917 }
918 
919 template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
RefCreateL2NormalizationTest(DataLayout dataLayout)920 static void RefCreateL2NormalizationTest(DataLayout dataLayout)
921 {
922     Graph graph;
923     RefWorkloadFactory factory = GetFactory();
924     auto workload =
925             CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
926 
927     TensorShape inputShape;
928     TensorShape outputShape;
929 
930     switch (dataLayout)
931     {
932         case DataLayout::NHWC:
933             inputShape  = { 5, 50, 67, 20 };
934             outputShape = { 5, 50, 67, 20 };
935             break;
936         case DataLayout::NCHW:
937         default:
938             inputShape  = { 5, 20, 50, 67 };
939             outputShape = { 5, 20, 50, 67 };
940             break;
941     }
942 
943     // Checks that outputs and inputs are as we expect them (see definition of CreateL2NormalizationWorkloadTest).
944     CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
945 }
946 
947 TEST_CASE("CreateL2NormalizationFloat32")
948 {
949     RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
950 }
951 
952 TEST_CASE("CreateL2NormalizationFloat32Nhwc")
953 {
954     RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
955 }
956 
957 TEST_CASE("CreateL2NormalizationInt16")
958 {
959     RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
960 }
961 
962 TEST_CASE("CreateL2NormalizationInt16Nhwc")
963 {
964     RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
965 }
966 
967 TEST_CASE("CreateL2NormalizationUint8")
968 {
969     RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
970 }
971 
972 TEST_CASE("CreateL2NormalizationUint8Nhwc")
973 {
974     RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
975 }
976 
977 template <typename ReshapeWorkloadType, armnn::DataType DataType>
RefCreateReshapeWorkloadTest()978 static void RefCreateReshapeWorkloadTest()
979 {
980     Graph graph;
981     RefWorkloadFactory factory = GetFactory();
982     auto workload = CreateReshapeWorkloadTest<ReshapeWorkloadType, DataType>(factory, graph);
983 
984     // Checks that outputs and inputs are as we expect them (see definition of CreateReshapeWorkloadTest).
985     CheckInputOutput(
986         std::move(workload),
987         TensorInfo({ 4, 1 }, DataType),
988         TensorInfo({ 1, 4 }, DataType));
989 }
990 
991 TEST_CASE("CreateReshapeWorkloadFloat32")
992 {
993     RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::Float32>();
994 }
995 
996 TEST_CASE("CreateReshapeWorkloadQuantisedAsymm8")
997 {
998     RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QAsymmU8>();
999 }
1000 
1001 TEST_CASE("CreateReshapeWorkloadQuantisedSymm16")
1002 {
1003     RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QSymmS16>();
1004 }
1005 
1006 template <typename ConcatWorkloadType, armnn::DataType DataType>
RefCreateConcatWorkloadTest(const armnn::TensorShape & outputShape,unsigned int concatAxis)1007 static void RefCreateConcatWorkloadTest(const armnn::TensorShape& outputShape,
1008                                         unsigned int concatAxis)
1009 {
1010     Graph graph;
1011     RefWorkloadFactory factory = GetFactory();
1012     auto workload = CreateConcatWorkloadTest<ConcatWorkloadType, DataType>(factory, graph, outputShape, concatAxis);
1013 
1014     CheckInputsOutput(std::move(workload),
1015                       TensorInfo({ 2, 3, 2, 5 }, DataType),
1016                       TensorInfo({ 2, 3, 2, 5 }, DataType),
1017                       TensorInfo(outputShape, DataType));
1018 }
1019 
1020 TEST_CASE("CreateConcatDim0Float32Workload")
1021 {
1022     RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 4, 3, 2, 5 }, 0);
1023 }
1024 
1025 TEST_CASE("CreateConcatDim0Float16Workload")
1026 {
1027     RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float16>({ 4, 3, 2, 5 }, 0);
1028 }
1029 
1030 TEST_CASE("CreateConcatDim0Uint8Workload")
1031 {
1032     RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 4, 3, 2, 5 }, 0);
1033 }
1034 
1035 TEST_CASE("CreateConcatDim0Uint16Workload")
1036 {
1037     RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QSymmS16>({ 4, 3, 2, 5 }, 0);
1038 }
1039 
1040 TEST_CASE("CreateConcatDim1Float32Workload")
1041 {
1042     RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 6, 2, 5 }, 1);
1043 }
1044 
1045 TEST_CASE("CreateConcatDim1Uint8Workload")
1046 {
1047     RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 6, 2, 5 }, 1);
1048 }
1049 
1050 TEST_CASE("CreateConcatDim2Float32Workload")
1051 {
1052     RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 4, 5 }, 2);
1053 }
1054 
1055 TEST_CASE("CreateConcatDim2Uint8Workload")
1056 {
1057     RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 4, 5 }, 2);
1058 }
1059 
1060 TEST_CASE("CreateConcatDim3Float32Workload")
1061 {
1062     RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 2, 10 }, 3);
1063 }
1064 
1065 TEST_CASE("CreateConcatDim3Uint8Workload")
1066 {
1067     RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 2, 10 }, 3);
1068 }
1069 
1070 template <typename ConstantWorkloadType, armnn::DataType DataType>
RefCreateConstantWorkloadTest(const armnn::TensorShape & outputShape)1071 static void RefCreateConstantWorkloadTest(const armnn::TensorShape& outputShape)
1072 {
1073     armnn::Graph graph;
1074     RefWorkloadFactory factory = GetFactory();
1075     auto workload = CreateConstantWorkloadTest<ConstantWorkloadType, DataType>(factory, graph, outputShape);
1076 
1077     // Check output is as expected
1078     auto queueDescriptor = workload->GetData();
1079     auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
1080     CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, DataType)));
1081 }
1082 
1083 TEST_CASE("CreateConstantUint8Workload")
1084 {
1085     RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 2, 10 });
1086 }
1087 
1088 TEST_CASE("CreateConstantInt16Workload")
1089 {
1090     RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::QSymmS16>({ 2, 3, 2, 10 });
1091 }
1092 
1093 TEST_CASE("CreateConstantFloat32Workload")
1094 {
1095     RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::Float32>({ 2, 3, 2, 10 });
1096 }
1097 
1098 TEST_CASE("CreateConstantSigned32Workload")
1099 {
1100     RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::Signed32>({ 2, 3, 2, 10 });
1101 }
1102 
RefCreatePreluWorkloadTest(const armnn::TensorShape & inputShape,const armnn::TensorShape & alphaShape,const armnn::TensorShape & outputShape,armnn::DataType dataType)1103 static void RefCreatePreluWorkloadTest(const armnn::TensorShape& inputShape,
1104                                        const armnn::TensorShape& alphaShape,
1105                                        const armnn::TensorShape& outputShape,
1106                                        armnn::DataType dataType)
1107 {
1108     armnn::Graph graph;
1109     RefWorkloadFactory factory;
1110     auto workload = CreatePreluWorkloadTest<RefPreluWorkload>(factory,
1111                                                               graph,
1112                                                               inputShape,
1113                                                               alphaShape,
1114                                                               outputShape,
1115                                                               dataType);
1116 
1117     // Check output is as expected
1118     auto queueDescriptor = workload->GetData();
1119     auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
1120     CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, dataType)));
1121 }
1122 
1123 TEST_CASE("CreatePreluFloat32Workload")
1124 {
1125     RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::Float32);
1126 }
1127 
1128 TEST_CASE("CreatePreluFloat16Workload")
1129 {
1130     RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::Float16);
1131 }
1132 
1133 TEST_CASE("CreatePreluUint8Workload")
1134 {
1135     RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QAsymmU8);
1136 }
1137 
1138 TEST_CASE("CreatePreluInt16Workload")
1139 {
1140     RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QSymmS16);
1141 }
1142 
1143 TEST_CASE("CreatePreluFloat32NoBroadcastWorkload")
1144 {
1145     CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
1146                                                  armnn::DataType::Float32),
1147                       armnn::InvalidArgumentException);
1148 }
1149 
1150 TEST_CASE("CreatePreluFloat16NoBroadcastWorkload")
1151 {
1152     CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
1153                                                  armnn::DataType::Float16),
1154                       armnn::InvalidArgumentException);
1155 }
1156 
1157 TEST_CASE("CreatePreluUint8NoBroadcastWorkload")
1158 {
1159     CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
1160                                                  armnn::DataType::QAsymmU8),
1161                       armnn::InvalidArgumentException);
1162 }
1163 
1164 TEST_CASE("CreatePreluInt16NoBroadcastWorkload")
1165 {
1166     CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
1167                                                  armnn::DataType::QSymmS16),
1168                       armnn::InvalidArgumentException);
1169 }
1170 
1171 template <typename SpaceToDepthWorkloadType, armnn::DataType DataType>
RefCreateSpaceToDepthWorkloadTest()1172 static void RefCreateSpaceToDepthWorkloadTest()
1173 {
1174     Graph graph;
1175     RefWorkloadFactory factory;
1176 
1177     auto workload = CreateSpaceToDepthWorkloadTest<SpaceToDepthWorkloadType, DataType>(factory, graph);
1178 
1179     CheckInputOutput(std::move(workload),
1180                      TensorInfo({ 1, 2, 2, 1 }, DataType),
1181                      TensorInfo({ 1, 1, 1, 4 }, DataType));
1182 }
1183 
1184 TEST_CASE("CreateSpaceToDepthWorkloadFloat32")
1185 {
1186     RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::Float32>();
1187 }
1188 
1189 TEST_CASE("CreateSpaceToDepthWorkloadFloat16")
1190 {
1191     RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::Float16>();
1192 }
1193 
1194 TEST_CASE("CreateSpaceToDepthWorkloadQASymm8")
1195 {
1196     RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::QAsymmU8>();
1197 }
1198 
1199 TEST_CASE("CreateSpaceToDepthWorkloadQSymm16")
1200 {
1201     RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::QSymmS16>();
1202 }
1203 
1204 template <armnn::DataType DataType>
RefCreateStackWorkloadTest(const armnn::TensorShape & inputShape,const armnn::TensorShape & outputShape,unsigned int axis,unsigned int numInputs)1205 static void RefCreateStackWorkloadTest(const armnn::TensorShape& inputShape,
1206                                        const armnn::TensorShape& outputShape,
1207                                        unsigned int axis,
1208                                        unsigned int numInputs)
1209 {
1210     armnn::Graph graph;
1211     RefWorkloadFactory factory;
1212     auto workload = CreateStackWorkloadTest<RefStackWorkload, DataType>(factory,
1213                                                                         graph,
1214                                                                         inputShape,
1215                                                                         outputShape,
1216                                                                         axis,
1217                                                                         numInputs);
1218 
1219     // Check inputs and output are as expected
1220     StackQueueDescriptor queueDescriptor = workload->GetData();
1221     for (unsigned int i = 0; i < numInputs; ++i)
1222     {
1223         auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[i]);
1224         CHECK((inputHandle->GetTensorInfo() == TensorInfo(inputShape, DataType)));
1225     }
1226     auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
1227     CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, DataType)));
1228 }
1229 
1230 TEST_CASE("CreateStackFloat32Workload")
1231 {
1232     RefCreateStackWorkloadTest<armnn::DataType::Float32>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
1233 }
1234 
1235 TEST_CASE("CreateStackUint8Workload")
1236 {
1237     RefCreateStackWorkloadTest<armnn::DataType::QAsymmU8>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
1238 }
1239 
1240 TEST_CASE("CreateStackUint16Workload")
1241 {
1242     RefCreateStackWorkloadTest<armnn::DataType::QSymmS16>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
1243 }
1244 
1245 template <typename QLstmWorkloadType>
RefCreateQLstmWorkloadTest()1246 static void RefCreateQLstmWorkloadTest()
1247 {
1248     Graph graph;
1249     RefWorkloadFactory factory;
1250 
1251     auto workload = CreateQLstmWorkloadTest<QLstmWorkloadType>(factory, graph);
1252 
1253     armnn::TensorInfo inputInfo({2 , 4}, armnn::DataType::QAsymmS8, 0.0078125f, 0);
1254 
1255     armnn::TensorInfo cellStateInfo({2 , 4}, armnn::DataType::QSymmS16, 3.05176e-05f, 0);
1256 
1257     armnn::TensorInfo outputInfo({2 , 4}, armnn::DataType::QAsymmS8, 0.007f, 0);
1258 
1259     QLstmQueueDescriptor queueDescriptor = workload->GetData();
1260     auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
1261     auto cellStateOutHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[1]);
1262     auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[2]);
1263 
1264     CHECK((inputHandle->GetTensorInfo() == inputInfo));
1265     CHECK((cellStateOutHandle->GetTensorInfo() == cellStateInfo));
1266     CHECK((outputHandle->GetTensorInfo() == outputInfo));
1267 }
1268 
1269 TEST_CASE("CreateQLstmWorkload")
1270 {
1271     RefCreateQLstmWorkloadTest<RefQLstmWorkload>();
1272 }
1273 
1274 template <armnn::DataType DataType>
RefCreateActivationWorkloadReplaceFunctionsTest()1275 static void RefCreateActivationWorkloadReplaceFunctionsTest()
1276 {
1277     Graph graph;
1278     RefWorkloadFactory factory = GetFactory();
1279     // input and output are created as armnn::TensorInfo tensorInfo({1, 1}, DataType)
1280     auto workloadPtr = CreateActivationWorkloadTest<RefActivationWorkload, DataType>(factory, graph);
1281 
1282     // new input and output tensor handlers are created and then replace in the workload
1283     shared_ptr<RefMemoryManager> memoryManager = make_shared<RefMemoryManager>();
1284     const RefTensorHandleFactory tensorHandleFactory(memoryManager);
1285     TensorInfo inputInfo({2 , 2}, armnn::DataType::Float16);
1286     TensorInfo outputInfo({2 , 2}, armnn::DataType::Float16);
1287     unique_ptr<ITensorHandle> inputHandle  = tensorHandleFactory.CreateTensorHandle(inputInfo);
1288     unique_ptr<ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
1289     unsigned int slot = 0;
1290     workloadPtr->ReplaceInputTensorHandle(inputHandle.get(), slot);
1291     workloadPtr->ReplaceOutputTensorHandle(outputHandle.get(), slot);
1292 
1293     // Check if the tensor handlers inside the workload are the same as ones we replace with
1294     auto queueDescriptor = workloadPtr->GetData();
1295     auto inputHandleTest  = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
1296     auto outputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
1297     CHECK((inputHandleTest->GetTensorInfo() == inputInfo));
1298     CHECK((outputHandleTest->GetTensorInfo() == outputInfo));
1299     CHECK(inputHandle.get() == inputHandleTest);
1300     CHECK(outputHandle.get() == outputHandleTest);
1301     inputHandle->Allocate();
1302     CHECK(inputHandle->Map() == inputHandleTest->Map());
1303     outputHandle->Allocate();
1304     CHECK(outputHandle->Map() == outputHandleTest->Map());
1305 }
1306 
1307 TEST_CASE("ReplaceFunctionsfromFloat32toFloat16ActivationWorkload")
1308 {
1309     RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::Float32>();
1310 }
1311 
1312 TEST_CASE("ReplaceFunctionsfromUint8toFloat16ActivationWorkload")
1313 {
1314     RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::QAsymmU8>();
1315 }
1316 
1317 }
1318