xref: /aosp_15_r20/external/armnn/src/backends/reference/RefLayerSupport.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefLayerSupport.hpp"
7 
8 #include <armnn/TypesUtils.hpp>
9 #include <armnn/Types.hpp>
10 #include <armnn/utility/IgnoreUnused.hpp>
11 #include <armnn/utility/NumericCast.hpp>
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 
14 #include <LayerSupportCommon.hpp>
15 #include <backendsCommon/LayerSupportRules.hpp>
16 
17 #include <vector>
18 #include <array>
19 
20 namespace armnn
21 {
22 
23 namespace
24 {
25 
26 template<typename Float32Func, typename Uint8Func, typename ... Params>
IsSupportedForDataTypeRef(Optional<std::string &> reasonIfUnsupported,DataType dataType,Float32Func floatFuncPtr,Uint8Func uint8FuncPtr,Params &&...params)27 bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28                                DataType dataType,
29                                Float32Func floatFuncPtr,
30                                Uint8Func uint8FuncPtr,
31                                Params&&... params)
32 {
33     return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34                                          dataType,
35                                          &FalseFunc<Params...>,
36                                          floatFuncPtr,
37                                          uint8FuncPtr,
38                                          &FalseFunc<Params...>,
39                                          &FalseFunc<Params...>,
40                                          std::forward<Params>(params)...);
41 }
42 
43 } // anonymous namespace
44 
45 namespace
46 {
47 
CreateIncorrectDimensionsErrorMsg(unsigned int expected,unsigned int actual,std::string & layerStr,std::string & tensorName)48 std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49                                               unsigned int actual,
50                                               std::string& layerStr,
51                                               std::string& tensorName)
52 {
53     std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54                            " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55 
56     return errorMsg;
57 }
58 
59 } // anonymous namespace
60 
IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> & lstmParamsInfo,const Optional<QuantizedLstmInputParamsInfo> & quantizedLstmInputParamsInfo,Optional<std::string &> reasonIfUnsupported) const61 bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62                                        const std::vector<TensorInfo>& infos,
63                                        const BaseDescriptor& descriptor,
64                                        const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65                                        const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66                                        Optional<std::string&> reasonIfUnsupported) const
67 {
68     switch (type)
69     {
70         case LayerType::Activation:
71             return IsActivationSupported(infos[0],
72                                          infos[1],
73                                          *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74                                          reasonIfUnsupported);
75         case LayerType::Addition:
76             return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77         case LayerType::ArgMinMax:
78             return IsArgMinMaxSupported(infos[0],
79                                         infos[1],
80                                         *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81                                         reasonIfUnsupported);
82         case LayerType::BatchMatMul:
83             return IsBatchMatMulSupported(infos[0],
84                                           infos[1],
85                                           infos[2],
86                                           *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87                                           reasonIfUnsupported);
88         case LayerType::BatchNormalization:
89             return IsBatchNormalizationSupported(infos[0],
90                                                  infos[1],
91                                                  infos[2],
92                                                  infos[3],
93                                                  infos[4],
94                                                  infos[5],
95                                                  *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96                                                      (&descriptor)),
97                                                  reasonIfUnsupported);
98         case LayerType::BatchToSpaceNd:
99             return IsBatchToSpaceNdSupported(infos[0],
100                                              infos[1],
101                                              *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102                                              reasonIfUnsupported);
103         case LayerType::Comparison:
104             return IsComparisonSupported(infos[0],
105                                          infos[1],
106                                          infos[2],
107                                          *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
108                                          reasonIfUnsupported);
109         case LayerType::Concat:
110         {
111             std::vector<const TensorInfo*> inputInfos;
112             for (uint32_t i = 0; i < (infos.size() - 1); i++)
113             {
114                 inputInfos.push_back(&infos[i]);
115             }
116             return IsConcatSupported(inputInfos,
117                                      infos[infos.size() - 1],
118                                      *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
119                                      reasonIfUnsupported);
120         }
121         case LayerType::Constant:
122             return IsConstantSupported(infos[0], reasonIfUnsupported);
123         case LayerType::ConvertFp16ToFp32:
124             return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
125         case LayerType::ConvertFp32ToFp16:
126             return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
127         case LayerType::Convolution2d:
128         {
129             if (infos.size() != 4)
130             {
131                 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
132                                                "TensorInfos should be of format: {input, output, weights, biases}.");
133             }
134 
135             auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
136             if (infos[3] == TensorInfo())
137             {
138                 return IsConvolution2dSupported(infos[0],
139                                                 infos[1],
140                                                 desc,
141                                                 infos[2],
142                                                 EmptyOptional(),
143                                                 reasonIfUnsupported);
144             }
145             else
146             {
147                 return IsConvolution2dSupported(infos[0],
148                                                 infos[1],
149                                                 desc,
150                                                 infos[2],
151                                                 infos[3],
152                                                 reasonIfUnsupported);
153             }
154         }
155         case LayerType::DepthToSpace:
156             return IsDepthToSpaceSupported(infos[0],
157                                            infos[1],
158                                            *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
159                                            reasonIfUnsupported);
160         case LayerType::DepthwiseConvolution2d:
161         {
162             if (infos.size() != 4)
163             {
164                 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
165                                                "TensorInfos should be of format: {input, output, weights, biases}.");
166             }
167 
168             auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
169             if (infos[3] == TensorInfo())
170             {
171                 return IsDepthwiseConvolutionSupported(infos[0],
172                                                        infos[1],
173                                                        desc,
174                                                        infos[2],
175                                                        EmptyOptional(),
176                                                        reasonIfUnsupported);
177             }
178             else
179             {
180                 return IsDepthwiseConvolutionSupported(infos[0],
181                                                        infos[1],
182                                                        desc,
183                                                        infos[2],
184                                                        infos[3],
185                                                        reasonIfUnsupported);
186             }
187         }
188         case LayerType::Dequantize:
189             return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
190         case LayerType::Division:
191             return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
192         case LayerType::ElementwiseBinary:
193         {
194             std::array<DataType, 7> supportedTypes =
195                     {
196                             DataType::Float32,
197                             DataType::Float16,
198                             DataType::QAsymmS8,
199                             DataType::QAsymmU8,
200                             DataType::QSymmS16,
201                             DataType::Signed32
202                     };
203 
204             bool supported = true;
205             supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
206                                           "Reference elementwise unary: input type not supported");
207 
208             supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
209                                           "Reference elementwise unary: input type not supported");
210 
211             supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
212                                           "Reference elementwise unary: output type not supported");
213 
214             supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
215                                           "Reference elementwise unary: input types not matching");
216 
217             supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
218                                           "Reference elementwise unary: input and output types not matching");
219 
220             return supported;
221         }
222         case LayerType::ElementwiseUnary:
223             return IsElementwiseUnarySupported(infos[0],
224                                                infos[1],
225                                                *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
226                                                reasonIfUnsupported);
227         case LayerType::Fill:
228             return IsFillSupported(infos[0],
229                                    infos[1],
230                                    *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
231                                    reasonIfUnsupported);
232         case LayerType::Floor:
233             return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
234         case LayerType::FullyConnected:
235             return IsFullyConnectedSupported(infos[0],
236                                              infos[1],
237                                              infos[2],
238                                              infos[3],
239                                              *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
240                                              reasonIfUnsupported);
241         case LayerType::Gather:
242             return IsGatherSupported(infos[0],
243                                      infos[1],
244                                      infos[2],
245                                      *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
246                                      reasonIfUnsupported);
247         case LayerType::GatherNd:
248             return IsGatherNdSupported(infos[0],
249                                        infos[1],
250                                        infos[2],
251                                        reasonIfUnsupported);
252         case LayerType::Input:
253             return IsInputSupported(infos[0], reasonIfUnsupported);
254         case LayerType::InstanceNormalization:
255             return IsInstanceNormalizationSupported(infos[0],
256                                                     infos[1],
257                                                     *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
258                                                         (&descriptor)),
259                                                     reasonIfUnsupported);
260         case LayerType::L2Normalization:
261             return IsL2NormalizationSupported(infos[0],
262                                               infos[1],
263                                               *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
264                                               reasonIfUnsupported);
265         case LayerType::LogicalBinary:
266             return IsLogicalBinarySupported(infos[0],
267                                             infos[1],
268                                             infos[2],
269                                             *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
270                                             reasonIfUnsupported);
271         case LayerType::LogSoftmax:
272             return IsLogSoftmaxSupported(infos[0],
273                                          infos[1],
274                                          *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
275                                          reasonIfUnsupported);
276         case LayerType::Lstm:
277             return IsLstmSupported(infos[0],
278                                    infos[1],
279                                    infos[2],
280                                    infos[3],
281                                    infos[4],
282                                    infos[5],
283                                    infos[6],
284                                    *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
285                                    lstmParamsInfo.value(),
286                                    reasonIfUnsupported);
287         case LayerType::QLstm:
288             return IsQLstmSupported(infos[0],
289                                     infos[1],
290                                     infos[2],
291                                     infos[3],
292                                     infos[4],
293                                     infos[5],
294                                     *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
295                                     lstmParamsInfo.value(),
296                                     reasonIfUnsupported);
297         case LayerType::Maximum:
298             return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
299         case LayerType::Mean:
300             return IsMeanSupported(infos[0],
301                                    infos[1],
302                                    *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
303                                    reasonIfUnsupported);
304         case LayerType::Minimum:
305             return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
306         case LayerType::Multiplication:
307             return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
308         case LayerType::Normalization:
309             return IsNormalizationSupported(infos[0],
310                                             infos[1],
311                                             *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
312                                             reasonIfUnsupported);
313         case LayerType::Output:
314             return IsOutputSupported(infos[0], reasonIfUnsupported);
315         case LayerType::Pad:
316             return IsPadSupported(infos[0],
317                                   infos[1],
318                                   *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
319                                   reasonIfUnsupported);
320         case LayerType::Permute:
321             return IsPermuteSupported(infos[0],
322                                       infos[1],
323                                       *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
324                                       reasonIfUnsupported);
325         case LayerType::Pooling2d:
326             return IsPooling2dSupported(infos[0],
327                                         infos[1],
328                                         *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
329                                         reasonIfUnsupported);
330         case LayerType::Prelu:
331             return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
332         case LayerType::Quantize:
333             return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
334         case LayerType::Reshape:
335             return IsReshapeSupported(infos[0],
336                                       infos[1],
337                                       *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
338                                       reasonIfUnsupported);
339         case LayerType::Resize:
340             return IsResizeSupported(infos[0],
341                                      infos[1],
342                                      *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
343                                      reasonIfUnsupported);
344         case LayerType::Reduce:
345             return IsReduceSupported(infos[0],
346                                      infos[1],
347                                      *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
348                                      reasonIfUnsupported);
349         case LayerType::Slice:
350             return IsSliceSupported(infos[0],
351                                     infos[1],
352                                     *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
353                                     reasonIfUnsupported);
354         case LayerType::Softmax:
355             return IsSoftmaxSupported(infos[0],
356                                       infos[1],
357                                       *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
358                                       reasonIfUnsupported);
359         case LayerType::SpaceToBatchNd:
360             return IsSpaceToBatchNdSupported(infos[0],
361                                              infos[1],
362                                              *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
363                                              reasonIfUnsupported);
364         case LayerType::SpaceToDepth:
365             return IsSpaceToDepthSupported(infos[0],
366                                            infos[1],
367                                            *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
368                                            reasonIfUnsupported);
369         case LayerType::Splitter:
370         {
371             std::vector<TensorInfo> outputInfos;
372             for (uint32_t i = 1; i < infos.size(); i++)
373             {
374                 outputInfos.push_back(infos[i]);
375             }
376             return IsSplitterSupported(infos[0],
377                                        {outputInfos.begin(), outputInfos.end()},
378                                        *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
379                                        reasonIfUnsupported);
380         }
381         case LayerType::Stack:
382         {
383             std::vector<const TensorInfo*> inputInfos;
384             for (uint32_t i = 0; i < infos.size() - 1; i++)
385             {
386                 inputInfos.push_back(&infos[i]);
387             }
388             return IsStackSupported(inputInfos,
389                                     infos[infos.size() - 1],
390                                     *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
391                                     reasonIfUnsupported);
392         }
393         case LayerType::StridedSlice:
394             return IsStridedSliceSupported(infos[0],
395                                            infos[1],
396                                            *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
397                                            reasonIfUnsupported);
398         case LayerType::Subtraction:
399             return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
400         case LayerType::Transpose:
401             return IsTransposeSupported(infos[0],
402                                         infos[1],
403                                         *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
404                                         reasonIfUnsupported);
405         case LayerType::TransposeConvolution2d:
406         {
407             if (infos.size() != 4)
408             {
409                 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
410                                                "TensorInfos should be of format: {input, output, weights, biases}.");
411             }
412 
413             auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
414             if (infos[3] == TensorInfo())
415             {
416                 return IsTransposeConvolution2dSupported(infos[0],
417                                                          infos[1],
418                                                          desc,
419                                                          infos[2],
420                                                          EmptyOptional(),
421                                                          reasonIfUnsupported);
422             }
423             else
424             {
425                 return IsTransposeConvolution2dSupported(infos[0],
426                                                          infos[1],
427                                                          desc,
428                                                          infos[2],
429                                                          infos[3],
430                                                          reasonIfUnsupported);
431             }
432         }
433         case LayerType::Cast:
434             return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
435         case LayerType::ChannelShuffle:
436             return IsChannelShuffleSupported(infos[0],
437                                              infos[1],
438                                              *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
439                                              reasonIfUnsupported);
440         case LayerType::Convolution3d:
441         {
442             if (infos.size() != 4)
443             {
444                 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
445                                                "TensorInfos should be of format: {input, output, weights, biases}.");
446             }
447 
448             auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
449             if (infos[3] == TensorInfo())
450             {
451                 return IsConvolution3dSupported(infos[0],
452                                                 infos[1],
453                                                 desc,
454                                                 infos[2],
455                                                 EmptyOptional(),
456                                                 reasonIfUnsupported);
457             }
458             else
459             {
460                 return IsConvolution3dSupported(infos[0],
461                                                 infos[1],
462                                                 desc,
463                                                 infos[2],
464                                                 infos[3],
465                                                 reasonIfUnsupported);
466             }
467         }
468         case LayerType::Debug:
469             return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
470         case LayerType::DetectionPostProcess:
471             return IsDetectionPostProcessSupported(infos[0],
472                                                    infos[1],
473                                                    infos[2],
474                                                    infos[3],
475                                                    infos[4],
476                                                    infos[5],
477                                                    infos[6],
478                                                    *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
479                                                        (&descriptor)),
480                                                    reasonIfUnsupported);
481         case LayerType::FakeQuantization:
482             return IsFakeQuantizationSupported(infos[0],
483                                                *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
484                                                reasonIfUnsupported);
485         case LayerType::MemCopy:
486             return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
487         case LayerType::Rank:
488             return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
489         case LayerType::Shape:
490             return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
491         case LayerType::UnidirectionalSequenceLstm:
492         {
493             if (infos.size() != 6)
494             {
495                 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
496                                                "should be of format: {input, outputStateIn, cellStateIn, "
497                                                "hiddenStateOutputVal, cellStateOutputVal, output}");
498             }
499             auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
500             return IsUnidirectionalSequenceLstmSupported(infos[0],
501                                                          infos[1],
502                                                          infos[2],
503                                                          infos[3],
504                                                          infos[4],
505                                                          infos[5],
506                                                          desc,
507                                                          lstmParamsInfo.value(),
508                                                          reasonIfUnsupported);
509         }
510         case LayerType::Pooling3d:
511             return IsPooling3dSupported(infos[0],
512                                         infos[1],
513                                         *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
514                                         reasonIfUnsupported);
515         case LayerType::Map:
516             return true;
517         case LayerType::Unmap:
518             return true;
519         case LayerType::MemImport:
520             return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
521         case LayerType::Merge:
522             return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
523         case LayerType::QuantizedLstm:
524             return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
525                                                               infos[1],
526                                                               infos[2],
527                                                               infos[3],
528                                                               infos[4],
529                                                               quantizedLstmInputParamsInfo.value(),
530                                                               reasonIfUnsupported);
531         default:
532             // layers not supported in neon by default:
533             // precompiled, standin, switch
534             return false;
535     }
536 }
537 
IsActivationSupported(const TensorInfo & input,const TensorInfo & output,const ActivationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const538 bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
539                                             const TensorInfo& output,
540                                             const ActivationDescriptor& descriptor,
541                                             Optional<std::string&> reasonIfUnsupported) const
542 {
543    bool supported = true;
544 
545     // Define supported types.
546     std::array<DataType,6> supportedTypes = {
547         DataType::Float32,
548         DataType::Float16,
549         DataType::QAsymmS8,
550         DataType::QAsymmU8,
551         DataType::QSymmS16
552     };
553 
554     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
555                                   "Reference activation: input type not supported.");
556 
557     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
558                                   "Reference activation: output type not supported.");
559 
560     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
561                                   "Reference activation: input and output types mismatched.");
562 
563     supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
564                                   "Reference activation: input and output shapes are of different rank.");
565 
566 
567     struct ActivationFunctionSupported : public Rule
568     {
569         ActivationFunctionSupported(const ActivationDescriptor& desc)
570         {
571             switch(desc.m_Function)
572             {
573                 case ActivationFunction::Abs:
574                 case ActivationFunction::BoundedReLu:
575                 case ActivationFunction::Elu:
576                 case ActivationFunction::HardSwish:
577                 case ActivationFunction::LeakyReLu:
578                 case ActivationFunction::Linear:
579                 case ActivationFunction::ReLu:
580                 case ActivationFunction::Sigmoid:
581                 case ActivationFunction::SoftReLu:
582                 case ActivationFunction::Sqrt:
583                 case ActivationFunction::Square:
584                 case ActivationFunction::TanH:
585                 {
586                     m_Res = true;
587                     break;
588                 }
589                 default:
590                 {
591                     m_Res = false;
592                     break;
593                 }
594             }
595         }
596     };
597 
598     // Function is supported
599     supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
600                                   "Reference activation: function not supported.");
601 
602     return supported;
603 }
604 
IsAdditionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const605 bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
606                                           const TensorInfo& input1,
607                                           const TensorInfo& output,
608                                           Optional<std::string&> reasonIfUnsupported) const
609 {
610     bool supported = true;
611 
612     std::array<DataType,7> supportedTypes = {
613         DataType::Float32,
614         DataType::Float16,
615         DataType::QAsymmS8,
616         DataType::QAsymmU8,
617         DataType::QSymmS16,
618         DataType::Signed32
619     };
620 
621     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
622                                   "Reference addition: input 0 is not a supported type.");
623 
624     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
625                                   "Reference addition: input 1 is not a supported type.");
626 
627     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
628                                   "Reference addition: output is not a supported type.");
629 
630     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
631                                   "Reference addition: input 0 and Input 1 types are mismatched");
632 
633     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
634                                   "Reference addition: input and output types are mismatched");
635 
636     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
637                                   "Reference addition: shapes are not suitable for implicit broadcast.");
638 
639     return supported;
640 }
641 
IsArgMinMaxSupported(const armnn::TensorInfo & input,const armnn::TensorInfo & output,const armnn::ArgMinMaxDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const642 bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
643                                            const armnn::ArgMinMaxDescriptor &descriptor,
644                                            armnn::Optional<std::string &> reasonIfUnsupported) const
645 {
646     IgnoreUnused(descriptor);
647 
648     std::array<DataType, 8> supportedInputTypes =
649     {
650         DataType::Float16,
651         DataType::Float32,
652         DataType::QAsymmS8,
653         DataType::QAsymmU8,
654         DataType::QSymmS16,
655         DataType::Signed32,
656         DataType::Signed64
657     };
658 
659     std::array<DataType,2> supportedOutputTypes = {
660         DataType::Signed32,
661         DataType::Signed64
662     };
663 
664     bool supported = true;
665 
666     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
667                                   "Reference ArgMinMax: input is not a supported type.");
668     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
669                                   "Reference ArgMinMax: output type not supported");
670 
671     return supported;
672 }
673 
IsBatchMatMulSupported(const TensorInfo & inputX,const TensorInfo & inputY,const TensorInfo & output,const BatchMatMulDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const674 bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
675                                              const TensorInfo& inputY,
676                                              const TensorInfo& output,
677                                              const BatchMatMulDescriptor& descriptor,
678                                              Optional<std::string &> reasonIfUnsupported) const
679 {
680     IgnoreUnused(descriptor);
681 
682     std::array<DataType, 6> supportedTypes =
683     {
684         DataType::Float16,
685         DataType::Float32,
686         DataType::QAsymmS8,
687         DataType::QAsymmU8,
688         DataType::QSymmS16
689     };
690 
691     bool supported = true;
692 
693     supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
694                                   "Reference batch matrix multiplication: input X is not a supported type");
695 
696     supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
697                                   "Reference batch matrix multiplication: input Y is not a supported type");
698 
699     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
700                                   "Reference batch matrix multiplication: output is not a supported type");
701 
702     supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
703                                   "Reference batch matrix multiplication: input X and input Y types are mismatched");
704 
705     supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
706                                   "Reference batch matrix multiplication: inputs and output types are mismatched");
707 
708     supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
709                                   reasonIfUnsupported,
710                                   "Reference batch matrix multiplication: input X is not of rank 2 or greater");
711 
712     supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
713                                   reasonIfUnsupported,
714                                   "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
715 
716     return supported;
717 }
718 
IsBatchNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & mean,const TensorInfo & variance,const TensorInfo & beta,const TensorInfo & gamma,const BatchNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const719 bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
720                                                     const TensorInfo& output,
721                                                     const TensorInfo& mean,
722                                                     const TensorInfo& variance,
723                                                     const TensorInfo& beta,
724                                                     const TensorInfo& gamma,
725                                                     const BatchNormalizationDescriptor& descriptor,
726                                                     Optional<std::string&> reasonIfUnsupported) const
727 {
728     IgnoreUnused(descriptor);
729 
730     std::array<DataType, 6> supportedTypes =
731     {
732         DataType::Float32,
733         DataType::Float16,
734         DataType::QAsymmS8,
735         DataType::QAsymmU8,
736         DataType::QSymmS16
737     };
738 
739     bool supported = true;
740 
741     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
742                                   "Reference batch normalization: input is not a supported type.");
743 
744     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
745                                   "Reference batch normalization: output is not a supported type.");
746 
747     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
748                                   "Reference batch normalization: input and output types are mismatched");
749 
750     supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
751                                   "Reference batch normalization: mean is not a supported type.");
752 
753     supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
754                                   "Reference batch normalization: variance is not a supported type.");
755 
756     supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
757                                   "Reference batch normalization: beta is not a supported type.");
758 
759     supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
760                                   "Reference batch normalization: gamma is not a supported type.");
761 
762     return supported;
763 }
764 
IsBatchToSpaceNdSupported(const TensorInfo & input,const TensorInfo & output,const BatchToSpaceNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const765 bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
766                                                 const TensorInfo& output,
767                                                 const BatchToSpaceNdDescriptor& descriptor,
768                                                 Optional<std::string&> reasonIfUnsupported) const
769 {
770     IgnoreUnused(descriptor);
771 
772     bool supported = true;
773 
774     std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
775     std::string inputTensorStr = "input";
776     std::string outputTensorStr = "output";
777 
778     // Define supported types.
779     std::array<DataType,6> supportedTypes =
780     {
781         DataType::Float32,
782         DataType::Float16,
783         DataType::QAsymmS8,
784         DataType::QAsymmU8,
785         DataType::QSymmS16
786     };
787 
788     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
789                                   "Reference BatchToSpaceNd: input type not supported.");
790 
791     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
792                                   "Reference BatchToSpaceNd: output type not supported.");
793 
794     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
795                                   "Reference BatchToSpaceNd: input and output types mismatched.");
796 
797     supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
798                                   reasonIfUnsupported,
799                                   CreateIncorrectDimensionsErrorMsg(4,
800                                                                     output.GetNumDimensions(),
801                                                                     batchToSpaceNdLayerStr,
802                                                                     outputTensorStr).data());
803 
804     supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
805                                   reasonIfUnsupported,
806                                   CreateIncorrectDimensionsErrorMsg(4,
807                                                                     input.GetNumDimensions(),
808                                                                     batchToSpaceNdLayerStr,
809                                                                     inputTensorStr).data());
810 
811     return supported;
812 }
813 
IsCastSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const814 bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
815                                       const TensorInfo& output,
816                                       Optional<std::string&> reasonIfUnsupported) const
817 {
818     std::array<DataType, 9> supportedInputTypes =
819             {
820                     DataType::Float32,
821                     DataType::Float16,
822                     DataType::QSymmS8,
823                     DataType::QAsymmS8,
824                     DataType::QAsymmU8,
825                     DataType::QSymmS16,
826                     DataType::Signed32
827             };
828 
829     bool supported = true;
830     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
831                                   "Reference cast: input is not a supported type");
832 
833 
834     supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
835                                   "Reference cast: output is not a supported type");
836 
837     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
838                                   "Reference cast: input and output shapes have different number of total elements");
839 
840     return supported;
841 }
842 
IsChannelShuffleSupported(const TensorInfo & input,const TensorInfo & output,const ChannelShuffleDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const843 bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
844                                                 const TensorInfo& output,
845                                                 const ChannelShuffleDescriptor& descriptor,
846                                                 Optional<std::string&> reasonIfUnsupported) const
847 {
848     IgnoreUnused(descriptor);
849     bool supported = true;
850 
851     // Define supported output and inputs types.
852     std::array<DataType, 7> supportedTypes =
853     {
854         DataType::Float32,
855         DataType::Float16,
856         DataType::QAsymmS8,
857         DataType::QAsymmU8,
858         DataType::QSymmS8,
859         DataType::QSymmS16
860     };
861 
862     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
863                                   "Reference ChannelShuffle: input is not a supported type.");
864 
865     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
866                                   "Reference ChannelShuffle: output is not a supported type.");
867 
868     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
869                                   "Reference ChannelShuffle: input and output types are mismatched.");
870 
871     return supported;
872 }
873 
874 
IsComparisonSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const ComparisonDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const875 bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
876                                             const TensorInfo& input1,
877                                             const TensorInfo& output,
878                                             const ComparisonDescriptor& descriptor,
879                                             Optional<std::string&> reasonIfUnsupported) const
880 {
881     IgnoreUnused(descriptor);
882     std::array<DataType, 8> supportedInputTypes =
883     {
884         DataType::Boolean,
885         DataType::Float32,
886         DataType::Float16,
887         DataType::QAsymmS8,
888         DataType::QAsymmU8,
889         DataType::QSymmS16,
890         DataType::Signed32
891     };
892 
893     bool supported = true;
894     supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
895                                   "Reference comparison: input 0 is not a supported type");
896 
897     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
898                                   "Reference comparison: input 0 and Input 1 types are mismatched");
899 
900     supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
901                                   "Reference comparison: output is not of type Boolean");
902 
903     return supported;
904 }
905 
IsConcatSupported(const std::vector<const TensorInfo * > inputs,const TensorInfo & output,const OriginsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const906 bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
907                                         const TensorInfo& output,
908                                         const OriginsDescriptor& descriptor,
909                                         Optional<std::string&> reasonIfUnsupported) const
910 {
911     IgnoreUnused(descriptor);
912 
913     bool supported = true;
914     std::array<DataType,7> supportedTypes =
915     {
916         DataType::Float32,
917         DataType::Float16,
918         DataType::QAsymmS8,
919         DataType::QAsymmU8,
920         DataType::QSymmS16,
921         DataType::Signed32
922     };
923 
924     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
925                                   "Reference concatenation: output type not supported");
926     for (const TensorInfo* input : inputs)
927     {
928         ARMNN_ASSERT(input != nullptr);
929         supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
930             "Reference concatenation: input type not supported");
931 
932         supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
933             "Reference concatenation: input and output types mismatched.");
934     }
935 
936     return supported;
937 }
938 
IsConstantSupported(const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const939 bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
940                                           Optional<std::string&> reasonIfUnsupported) const
941 {
942     std::array<DataType,8> supportedTypes =
943     {
944         DataType::Float16,
945         DataType::Float32,
946         DataType::QAsymmS8,
947         DataType::QAsymmU8,
948         DataType::QSymmS8,
949         DataType::QSymmS16,
950         DataType::Signed32
951     };
952 
953     return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
954                                   "Reference constant: output is not a supported type.");
955 }
956 
IsConvertFp16ToFp32Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const957 bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
958                                                    const TensorInfo& output,
959                                                    Optional<std::string&> reasonIfUnsupported) const
960 {
961     return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
962                                           input.GetDataType(),
963                                           &TrueFunc<>,
964                                           &FalseInputFuncF32<>,
965                                           &FalseFuncU8<>,
966                                           &FalseFuncI32<>,
967                                           &FalseFuncU8<>) &&
968             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
969                                           output.GetDataType(),
970                                           &FalseOutputFuncF16<>,
971                                           &TrueFunc<>,
972                                           &FalseFuncU8<>,
973                                           &FalseFuncI32<>,
974                                           &FalseFuncU8<>));
975 }
976 
IsConvertFp32ToFp16Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const977 bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
978                                                    const TensorInfo& output,
979                                                    Optional<std::string&> reasonIfUnsupported) const
980 {
981     return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
982                                           input.GetDataType(),
983                                           &FalseInputFuncF16<>,
984                                           &TrueFunc<>,
985                                           &FalseFuncU8<>,
986                                           &FalseFuncI32<>,
987                                           &FalseFuncU8<>) &&
988             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
989                                           output.GetDataType(),
990                                           &TrueFunc<>,
991                                           &FalseOutputFuncF32<>,
992                                           &FalseFuncU8<>,
993                                           &FalseFuncI32<>,
994                                           &FalseFuncU8<>));
995 }
996 
IsConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const997 bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
998                                                const TensorInfo& output,
999                                                const Convolution2dDescriptor& descriptor,
1000                                                const TensorInfo& weights,
1001                                                const Optional<TensorInfo>& biases,
1002                                                Optional<std::string&> reasonIfUnsupported) const
1003 {
1004     bool supported = true;
1005 
1006     // Define supported types.
1007     std::array<DataType,7> supportedTypes =
1008     {
1009         DataType::Float32,
1010         DataType::Float16,
1011         DataType::QAsymmS8,
1012         DataType::QAsymmU8,
1013         DataType::QSymmS8,
1014         DataType::QSymmS16
1015     };
1016 
1017     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1018                                   "Reference Convolution2d: input is not a supported type.");
1019 
1020     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1021                                   "Reference Convolution2d: output is not a supported type.");
1022 
1023     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1024                               "Reference Convolution2d: input and output types mismatched.");
1025 
1026 
1027     const DataType inputType = input.GetDataType();
1028     if (IsQuantized8BitType(inputType))
1029     {
1030         std::array<DataType, 3> supportedWeightTypes =
1031         {
1032             DataType::QAsymmS8,
1033             DataType::QAsymmU8,
1034             DataType::QSymmS8
1035         };
1036 
1037         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1038                                       "Reference Convolution2d: weights type not supported for quantized input.");
1039     }
1040     else
1041     {
1042         supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1043                                       "Reference Convolution2d: weights is not a supported type.");
1044 
1045         supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1046                                       "Reference Convolution2d: input and weights types mismatched.");
1047     }
1048 
1049     if (biases.has_value())
1050     {
1051         std::array<DataType,4> biasesSupportedTypes =
1052         {
1053             DataType::Float32,
1054             DataType::Float16,
1055             DataType::Signed32
1056         };
1057 
1058         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1059                                       "Reference Convolution2d: biases is not a supported type.");
1060     }
1061     IgnoreUnused(descriptor);
1062 
1063     return supported;
1064 }
1065 
IsConvolution3dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution3dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1066 bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1067                                                const TensorInfo& output,
1068                                                const Convolution3dDescriptor& descriptor,
1069                                                const TensorInfo& weights,
1070                                                const Optional<TensorInfo>& biases,
1071                                                Optional<std::string&> reasonIfUnsupported) const
1072 {
1073     bool supported = true;
1074 
1075     // Define supported types.
1076     std::array<DataType,7> supportedTypes =
1077     {
1078         DataType::Float32,
1079         DataType::Float16,
1080         DataType::QAsymmS8,
1081         DataType::QAsymmU8,
1082         DataType::QSymmS8,
1083         DataType::QSymmS16
1084     };
1085 
1086     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1087                                   "Reference Convolution3d: input is not a supported type.");
1088 
1089     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1090                                   "Reference Convolution3d: output is not a supported type.");
1091 
1092     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1093                                   "Reference Convolution3d: input and output types mismatched.");
1094 
1095     const DataType inputType = input.GetDataType();
1096     if (IsQuantized8BitType(inputType))
1097     {
1098         std::array<DataType, 3> supportedWeightTypes =
1099         {
1100             DataType::QAsymmS8,
1101             DataType::QAsymmU8,
1102             DataType::QSymmS8
1103         };
1104 
1105         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1106                                       "Reference Convolution3d: weights type not supported for quantized input.");
1107     }
1108     else
1109     {
1110         supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1111                                       "Reference Convolution3d: weights is not a supported type.");
1112 
1113         supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1114                                       "Reference Convolution3d: input and weights types mismatched.");
1115     }
1116 
1117     if (biases.has_value())
1118     {
1119         std::array<DataType,4> biasesSupportedTypes =
1120         {
1121             DataType::Float32,
1122             DataType::Float16,
1123             DataType::Signed32
1124         };
1125 
1126         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1127                                       "Reference Convolution3d: biases is not a supported type.");
1128     }
1129     IgnoreUnused(descriptor);
1130 
1131     return supported;
1132 }
1133 
IsDebugSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1134 bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1135                                        const TensorInfo& output,
1136                                        Optional<std::string&> reasonIfUnsupported) const
1137 {
1138     bool supported = true;
1139 
1140     std::array<DataType, 8> supportedTypes =
1141     {
1142         DataType::BFloat16,
1143         DataType::Float16,
1144         DataType::Float32,
1145         DataType::QAsymmS8,
1146         DataType::QAsymmU8,
1147         DataType::QSymmS8,
1148         DataType::QSymmS16,
1149         DataType::Signed32
1150     };
1151 
1152     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1153                                   "Reference for Debug layer: input type not supported");
1154 
1155     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1156                                   "Reference for Debug layer: output type not supported");
1157 
1158     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1159                                   "Reference for Debug layer: input and output types are mismatched");
1160 
1161     return supported;
1162 }
1163 
IsDepthToSpaceSupported(const TensorInfo & input,const TensorInfo & output,const DepthToSpaceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1164 bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1165                                               const TensorInfo& output,
1166                                               const DepthToSpaceDescriptor& descriptor,
1167                                               Optional<std::string&> reasonIfUnsupported) const
1168 {
1169     IgnoreUnused(descriptor);
1170     bool supported = true;
1171 
1172     std::array<DataType,6> supportedTypes =
1173     {
1174         DataType::Float32,
1175         DataType::Float16,
1176         DataType::QAsymmS8,
1177         DataType::QAsymmU8,
1178         DataType::QSymmS16
1179     };
1180 
1181     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1182         "Reference DepthToSpace: input type not supported");
1183 
1184     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1185         "Reference DepthToSpace: output type not supported");
1186 
1187     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1188         "Reference DepthToSpace: input and output types are mismatched");
1189 
1190     return supported;
1191 }
1192 
IsDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1193 bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1194                                                       const TensorInfo& output,
1195                                                       const DepthwiseConvolution2dDescriptor& descriptor,
1196                                                       const TensorInfo& weights,
1197                                                       const Optional<TensorInfo>& biases,
1198                                                       Optional<std::string&> reasonIfUnsupported) const
1199 {
1200     IgnoreUnused(descriptor);
1201     bool supported = true;
1202 
1203     // Define supported types.
1204     std::array<DataType,7> supportedTypes =
1205     {
1206         DataType::Float32,
1207         DataType::Float16,
1208         DataType::QAsymmS8,
1209         DataType::QAsymmU8,
1210         DataType::QSymmS8,
1211         DataType::QSymmS16
1212     };
1213 
1214     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1215                                   "Reference DepthwiseConvolution2d: input is not a supported type.");
1216 
1217     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1218                                   "Reference DepthwiseConvolution2d: output is not a supported type.");
1219 
1220     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1221                                   "Reference DepthwiseConvolution2d: input and output types mismatched.");
1222 
1223     const DataType inputType = input.GetDataType();
1224     if (IsQuantized8BitType(inputType))
1225     {
1226         std::array<DataType, 3> supportedWeightTypes =
1227                 {
1228                         DataType::QAsymmS8,
1229                         DataType::QAsymmU8,
1230                         DataType::QSymmS8,
1231                 };
1232 
1233         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1234                                        "Reference DepthwiseConvolution2d: weights type not supported for "
1235                                        "quantized input.");
1236     }
1237     else
1238     {
1239         supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1240                                       "Reference DepthwiseConvolution2d: weights is not a supported type.");
1241 
1242         supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1243                                       "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1244     }
1245 
1246     if (biases.has_value())
1247     {
1248         std::array<DataType,4> biasesSupportedTypes =
1249         {
1250             DataType::Float32,
1251             DataType::Float16,
1252             DataType::Signed32
1253         };
1254         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1255                                       "Reference DepthwiseConvolution2d: biases is not a supported type.");
1256     }
1257 
1258     return supported;
1259 
1260 }
1261 
IsDequantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1262 bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1263                                             const TensorInfo& output,
1264                                             Optional<std::string&> reasonIfUnsupported) const
1265 {
1266    bool supported = true;
1267 
1268     std::array<DataType,5> supportedInputTypes = {
1269         DataType::QAsymmS8,
1270         DataType::QAsymmU8,
1271         DataType::QSymmS8,
1272         DataType::QSymmS16,
1273         DataType::Float16
1274     };
1275 
1276     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1277                                   "Reference for Dequantize layer: input type not supported.");
1278 
1279     supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
1280                                   "Reference for Dequantize layer: per-axis quantized input not supported.");
1281 
1282     std::array<DataType,3> supportedOutputTypes = {
1283         DataType::Float32,
1284         DataType::Float16
1285     };
1286 
1287     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1288                                   "Reference for Dequantize layer: output type not supported.");
1289 
1290     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1291                                   "Reference for Dequantize layer: input/output shapes have different num total "
1292                                   "elements.");
1293 
1294     return supported;
1295 }
1296 
IsDetectionPostProcessSupported(const TensorInfo & boxEncodings,const TensorInfo & scores,const TensorInfo & anchors,const TensorInfo & detectionBoxes,const TensorInfo & detectionClasses,const TensorInfo & detectionScores,const TensorInfo & numDetections,const DetectionPostProcessDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1297 bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1298                                                       const TensorInfo& scores,
1299                                                       const TensorInfo& anchors,
1300                                                       const TensorInfo& detectionBoxes,
1301                                                       const TensorInfo& detectionClasses,
1302                                                       const TensorInfo& detectionScores,
1303                                                       const TensorInfo& numDetections,
1304                                                       const DetectionPostProcessDescriptor& descriptor,
1305                                                       Optional<std::string&> reasonIfUnsupported) const
1306 {
1307     IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
1308 
1309     bool supported = true;
1310 
1311     std::array<DataType,6> supportedInputTypes =
1312     {
1313         DataType::Float32,
1314         DataType::Float16,
1315         DataType::QAsymmS8,
1316         DataType::QAsymmU8,
1317         DataType::QSymmS16
1318     };
1319 
1320     supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
1321                                   "Reference DetectionPostProcess: input 0 is not a supported type.");
1322 
1323     supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
1324                                   "Reference DetectionPostProcess: input 1 is not a supported type.");
1325 
1326     return supported;
1327 }
1328 
IsDilatedDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1329 bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1330                                                              const TensorInfo& output,
1331                                                              const DepthwiseConvolution2dDescriptor& descriptor,
1332                                                              const TensorInfo& weights,
1333                                                              const Optional<TensorInfo>& biases,
1334                                                              Optional<std::string&> reasonIfUnsupported) const
1335 {
1336     return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
1337 }
1338 
IsDivisionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1339 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
1340                                           const TensorInfo& input1,
1341                                           const TensorInfo& output,
1342                                           Optional<std::string&> reasonIfUnsupported) const
1343 {
1344     bool supported = true;
1345 
1346     std::array<DataType,7> supportedTypes = {
1347         DataType::Float32,
1348         DataType::Float16,
1349         DataType::QAsymmS8,
1350         DataType::QAsymmU8,
1351         DataType::QSymmS16,
1352         DataType::Signed32
1353     };
1354 
1355     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1356                                   "Reference division: input 0 is not a supported type.");
1357 
1358     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1359                                   "Reference division: input 1 is not a supported type.");
1360 
1361     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1362                                   "Reference division: output is not a supported type.");
1363 
1364     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1365                                   "Reference division: input 0 and Input 1 types are mismatched");
1366 
1367     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1368                                   "Reference division: input and output types are mismatched");
1369 
1370     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1371                                   "Reference division: shapes are not suitable for implicit broadcast.");
1372 
1373     return supported;
1374 }
1375 
IsElementwiseUnarySupported(const TensorInfo & input,const TensorInfo & output,const ElementwiseUnaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1376 bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1377                                                   const TensorInfo& output,
1378                                                   const ElementwiseUnaryDescriptor& descriptor,
1379                                                   Optional<std::string&> reasonIfUnsupported) const
1380 {
1381     IgnoreUnused(descriptor);
1382 
1383     std::array<DataType, 7> supportedTypes =
1384     {
1385         DataType::Float32,
1386         DataType::Float16,
1387         DataType::QAsymmS8,
1388         DataType::QAsymmU8,
1389         DataType::QSymmS16,
1390         DataType::Signed32
1391     };
1392 
1393     std::array<DataType, 1> logicalSupportedTypes =
1394     {
1395         DataType::Boolean
1396     };
1397 
1398     bool supported = true;
1399 
1400     if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1401     {
1402         supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1403                                       "Reference elementwise unary: input type not supported");
1404 
1405         supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1406                                       "Reference elementwise unary: output type not supported");
1407     }
1408     else
1409     {
1410         supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1411                                       "Reference elementwise unary: input type not supported");
1412 
1413         supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1414                                       "Reference elementwise unary: output type not supported");
1415     }
1416 
1417     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1418                                   "Reference elementwise unary: input and output types not matching");
1419 
1420     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1421                                   "Reference elementwise unary: input and output shapes"
1422                                   "have different number of total elements");
1423 
1424     return supported;
1425 }
1426 
IsFakeQuantizationSupported(const TensorInfo & input,const FakeQuantizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1427 bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1428                                                   const FakeQuantizationDescriptor& descriptor,
1429                                                   Optional<std::string&> reasonIfUnsupported) const
1430 {
1431     IgnoreUnused(descriptor);
1432     bool supported = true;
1433 
1434     std::array<DataType,1> supportedTypes =
1435     {
1436         DataType::Float32
1437     };
1438 
1439     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1440                                   "Reference fake quantization: input type not supported.");
1441 
1442     return supported;
1443 }
1444 
IsFillSupported(const TensorInfo & input,const TensorInfo & output,const FillDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1445 bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1446                                       const TensorInfo& output,
1447                                       const FillDescriptor& descriptor,
1448                                       Optional<std::string&> reasonIfUnsupported) const
1449 {
1450     IgnoreUnused(descriptor);
1451     IgnoreUnused(output);
1452 
1453     bool supported = true;
1454 
1455     std::array<DataType,3> supportedTypes =
1456     {
1457         DataType::Float32,
1458         DataType::Float16,
1459         DataType::Signed32
1460     };
1461 
1462     supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
1463                                   "Reference Fill: input type not supported.");
1464 
1465     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1466                                   "Reference Fill: output type not supported.");
1467     return supported;
1468 }
1469 
IsFloorSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1470 bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1471                                        const TensorInfo& output,
1472                                        Optional<std::string&> reasonIfUnsupported) const
1473 {
1474     IgnoreUnused(output);
1475     bool supported = true;
1476 
1477     std::array<DataType,3> supportedTypes =
1478     {
1479         DataType::Float32,
1480         DataType::Float16
1481     };
1482 
1483     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1484                                   "Reference Floor: input type not supported.");
1485 
1486     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1487                                   "Reference Floor: output type not supported.");
1488 
1489     return supported;
1490 }
1491 
IsFullyConnectedSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & weights,const TensorInfo & biases,const FullyConnectedDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1492 bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1493                                                 const TensorInfo& output,
1494                                                 const TensorInfo& weights,
1495                                                 const TensorInfo& biases,
1496                                                 const FullyConnectedDescriptor& descriptor,
1497                                                 Optional<std::string&> reasonIfUnsupported) const
1498 {
1499     bool supported = true;
1500 
1501     // Define supported types.
1502     std::array<DataType,6> supportedTypes =
1503     {
1504         DataType::Float32,
1505         DataType::Float16,
1506         DataType::QAsymmS8,
1507         DataType::QAsymmU8,
1508         DataType::QSymmS16
1509     };
1510 
1511     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1512                                   "Reference Fully Connected: input type not supported.");
1513 
1514     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1515                                   "Reference Fully Connected: output type not supported.");
1516 
1517     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1518                                   "Reference Fully Connected: weights type not supported.");
1519 
1520     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1521                               "Reference Fully Connected: input and output types mismatched.");
1522 
1523     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1524                                   "Reference Fully Connected: weights is not a supported type.");
1525 
1526     supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1527                                   "Reference Fully Connected: input and weights types mismatched.");
1528 
1529     if (descriptor.m_BiasEnabled)
1530     {
1531         // Defined supported types for bias
1532         std::array<DataType, 5>
1533         supportedBiasTypes =
1534         {
1535             DataType::Float32,
1536             DataType::Float16,
1537             DataType::Signed32,
1538             DataType::QAsymmS8
1539         };
1540 
1541         supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1542                                       "Reference Fully Connected: bias type not supported.");
1543 
1544         supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1545                                       "Reference Fully Connected: bias and weight types mismatch.");
1546 
1547         supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1548                                       "Reference Fully Connected: bias type inferred from weights is incompatible.");
1549 
1550         supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1551                                       "Reference Fully Connected: bias must have 1 dimension.");
1552 
1553     }
1554 
1555     return supported;
1556 }
1557 
IsGatherNdSupported(const armnn::TensorInfo & input0,const armnn::TensorInfo & input1,const armnn::TensorInfo & output,armnn::Optional<std::string &> reasonIfUnsupported) const1558 bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1559                                           const armnn::TensorInfo& input1,
1560                                           const armnn::TensorInfo& output,
1561                                           armnn::Optional<std::string&> reasonIfUnsupported) const
1562 {
1563     bool supported = true;
1564     std::array<DataType,7> supportedTypes =
1565     {
1566             DataType::Float32,
1567             DataType::Float16,
1568             DataType::QAsymmS8,
1569             DataType::QAsymmU8,
1570             DataType::QSymmS16,
1571             DataType::Signed32
1572     };
1573 
1574     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1575                                   "Reference GatherNd: input type not supported");
1576 
1577     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1578                                   "Reference GatherNd: output type not supported");
1579 
1580     supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1581                                   "Reference GatherNd: indices (input1) type not supported");
1582 
1583     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1584                                   "Reference GatherNd: input and output types not matching");
1585 
1586     return supported;
1587 }
1588 
IsGatherSupported(const armnn::TensorInfo & input0,const armnn::TensorInfo & input1,const armnn::TensorInfo & output,const GatherDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const1589 bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1590                                         const armnn::TensorInfo& input1,
1591                                         const armnn::TensorInfo& output,
1592                                         const GatherDescriptor& descriptor,
1593                                         armnn::Optional<std::string&> reasonIfUnsupported) const
1594 {
1595     bool supported = true;
1596     std::array<DataType,7> supportedTypes =
1597     {
1598         DataType::Float32,
1599         DataType::Float16,
1600         DataType::QAsymmS8,
1601         DataType::QAsymmU8,
1602         DataType::QSymmS16,
1603         DataType::Signed32
1604     };
1605 
1606     IgnoreUnused(descriptor);
1607     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1608                                   "Reference Gather: input type not supported");
1609 
1610     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1611                                   "Reference Gather: output type not supported");
1612 
1613     supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1614                                   "Reference Gather: indices (input1) type not supported");
1615 
1616     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1617                                   "Reference Gather: input and output types not matching");
1618 
1619     return supported;
1620 }
1621 
IsInputSupported(const TensorInfo &,Optional<std::string &>) const1622 bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1623                                        Optional<std::string&> /*reasonIfUnsupported*/) const
1624 {
1625     return true;
1626 }
1627 
IsInstanceNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const InstanceNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1628 bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1629                                                        const TensorInfo& output,
1630                                                        const InstanceNormalizationDescriptor& descriptor,
1631                                                        Optional<std::string&> reasonIfUnsupported) const
1632 {
1633     IgnoreUnused(descriptor);
1634     // Define supported types
1635     std::array<DataType, 3> supportedTypes =
1636         {
1637             DataType::Float32,
1638             DataType::Float16
1639         };
1640 
1641     bool supported = true;
1642 
1643     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1644                                   "Reference Instance Normalization: input type not supported.");
1645 
1646     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1647                                   "Reference Instance Normalization: output type not supported.");
1648 
1649     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1650                                   "Reference Instance Normalization: input and output types mismatched.");
1651 
1652     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1653                                   "Reference Instance Normalization: input and output shapes have different "
1654                                   "num total elements.");
1655 
1656     return supported;
1657 }
1658 
IsL2NormalizationSupported(const TensorInfo & input,const TensorInfo & output,const L2NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1659 bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1660                                                  const TensorInfo& output,
1661                                                  const L2NormalizationDescriptor& descriptor,
1662                                                  Optional<std::string&> reasonIfUnsupported) const
1663 {
1664     IgnoreUnused(descriptor);
1665     // Define supported types
1666     std::array<DataType, 6> supportedTypes =
1667     {
1668         DataType::Float32,
1669         DataType::Float16,
1670         DataType::QAsymmS8,
1671         DataType::QAsymmU8,
1672         DataType::QSymmS16
1673     };
1674 
1675     bool supported = true;
1676 
1677     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1678                                   "Reference L2normalization: input type not supported.");
1679 
1680     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1681                                   "Reference L2normalization: output type not supported.");
1682 
1683     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1684                                   "Reference L2normalization: input and output types mismatched.");
1685 
1686     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1687                                   "Reference L2normalization: input and output shapes have different "
1688                                   "num total elements.");
1689 
1690     return supported;
1691 }
1692 
IsLogicalBinarySupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const LogicalBinaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1693 bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1694                                                const TensorInfo& input1,
1695                                                const TensorInfo& output,
1696                                                const LogicalBinaryDescriptor& descriptor,
1697                                                Optional<std::string&> reasonIfUnsupported) const
1698 {
1699     IgnoreUnused(descriptor);
1700 
1701     std::array<DataType, 1> supportedTypes =
1702     {
1703         DataType::Boolean
1704     };
1705 
1706     bool supported = true;
1707     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1708                                   "Reference LogicalBinary: input 0 type not supported");
1709     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1710                                   "Reference LogicalBinary: input 1 type not supported");
1711 
1712     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1713                                   "Reference LogicalBinary: input and output types do not match");
1714 
1715     return supported;
1716 }
1717 
IsLogSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const LogSoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1718 bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1719                                             const TensorInfo& output,
1720                                             const LogSoftmaxDescriptor& descriptor,
1721                                             Optional<std::string&> reasonIfUnsupported) const
1722 {
1723     IgnoreUnused(descriptor);
1724 
1725     std::array<DataType, 3> supportedTypes =
1726     {
1727         DataType::Float32,
1728         DataType::Float16
1729     };
1730 
1731     bool supported = true;
1732     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1733                                   "Reference LogSoftmax: input type not supported");
1734 
1735     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1736                                   "Reference LogSoftmax: output type not supported");
1737 
1738     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1739                                   "Reference LogSoftmax: input and output types do not match");
1740 
1741     return supported;
1742 }
1743 
IsLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & scratchBuffer,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const LstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1744 bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1745                                       const TensorInfo& outputStateIn,
1746                                       const TensorInfo& cellStateIn,
1747                                       const TensorInfo& scratchBuffer,
1748                                       const TensorInfo& outputStateOut,
1749                                       const TensorInfo& cellStateOut,
1750                                       const TensorInfo& output,
1751                                       const LstmDescriptor& descriptor,
1752                                       const LstmInputParamsInfo& paramsInfo,
1753                                       Optional<std::string&> reasonIfUnsupported) const
1754 {
1755     IgnoreUnused(descriptor);
1756     IgnoreUnused(paramsInfo);
1757 
1758     bool supported = true;
1759 
1760     std::array<DataType,3> supportedTypes = {
1761         DataType::Float32,
1762         DataType::QSymmS16
1763     };
1764 
1765     // check inputs and outputs
1766     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1767                                   "Reference Lstm: input is not a supported type.");
1768     supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1769                                   "Reference Lstm: input and outputStateIn types are mismatched");
1770     supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1771                                   "Reference Lstm: input and cellStateIn types are mismatched");
1772     supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1773                                   "Reference Lstm: input and scratchBuffer types are mismatched");
1774     supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1775                                   "Reference Lstm: input and outputStateOut types are mismatched");
1776     supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1777                                   "Reference Lstm: input and cellStateOut types are mismatched");
1778 
1779     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1780                                   "Reference Lstm: input and output types are mismatched");
1781     // check layer parameters
1782     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
1783                                   "Reference Lstm: input and InputToForgetWeights types are mismatched");
1784     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
1785                                   "Reference Lstm: input and InputToCellWeights types are mismatched");
1786     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
1787                                   "Reference Lstm: input and InputToOutputWeights types are mismatched");
1788     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
1789                                   "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
1790     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
1791                                   "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
1792     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
1793                                   "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
1794     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
1795                                   "Reference Lstm: input and ForgetGateBias types are mismatched");
1796     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
1797                                   "Reference Lstm: input and CellBias types are mismatched");
1798     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
1799                                   "Reference Lstm: input and OutputGateBias types are mismatched");
1800     if (!descriptor.m_CifgEnabled)
1801     {
1802         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
1803                                       "Reference Lstm: input and InputToInputWeights types are mismatched");
1804         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
1805                                       reasonIfUnsupported,
1806                                       "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
1807         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
1808                                       "Reference Lstm: input and InputGateBias types are mismatched");
1809         if (descriptor.m_PeepholeEnabled)
1810         {
1811             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
1812                                           reasonIfUnsupported,
1813                                           "Reference Lstm: input and CellToInputWeights types are mismatched");
1814         }
1815     }
1816     if (descriptor.m_PeepholeEnabled)
1817     {
1818         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
1819                                       "Reference Lstm: input and CellToForgetWeights types are mismatched");
1820         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
1821                                       "Reference Lstm: input and CellToOutputWeights types are mismatched");
1822     }
1823     if (descriptor.m_ProjectionEnabled)
1824     {
1825         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
1826                                       "Reference Lstm: input and mProjectionWeights types are mismatched");
1827         if (paramsInfo.m_ProjectionBias != nullptr)
1828         {
1829             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
1830                                           "Reference Lstm: input and ProjectionBias types are mismatched");
1831         }
1832     }
1833     if (descriptor.m_LayerNormEnabled)
1834     {
1835         if (!descriptor.m_CifgEnabled)
1836         {
1837             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
1838                                           reasonIfUnsupported,
1839                                           "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1840         }
1841         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
1842                                       reasonIfUnsupported,
1843                                       "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1844         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
1845                                       reasonIfUnsupported,
1846                                       "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1847         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
1848                                       reasonIfUnsupported,
1849                                       "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1850     }
1851 
1852     return supported;
1853 }
1854 
IsMaximumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1855 bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1856                                          const TensorInfo& input1,
1857                                          const TensorInfo& output,
1858                                          Optional<std::string&> reasonIfUnsupported) const
1859 {
1860     bool supported = true;
1861 
1862     std::array<DataType,7> supportedTypes = {
1863         DataType::Float32,
1864         DataType::Float16,
1865         DataType::QAsymmS8,
1866         DataType::QAsymmU8,
1867         DataType::QSymmS16,
1868         DataType::Signed32
1869     };
1870 
1871     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1872                                   "Reference maximum: input 0 is not a supported type.");
1873 
1874     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1875                                   "Reference maximum: input 1 is not a supported type.");
1876 
1877     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1878                                   "Reference maximum: output is not a supported type.");
1879 
1880     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1881                                   "Reference maximum: input 0 and Input 1 types are mismatched");
1882 
1883     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1884                                   "Reference maximum: input and output types are mismatched");
1885 
1886     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1887                                   "Reference maximum: shapes are not suitable for implicit broadcast.");
1888 
1889     return supported;
1890 }
1891 
IsMeanSupported(const TensorInfo & input,const TensorInfo & output,const MeanDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1892 bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1893                                       const TensorInfo& output,
1894                                       const MeanDescriptor& descriptor,
1895                                       Optional<std::string&> reasonIfUnsupported) const
1896 {
1897     bool supported = true;
1898     std::string meanLayerStr = "Mean";
1899     std::string outputTensorStr = "output";
1900 
1901     std::array<DataType,6> supportedTypes =
1902     {
1903         DataType::Float32,
1904         DataType::Float16,
1905         DataType::QAsymmS8,
1906         DataType::QAsymmU8,
1907         DataType::QSymmS16
1908     };
1909 
1910     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1911                                   "Reference Mean: input type not supported.");
1912 
1913     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1914                                   "Reference Mean: input and output types are mismatched");
1915 
1916     if (descriptor.m_KeepDims)
1917     {
1918         supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1919                                       reasonIfUnsupported,
1920                                       CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1921                                                                         output.GetNumDimensions(),
1922                                                                         meanLayerStr, outputTensorStr).data());
1923     }
1924     else if (descriptor.m_Axis.empty())
1925     {
1926         supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1927                                       reasonIfUnsupported,
1928                                       CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1929                                                                         meanLayerStr, outputTensorStr).data());
1930     }
1931     else
1932     {
1933         auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1934 
1935         if (outputDim > 0)
1936         {
1937             supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1938                                           reasonIfUnsupported,
1939                                           CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1940                                                                             meanLayerStr, outputTensorStr).data());
1941         }
1942         else
1943         {
1944             supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1945                                           reasonIfUnsupported,
1946                                           CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1947                                                                             meanLayerStr, outputTensorStr).data());
1948         }
1949     }
1950 
1951     return supported;
1952 }
1953 
IsMemCopySupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1954 bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1955                                          const TensorInfo &output,
1956                                          Optional<std::string &> reasonIfUnsupported) const
1957 {
1958     bool supported = true;
1959 
1960     std::array<DataType,7> supportedTypes =
1961     {
1962         DataType::BFloat16,
1963         DataType::Float32,
1964         DataType::Float16,
1965         DataType::QAsymmS8,
1966         DataType::QAsymmU8,
1967         DataType::QSymmS16,
1968         DataType::Boolean
1969     };
1970 
1971     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1972                                   "Reference MemCopy: input type not supported");
1973 
1974     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1975                                   "Reference MemCopy: output type not supported");
1976 
1977     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1978                                   "Reference MemCopy: input and output types are mismatched");
1979 
1980     return supported;
1981 }
1982 
IsMinimumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1983 bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1984                                          const TensorInfo& input1,
1985                                          const TensorInfo& output,
1986                                          Optional<std::string&> reasonIfUnsupported) const
1987 {
1988     bool supported = true;
1989 
1990     std::array<DataType,7> supportedTypes = {
1991         DataType::Float32,
1992         DataType::Float16,
1993         DataType::QAsymmS8,
1994         DataType::QAsymmU8,
1995         DataType::QSymmS16,
1996         DataType::Signed32
1997     };
1998 
1999     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2000                                   "Reference minimum: input 0 is not a supported type.");
2001 
2002     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2003                                   "Reference minimum: input 1 is not a supported type.");
2004 
2005     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2006                                   "Reference minimum: output is not a supported type.");
2007 
2008     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2009                                   "Reference minimum: input 0 and Input 1 types are mismatched");
2010 
2011     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2012                                   "Reference minimum: input and output types are mismatched");
2013 
2014     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2015                                   "Reference minimum: shapes are not suitable for implicit broadcast.");
2016 
2017     return supported;
2018 }
2019 
IsMultiplicationSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2020 bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2021                                                 const TensorInfo& input1,
2022                                                 const TensorInfo& output,
2023                                                 Optional<std::string&> reasonIfUnsupported) const
2024 {
2025     bool supported = true;
2026 
2027     std::array<DataType,7> supportedTypes = {
2028         DataType::Float32,
2029         DataType::Float16,
2030         DataType::QAsymmS8,
2031         DataType::QAsymmU8,
2032         DataType::QSymmS16,
2033         DataType::Signed32
2034     };
2035 
2036     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2037                                   "Reference multiplication: input 0 is not a supported type.");
2038 
2039     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2040                                   "Reference multiplication: input 1 is not a supported type.");
2041 
2042     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2043                                   "Reference multiplication: output is not a supported type.");
2044 
2045     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2046                                   "Reference multiplication: input 0 and Input 1 types are mismatched");
2047 
2048     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2049                                   "Reference multiplication: input and output types are mismatched");
2050 
2051     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2052                                   "Reference multiplication: shapes are not suitable for implicit broadcast.");
2053 
2054     return supported;
2055 }
2056 
IsNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2057 bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2058                                                const TensorInfo& output,
2059                                                const NormalizationDescriptor& descriptor,
2060                                                Optional<std::string&> reasonIfUnsupported) const
2061 {
2062     IgnoreUnused(descriptor);
2063 
2064     // Define supported types
2065     std::array<DataType, 6> supportedTypes =
2066     {
2067         DataType::Float16,
2068         DataType::Float32,
2069         DataType::QAsymmS8,
2070         DataType::QAsymmU8,
2071         DataType::QSymmS16
2072     };
2073 
2074     bool supported = true;
2075 
2076     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2077                                   "Reference normalization: input type not supported.");
2078 
2079     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2080                                   "Reference normalization: output type not supported.");
2081 
2082     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2083                                   "Reference normalization: input and output shapes have different "
2084                                   "num total elements.");
2085 
2086     return supported;
2087 }
2088 
IsOutputSupported(const TensorInfo &,Optional<std::string &>) const2089 bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2090                                         Optional<std::string&> /*reasonIfUnsupported*/) const
2091 {
2092     return true;
2093 }
2094 
IsPadSupported(const TensorInfo & input,const TensorInfo & output,const PadDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2095 bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2096                                      const TensorInfo& output,
2097                                      const PadDescriptor& descriptor,
2098                                      Optional<std::string&> reasonIfUnsupported) const
2099 {
2100     IgnoreUnused(descriptor);
2101     bool supported = true;
2102 
2103     // Define supported output and inputs types.
2104     std::array<DataType,6> supportedTypes =
2105     {
2106         DataType::Float32,
2107         DataType::Float16,
2108         DataType::QAsymmS8,
2109         DataType::QAsymmU8,
2110         DataType::QSymmS16
2111     };
2112 
2113     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2114                                   "Reference pad: input is not a supported type.");
2115 
2116     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2117                                   "Reference pad: output is not a supported type.");
2118 
2119     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2120                                   "Reference pad: input and output types are mismatched.");
2121 
2122     return supported;
2123 }
2124 
IsPermuteSupported(const TensorInfo & input,const TensorInfo & output,const PermuteDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2125 bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2126                                          const TensorInfo& output,
2127                                          const PermuteDescriptor& descriptor,
2128                                          Optional<std::string&> reasonIfUnsupported) const
2129 {
2130     IgnoreUnused(descriptor);
2131     bool supported = true;
2132 
2133     // Define supported output and inputs types.
2134     std::array<DataType, 6> supportedTypes =
2135     {
2136         DataType::BFloat16,
2137         DataType::Float32,
2138         DataType::Float16,
2139         DataType::QAsymmS8,
2140         DataType::QAsymmU8,
2141         DataType::QSymmS16
2142     };
2143 
2144     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2145                                   "Reference permute: input is not a supported type.");
2146 
2147     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2148                                   "Reference permute: output is not a supported type.");
2149 
2150     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2151                                   "Reference permute: input and output types are mismatched.");
2152 
2153     return supported;
2154 }
2155 
IsPooling2dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling2dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2156 bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2157                                            const TensorInfo& output,
2158                                            const Pooling2dDescriptor& descriptor,
2159                                            Optional<std::string&> reasonIfUnsupported) const
2160 {
2161     IgnoreUnused(descriptor);
2162     bool supported = true;
2163 
2164     // Define supported output and inputs types.
2165     std::array<DataType,6> supportedTypes =
2166     {
2167         DataType::Float32,
2168         DataType::Float16,
2169         DataType::QAsymmS8,
2170         DataType::QAsymmU8,
2171         DataType::QSymmS16
2172     };
2173 
2174     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2175                                   "Reference poolind2d: input is not a supported type.");
2176 
2177     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2178                                   "Reference poolind2d: output is not a supported type.");
2179 
2180     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2181                                   "Reference poolind2d: input and output types are mismatched.");
2182 
2183     return supported;
2184 }
2185 
IsPooling3dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling3dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2186 bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2187                                            const TensorInfo& output,
2188                                            const Pooling3dDescriptor& descriptor,
2189                                            Optional<std::string&> reasonIfUnsupported) const
2190 {
2191     IgnoreUnused(descriptor);
2192     bool supported = true;
2193 
2194     // Define supported output and inputs types.
2195     std::array<DataType,6> supportedTypes =
2196     {
2197         DataType::Float32,
2198         DataType::Float16,
2199         DataType::QAsymmS8,
2200         DataType::QAsymmU8,
2201         DataType::QSymmS16
2202     };
2203 
2204     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2205                                   "Reference poolind3d: input is not a supported type.");
2206 
2207     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2208                                   "Reference poolind3d: output is not a supported type.");
2209 
2210     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2211                                   "Reference poolind3d: input and output types are mismatched.");
2212 
2213     return supported;
2214 }
2215 
2216 
IsQLstmSupported(const TensorInfo & input,const TensorInfo & previousOutputIn,const TensorInfo & previousCellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const QLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const2217 bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2218                                        const TensorInfo& previousOutputIn,
2219                                        const TensorInfo& previousCellStateIn,
2220                                        const TensorInfo& outputStateOut,
2221                                        const TensorInfo& cellStateOut,
2222                                        const TensorInfo& output,
2223                                        const QLstmDescriptor& descriptor,
2224                                        const LstmInputParamsInfo& paramsInfo,
2225                                        Optional<std::string&> reasonIfUnsupported) const
2226 {
2227     IgnoreUnused(input);
2228     IgnoreUnused(previousOutputIn);
2229     IgnoreUnused(previousCellStateIn);
2230     IgnoreUnused(outputStateOut);
2231     IgnoreUnused(cellStateOut);
2232     IgnoreUnused(output);
2233     IgnoreUnused(descriptor);
2234     IgnoreUnused(paramsInfo);
2235 
2236     IgnoreUnused(reasonIfUnsupported);
2237 
2238     return true;
2239 }
2240 
IsQuantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2241 bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2242                                           const TensorInfo& output,
2243                                           Optional<std::string&> reasonIfUnsupported) const
2244 {
2245    bool supported = true;
2246 
2247     // Define supported input types.
2248     std::array<DataType,7> supportedInputTypes = {
2249         DataType::Float32,
2250         DataType::Float16,
2251         DataType::QAsymmS8,
2252         DataType::QAsymmU8,
2253         DataType::QSymmS8,
2254         DataType::QSymmS16
2255     };
2256 
2257     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2258                                   "Reference quantize: input type not supported.");
2259 
2260     // Define supported output types.
2261     std::array<DataType,4> supportedOutputTypes = {
2262         DataType::QAsymmS8,
2263         DataType::QAsymmU8,
2264         DataType::QSymmS8,
2265         DataType::QSymmS16
2266     };
2267     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2268                                   "Reference quantize: output type not supported.");
2269 
2270     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2271                                   "Reference quantize: input and output shapes have different num total elements.");
2272 
2273     return supported;
2274 }
2275 
IsRankSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2276 bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2277                                       const TensorInfo& output,
2278                                       Optional<std::string&> reasonIfUnsupported) const
2279 {
2280     IgnoreUnused(input);
2281     // Define supported output types.
2282     std::array<DataType,1> supportedOutputTypes =
2283     {
2284         DataType::Signed32,
2285     };
2286 
2287     return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2288            "Reference rank: input type not supported.");
2289 }
2290 
IsReduceSupported(const TensorInfo & input,const TensorInfo & output,const ReduceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2291 bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2292                                         const TensorInfo& output,
2293                                         const ReduceDescriptor& descriptor,
2294                                         Optional<std::string&> reasonIfUnsupported) const
2295 {
2296     IgnoreUnused(descriptor);
2297     bool supported = true;
2298     std::array<DataType,7> supportedTypes =
2299     {
2300         DataType::Float32,
2301         DataType::Float16,
2302         DataType::QAsymmS8,
2303         DataType::QAsymmU8,
2304         DataType::QSymmS16,
2305         DataType::Signed32
2306     };
2307 
2308     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2309                                   "Reference Reduce: input type not supported");
2310 
2311     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2312                                   "Reference Reduce: output type not supported");
2313 
2314     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2315                                   "Reference Reduce: input and output types not matching");
2316 
2317     return supported;
2318 }
2319 
IsReshapeSupported(const TensorInfo & input,const TensorInfo & output,const ReshapeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2320 bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
2321                                          const TensorInfo& output,
2322                                          const ReshapeDescriptor& descriptor,
2323                                          Optional<std::string&> reasonIfUnsupported) const
2324 {
2325     IgnoreUnused(output);
2326     IgnoreUnused(descriptor);
2327     // Define supported output types.
2328     std::array<DataType,8> supportedOutputTypes =
2329     {
2330         DataType::BFloat16,
2331         DataType::Float32,
2332         DataType::Float16,
2333         DataType::Signed32,
2334         DataType::QAsymmS8,
2335         DataType::QAsymmU8,
2336         DataType::QSymmS16,
2337         DataType::Boolean
2338     };
2339 
2340     return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2341         "Reference reshape: input type not supported.");
2342 }
2343 
IsResizeSupported(const TensorInfo & input,const TensorInfo & output,const ResizeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2344 bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2345                                         const TensorInfo& output,
2346                                         const ResizeDescriptor& descriptor,
2347                                         Optional<std::string&> reasonIfUnsupported) const
2348 {
2349     IgnoreUnused(descriptor);
2350     bool supported = true;
2351     std::array<DataType,6> supportedTypes =
2352     {
2353         DataType::BFloat16,
2354         DataType::Float32,
2355         DataType::Float16,
2356         DataType::QAsymmS8,
2357         DataType::QAsymmU8,
2358         DataType::QSymmS16
2359     };
2360 
2361     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2362                                   "Reference Resize: input type not supported");
2363 
2364     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2365                                   "Reference Resize: output type not supported");
2366 
2367     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2368                                   "Reference Resize: input and output types not matching");
2369 
2370     return supported;
2371 }
2372 
IsShapeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2373 bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2374                                        const TensorInfo& output,
2375                                        Optional<std::string&> reasonIfUnsupported) const
2376 {
2377     IgnoreUnused(input);
2378     bool supported = true;
2379 
2380     std::array<DataType, 1> supportedTypes =
2381     {
2382         DataType::Signed32
2383     };
2384 
2385     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2386                                   "Reference Shape: output type not supported");
2387 
2388     return supported;
2389 }
2390 
IsSliceSupported(const TensorInfo & input,const TensorInfo & output,const SliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2391 bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2392                                        const TensorInfo& output,
2393                                        const SliceDescriptor& descriptor,
2394                                        Optional<std::string&> reasonIfUnsupported) const
2395 {
2396     IgnoreUnused(descriptor);
2397     bool supported = true;
2398 
2399     std::array<DataType, 5> supportedTypes =
2400     {
2401         DataType::Float32,
2402         DataType::QAsymmS8,
2403         DataType::QAsymmU8,
2404         DataType::QSymmS16
2405     };
2406 
2407     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2408                                   "Reference Slice: input type not supported");
2409 
2410     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2411                                   "Reference Slice: output type not supported");
2412 
2413     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2414                                   "Reference Slice: input and output types are mismatched");
2415 
2416     return supported;
2417 }
2418 
IsSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const SoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2419 bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2420                                          const TensorInfo& output,
2421                                          const SoftmaxDescriptor& descriptor,
2422                                          Optional<std::string&> reasonIfUnsupported) const
2423 {
2424     IgnoreUnused(descriptor);
2425     bool supported = true;
2426     std::array<DataType,7> supportedTypes =
2427     {
2428         DataType::Float32,
2429         DataType::Float16,
2430         DataType::QSymmS8,
2431         DataType::QAsymmS8,
2432         DataType::QAsymmU8,
2433         DataType::QSymmS16
2434     };
2435 
2436     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2437                                   "Reference Softmax: output type not supported");
2438 
2439     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2440                                   "Reference Softmax: input type not supported");
2441 
2442     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2443                                   "Reference Softmax: input type not supported");
2444 
2445     return supported;
2446 }
2447 
IsSpaceToBatchNdSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToBatchNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2448 bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2449                                                 const TensorInfo& output,
2450                                                 const SpaceToBatchNdDescriptor& descriptor,
2451                                                 Optional<std::string&> reasonIfUnsupported) const
2452 {
2453     IgnoreUnused(descriptor);
2454     bool supported = true;
2455     std::array<DataType,6> supportedTypes =
2456     {
2457         DataType::Float32,
2458         DataType::Float16,
2459         DataType::QAsymmS8,
2460         DataType::QAsymmU8,
2461         DataType::QSymmS16
2462     };
2463 
2464     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2465                                   "Reference SpaceToBatchNd: input type not supported");
2466 
2467     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2468                                   "Reference SpaceToBatchNd: output type not supported");
2469 
2470     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2471                                   "Reference SpaceToBatchNd: input and output types are mismatched");
2472 
2473     return supported;
2474 }
2475 
IsSpaceToDepthSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToDepthDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2476 bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
2477                                               const TensorInfo& output,
2478                                               const SpaceToDepthDescriptor& descriptor,
2479                                               Optional<std::string&> reasonIfUnsupported) const
2480 {
2481 
2482     IgnoreUnused(descriptor);
2483     bool supported = true;
2484 
2485     std::array<DataType,6> supportedTypes =
2486     {
2487         DataType::Float32,
2488         DataType::Float16,
2489         DataType::QAsymmS8,
2490         DataType::QAsymmU8,
2491         DataType::QSymmS16
2492     };
2493 
2494     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2495         "Reference SpaceToDepth: input type not supported");
2496 
2497     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2498         "Reference SpaceToDepth: output type not supported");
2499 
2500     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2501         "Reference SpaceToDepth: input and output types are mismatched");
2502 
2503     return supported;
2504 }
2505 
IsSplitterSupported(const TensorInfo & input,const std::vector<std::reference_wrapper<TensorInfo>> & outputs,const ViewsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2506 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
2507                                           const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2508                                           const ViewsDescriptor& descriptor,
2509                                           Optional<std::string&> reasonIfUnsupported) const
2510 {
2511     IgnoreUnused(descriptor);
2512     bool supported = true;
2513     std::array<DataType,6> supportedTypes =
2514     {
2515         DataType::Float32,
2516         DataType::Float16,
2517         DataType::QAsymmS8,
2518         DataType::QAsymmU8,
2519         DataType::QSymmS16
2520     };
2521 
2522     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2523                                   "Reference splitter: output type not supported");
2524     for (const TensorInfo& output : outputs)
2525     {
2526         supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2527                                       "Reference splitter: input type not supported");
2528 
2529         supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2530                                       "Reference splitter: input and output types mismatched.");
2531     }
2532 
2533     return supported;
2534 }
2535 
IsStackSupported(const std::vector<const TensorInfo * > & inputs,const TensorInfo & output,const StackDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2536 bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2537                                        const TensorInfo& output,
2538                                        const StackDescriptor& descriptor,
2539                                        Optional<std::string&> reasonIfUnsupported) const
2540 {
2541     IgnoreUnused(descriptor);
2542 
2543     bool supported = true;
2544     std::array<DataType,7> supportedTypes =
2545     {
2546         DataType::Float32,
2547         DataType::Float16,
2548         DataType::QAsymmS8,
2549         DataType::QAsymmU8,
2550         DataType::QSymmS16,
2551         DataType::Signed32
2552     };
2553 
2554     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2555                                   "Reference stack: output type not supported");
2556     for (const TensorInfo* input : inputs)
2557     {
2558         ARMNN_ASSERT(input != nullptr);
2559         supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2560             "Reference stack: input type not supported");
2561 
2562         supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2563             "Reference stack: input and output types mismatched.");
2564     }
2565 
2566     return supported;
2567 }
2568 
IsStridedSliceSupported(const TensorInfo & input,const TensorInfo & output,const StridedSliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2569 bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2570                                               const TensorInfo& output,
2571                                               const StridedSliceDescriptor& descriptor,
2572                                               Optional<std::string&> reasonIfUnsupported) const
2573 {
2574     IgnoreUnused(descriptor);
2575     bool supported = true;
2576 
2577     std::array<DataType,5> supportedTypes =
2578     {
2579         DataType::Float32,
2580         DataType::QAsymmS8,
2581         DataType::QAsymmU8,
2582         DataType::QSymmS16
2583     };
2584 
2585     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2586                                   "Reference StridedSlice: input type not supported");
2587 
2588     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2589                                   "Reference StridedSlice: output type not supported");
2590 
2591     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2592                                   "Reference StridedSlice: input and output types are mismatched");
2593 
2594     return supported;
2595 }
2596 
IsSubtractionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2597 bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2598                                              const TensorInfo& input1,
2599                                              const TensorInfo& output,
2600                                              Optional<std::string&> reasonIfUnsupported) const
2601 {
2602     bool supported = true;
2603 
2604     std::array<DataType,7> supportedTypes = {
2605         DataType::Float32,
2606         DataType::Float16,
2607         DataType::QAsymmS8,
2608         DataType::QAsymmU8,
2609         DataType::QSymmS16,
2610         DataType::Signed32
2611     };
2612 
2613     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2614                                   "Reference subtraction: input 0 is not a supported type.");
2615 
2616     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2617                                   "Reference subtraction: input 1 is not a supported type.");
2618 
2619     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2620                                   "Reference subtraction: output is not a supported type.");
2621 
2622     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2623                                   "Reference subtraction: input 0 and Input 1 types are mismatched");
2624 
2625     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2626                                   "Reference subtraction: input and output types are mismatched");
2627 
2628     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2629                                   "Reference subtraction: shapes are not suitable for implicit broadcast.");
2630 
2631     return supported;
2632 }
2633 
IsPreluSupported(const TensorInfo & input,const TensorInfo & alpha,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2634 bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2635                                        const TensorInfo& alpha,
2636                                        const TensorInfo& output,
2637                                        Optional<std::string&> reasonIfUnsupported) const
2638 {
2639     bool supported = true;
2640 
2641     std::array<DataType, 6> supportedTypes
2642     {
2643         DataType::Float32,
2644         DataType::Float16,
2645         DataType::QAsymmS8,
2646         DataType::QAsymmU8,
2647         DataType::QSymmS16
2648     };
2649 
2650     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2651                                   "PReLU: input is not a supported type.");
2652 
2653     supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2654                                   "PReLU: alpha is not a supported type.");
2655 
2656     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2657                                   "PReLU: output is not a supported type.");
2658 
2659     supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2660                                   "PReLU: input, alpha and output types are mismatched");
2661 
2662     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2663                                   "PReLU: shapes are not suitable for implicit broadcast");
2664 
2665     return supported;
2666 }
2667 
IsTransposeConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const TransposeConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const2668 bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2669                                                         const TensorInfo& output,
2670                                                         const TransposeConvolution2dDescriptor& descriptor,
2671                                                         const TensorInfo& weights,
2672                                                         const Optional<TensorInfo>& biases,
2673                                                         Optional<std::string&> reasonIfUnsupported) const
2674 {
2675     IgnoreUnused(descriptor);
2676     bool supported = true;
2677 
2678     std::array<DataType,7> supportedTypes =
2679     {
2680         DataType::Float32,
2681         DataType::Float16,
2682         DataType::QAsymmS8,
2683         DataType::QAsymmU8,
2684         DataType::QSymmS8,
2685         DataType::QSymmS16
2686     };
2687 
2688     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2689                                   "Reference TransposeConvolution2d: input is not a supported type.");
2690 
2691     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2692                                   "Reference TransposeConvolution2d: output is not a supported type.");
2693 
2694     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2695                                   "Reference TransposeConvolution2d: input and output types mismatched.");
2696 
2697 
2698     const DataType inputType = input.GetDataType();
2699     if (IsQuantized8BitType(inputType))
2700     {
2701         std::array<DataType, 3> supportedWeightTypes =
2702         {
2703             DataType::QAsymmS8,
2704             DataType::QAsymmU8,
2705             DataType::QSymmS8
2706         };
2707 
2708         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2709                                       "Reference TransposeConvolution2d: weights type not supported for "
2710                                       "quantized input.");
2711     }
2712     else
2713     {
2714         supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2715                                     "Reference TransposeConvolution2d: weights is not a supported type.");
2716 
2717         supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2718                                     "Reference TransposeConvolution2d: input and weights types mismatched.");
2719     }
2720 
2721     if (biases.has_value())
2722     {
2723         std::array<DataType,4> biasesSupportedTypes =
2724         {
2725             DataType::Float32,
2726             DataType::Float16,
2727             DataType::Signed32
2728         };
2729         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2730                                       "Reference TransposeConvolution2d: biases is not a supported type.");
2731     }
2732 
2733     return supported;
2734 }
2735 
IsTransposeSupported(const TensorInfo & input,const TensorInfo & output,const TransposeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2736 bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2737                                            const TensorInfo& output,
2738                                            const TransposeDescriptor& descriptor,
2739                                            Optional<std::string&> reasonIfUnsupported) const
2740 {
2741     IgnoreUnused(descriptor);
2742     bool supported = true;
2743 
2744     // Define supported output and inputs types.
2745     std::array<DataType, 6> supportedTypes =
2746     {
2747         DataType::BFloat16,
2748         DataType::Float32,
2749         DataType::Float16,
2750         DataType::QAsymmS8,
2751         DataType::QAsymmU8,
2752         DataType::QSymmS16
2753     };
2754 
2755     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2756                                   "Reference transpose: input is not a supported type.");
2757 
2758     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2759                                   "Reference transpose: output is not a supported type.");
2760 
2761     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2762                                   "Reference transpose: input and output types are mismatched.");
2763 
2764     return supported;
2765 }
2766 
IsUnidirectionalSequenceLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const UnidirectionalSequenceLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const2767 bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2768         const TensorInfo& input,
2769         const TensorInfo& outputStateIn,
2770         const TensorInfo& cellStateIn,
2771         const TensorInfo& outputStateOut,
2772         const TensorInfo& cellStateOut,
2773         const TensorInfo& output,
2774         const UnidirectionalSequenceLstmDescriptor& descriptor,
2775         const LstmInputParamsInfo& paramsInfo,
2776         Optional<std::string&> reasonIfUnsupported) const
2777 {
2778     IgnoreUnused(descriptor);
2779     IgnoreUnused(paramsInfo);
2780     IgnoreUnused(outputStateIn);
2781     IgnoreUnused(cellStateIn);
2782     IgnoreUnused(outputStateOut);
2783     IgnoreUnused(cellStateOut);
2784     bool supported = true;
2785 
2786     std::array<DataType, 2> supportedTypes =
2787     {
2788         DataType::Float32,
2789         DataType::QAsymmS8
2790     };
2791 
2792     std::array<DataType, 2> supportedWeightTypes =
2793     {
2794         DataType::Float32,
2795         DataType::QAsymmS8
2796     };
2797 
2798     std::array<DataType, 3> supportedBiasTypes =
2799     {
2800         DataType::Float32,
2801         DataType::QAsymmS8,
2802         DataType::Signed32
2803     };
2804 
2805     // check inputs and outputs
2806     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2807                                   "Reference UnidirectionalSequenceLstm: input is not a supported type.");
2808     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2809                                   "Reference UnidirectionalSequenceLstm: output is not a supported type.");
2810 
2811     // check layer parameters
2812     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2813                                   reasonIfUnsupported,
2814                                   "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2815                                   "is not a supported type.");
2816     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2817                                   reasonIfUnsupported,
2818                                   "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2819     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2820                                   reasonIfUnsupported,
2821                                   "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2822                                   "is not a supported type.");
2823     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2824                                   reasonIfUnsupported,
2825                                   "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2826                                   "is not a supported type.");
2827     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2828                                   reasonIfUnsupported,
2829                                   "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2830                                   "is not a supported type.");
2831     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2832                                   reasonIfUnsupported,
2833                                   "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2834                                   "is not a supported type.");
2835 
2836     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2837                                   "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2838     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2839                                   "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2840     supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2841                                   "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
2842     if (!descriptor.m_CifgEnabled)
2843     {
2844         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2845                                       reasonIfUnsupported,
2846                                       "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2847                                       "is not a supported type.");
2848         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2849                                       reasonIfUnsupported,
2850                                       "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2851                                       "is not a supported type.");
2852         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2853                                       "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
2854         if (descriptor.m_PeepholeEnabled)
2855         {
2856             supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2857                                           reasonIfUnsupported,
2858                                           "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2859                                           "is not a supported type.");
2860         }
2861     }
2862     if (descriptor.m_PeepholeEnabled)
2863     {
2864         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2865                                       reasonIfUnsupported,
2866                                       "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2867                                       "is not a supported type.");
2868         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2869                                       reasonIfUnsupported,
2870                                       "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2871                                       "is not a supported type.");
2872     }
2873     if (descriptor.m_ProjectionEnabled)
2874     {
2875         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2876                                       reasonIfUnsupported,
2877                                       "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2878                                       "is not a supported type.");
2879         if (paramsInfo.m_ProjectionBias != nullptr)
2880         {
2881             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2882                                           "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2883                                           "are mismatched");
2884         }
2885     }
2886     if (descriptor.m_LayerNormEnabled)
2887     {
2888         if (!descriptor.m_CifgEnabled)
2889         {
2890             supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2891                                           reasonIfUnsupported,
2892                                           "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2893                                           "is not a supported type.");
2894         }
2895         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2896                                       reasonIfUnsupported,
2897                                       "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2898                                       "is not a supported type.");
2899         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2900                                       reasonIfUnsupported,
2901                                       "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2902                                       "is not a supported type.");
2903         supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2904                                       reasonIfUnsupported,
2905                                       "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2906                                       "is not a supported type.");
2907     }
2908 
2909     return supported;
2910 }
2911 
2912 } // namespace armnn
2913