1 //
2 // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefLayerSupport.hpp"
7
8 #include <armnn/TypesUtils.hpp>
9 #include <armnn/Types.hpp>
10 #include <armnn/utility/IgnoreUnused.hpp>
11 #include <armnn/utility/NumericCast.hpp>
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13
14 #include <LayerSupportCommon.hpp>
15 #include <backendsCommon/LayerSupportRules.hpp>
16
17 #include <vector>
18 #include <array>
19
20 namespace armnn
21 {
22
23 namespace
24 {
25
26 template<typename Float32Func, typename Uint8Func, typename ... Params>
IsSupportedForDataTypeRef(Optional<std::string &> reasonIfUnsupported,DataType dataType,Float32Func floatFuncPtr,Uint8Func uint8FuncPtr,Params &&...params)27 bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32 {
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
38 &FalseFunc<Params...>,
39 &FalseFunc<Params...>,
40 std::forward<Params>(params)...);
41 }
42
43 } // anonymous namespace
44
45 namespace
46 {
47
CreateIncorrectDimensionsErrorMsg(unsigned int expected,unsigned int actual,std::string & layerStr,std::string & tensorName)48 std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52 {
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57 }
58
59 } // anonymous namespace
60
IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> & lstmParamsInfo,const Optional<QuantizedLstmInputParamsInfo> & quantizedLstmInputParamsInfo,Optional<std::string &> reasonIfUnsupported) const61 bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67 {
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
82 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
88 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Comparison:
104 return IsComparisonSupported(infos[0],
105 infos[1],
106 infos[2],
107 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
108 reasonIfUnsupported);
109 case LayerType::Concat:
110 {
111 std::vector<const TensorInfo*> inputInfos;
112 for (uint32_t i = 0; i < (infos.size() - 1); i++)
113 {
114 inputInfos.push_back(&infos[i]);
115 }
116 return IsConcatSupported(inputInfos,
117 infos[infos.size() - 1],
118 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
119 reasonIfUnsupported);
120 }
121 case LayerType::Constant:
122 return IsConstantSupported(infos[0], reasonIfUnsupported);
123 case LayerType::ConvertFp16ToFp32:
124 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
125 case LayerType::ConvertFp32ToFp16:
126 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
127 case LayerType::Convolution2d:
128 {
129 if (infos.size() != 4)
130 {
131 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
132 "TensorInfos should be of format: {input, output, weights, biases}.");
133 }
134
135 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
136 if (infos[3] == TensorInfo())
137 {
138 return IsConvolution2dSupported(infos[0],
139 infos[1],
140 desc,
141 infos[2],
142 EmptyOptional(),
143 reasonIfUnsupported);
144 }
145 else
146 {
147 return IsConvolution2dSupported(infos[0],
148 infos[1],
149 desc,
150 infos[2],
151 infos[3],
152 reasonIfUnsupported);
153 }
154 }
155 case LayerType::DepthToSpace:
156 return IsDepthToSpaceSupported(infos[0],
157 infos[1],
158 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
159 reasonIfUnsupported);
160 case LayerType::DepthwiseConvolution2d:
161 {
162 if (infos.size() != 4)
163 {
164 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
165 "TensorInfos should be of format: {input, output, weights, biases}.");
166 }
167
168 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
169 if (infos[3] == TensorInfo())
170 {
171 return IsDepthwiseConvolutionSupported(infos[0],
172 infos[1],
173 desc,
174 infos[2],
175 EmptyOptional(),
176 reasonIfUnsupported);
177 }
178 else
179 {
180 return IsDepthwiseConvolutionSupported(infos[0],
181 infos[1],
182 desc,
183 infos[2],
184 infos[3],
185 reasonIfUnsupported);
186 }
187 }
188 case LayerType::Dequantize:
189 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
190 case LayerType::Division:
191 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
192 case LayerType::ElementwiseBinary:
193 {
194 std::array<DataType, 7> supportedTypes =
195 {
196 DataType::Float32,
197 DataType::Float16,
198 DataType::QAsymmS8,
199 DataType::QAsymmU8,
200 DataType::QSymmS16,
201 DataType::Signed32
202 };
203
204 bool supported = true;
205 supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
206 "Reference elementwise unary: input type not supported");
207
208 supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
209 "Reference elementwise unary: input type not supported");
210
211 supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
212 "Reference elementwise unary: output type not supported");
213
214 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
215 "Reference elementwise unary: input types not matching");
216
217 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
218 "Reference elementwise unary: input and output types not matching");
219
220 return supported;
221 }
222 case LayerType::ElementwiseUnary:
223 return IsElementwiseUnarySupported(infos[0],
224 infos[1],
225 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
226 reasonIfUnsupported);
227 case LayerType::Fill:
228 return IsFillSupported(infos[0],
229 infos[1],
230 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
231 reasonIfUnsupported);
232 case LayerType::Floor:
233 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
234 case LayerType::FullyConnected:
235 return IsFullyConnectedSupported(infos[0],
236 infos[1],
237 infos[2],
238 infos[3],
239 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
240 reasonIfUnsupported);
241 case LayerType::Gather:
242 return IsGatherSupported(infos[0],
243 infos[1],
244 infos[2],
245 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
246 reasonIfUnsupported);
247 case LayerType::GatherNd:
248 return IsGatherNdSupported(infos[0],
249 infos[1],
250 infos[2],
251 reasonIfUnsupported);
252 case LayerType::Input:
253 return IsInputSupported(infos[0], reasonIfUnsupported);
254 case LayerType::InstanceNormalization:
255 return IsInstanceNormalizationSupported(infos[0],
256 infos[1],
257 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
258 (&descriptor)),
259 reasonIfUnsupported);
260 case LayerType::L2Normalization:
261 return IsL2NormalizationSupported(infos[0],
262 infos[1],
263 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
264 reasonIfUnsupported);
265 case LayerType::LogicalBinary:
266 return IsLogicalBinarySupported(infos[0],
267 infos[1],
268 infos[2],
269 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
270 reasonIfUnsupported);
271 case LayerType::LogSoftmax:
272 return IsLogSoftmaxSupported(infos[0],
273 infos[1],
274 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
275 reasonIfUnsupported);
276 case LayerType::Lstm:
277 return IsLstmSupported(infos[0],
278 infos[1],
279 infos[2],
280 infos[3],
281 infos[4],
282 infos[5],
283 infos[6],
284 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
285 lstmParamsInfo.value(),
286 reasonIfUnsupported);
287 case LayerType::QLstm:
288 return IsQLstmSupported(infos[0],
289 infos[1],
290 infos[2],
291 infos[3],
292 infos[4],
293 infos[5],
294 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
295 lstmParamsInfo.value(),
296 reasonIfUnsupported);
297 case LayerType::Maximum:
298 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
299 case LayerType::Mean:
300 return IsMeanSupported(infos[0],
301 infos[1],
302 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
303 reasonIfUnsupported);
304 case LayerType::Minimum:
305 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
306 case LayerType::Multiplication:
307 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
308 case LayerType::Normalization:
309 return IsNormalizationSupported(infos[0],
310 infos[1],
311 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
312 reasonIfUnsupported);
313 case LayerType::Output:
314 return IsOutputSupported(infos[0], reasonIfUnsupported);
315 case LayerType::Pad:
316 return IsPadSupported(infos[0],
317 infos[1],
318 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
319 reasonIfUnsupported);
320 case LayerType::Permute:
321 return IsPermuteSupported(infos[0],
322 infos[1],
323 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
324 reasonIfUnsupported);
325 case LayerType::Pooling2d:
326 return IsPooling2dSupported(infos[0],
327 infos[1],
328 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
329 reasonIfUnsupported);
330 case LayerType::Prelu:
331 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
332 case LayerType::Quantize:
333 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
334 case LayerType::Reshape:
335 return IsReshapeSupported(infos[0],
336 infos[1],
337 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
338 reasonIfUnsupported);
339 case LayerType::Resize:
340 return IsResizeSupported(infos[0],
341 infos[1],
342 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
343 reasonIfUnsupported);
344 case LayerType::Reduce:
345 return IsReduceSupported(infos[0],
346 infos[1],
347 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
348 reasonIfUnsupported);
349 case LayerType::Slice:
350 return IsSliceSupported(infos[0],
351 infos[1],
352 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
353 reasonIfUnsupported);
354 case LayerType::Softmax:
355 return IsSoftmaxSupported(infos[0],
356 infos[1],
357 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
358 reasonIfUnsupported);
359 case LayerType::SpaceToBatchNd:
360 return IsSpaceToBatchNdSupported(infos[0],
361 infos[1],
362 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
363 reasonIfUnsupported);
364 case LayerType::SpaceToDepth:
365 return IsSpaceToDepthSupported(infos[0],
366 infos[1],
367 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
368 reasonIfUnsupported);
369 case LayerType::Splitter:
370 {
371 std::vector<TensorInfo> outputInfos;
372 for (uint32_t i = 1; i < infos.size(); i++)
373 {
374 outputInfos.push_back(infos[i]);
375 }
376 return IsSplitterSupported(infos[0],
377 {outputInfos.begin(), outputInfos.end()},
378 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
379 reasonIfUnsupported);
380 }
381 case LayerType::Stack:
382 {
383 std::vector<const TensorInfo*> inputInfos;
384 for (uint32_t i = 0; i < infos.size() - 1; i++)
385 {
386 inputInfos.push_back(&infos[i]);
387 }
388 return IsStackSupported(inputInfos,
389 infos[infos.size() - 1],
390 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
391 reasonIfUnsupported);
392 }
393 case LayerType::StridedSlice:
394 return IsStridedSliceSupported(infos[0],
395 infos[1],
396 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
397 reasonIfUnsupported);
398 case LayerType::Subtraction:
399 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
400 case LayerType::Transpose:
401 return IsTransposeSupported(infos[0],
402 infos[1],
403 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
404 reasonIfUnsupported);
405 case LayerType::TransposeConvolution2d:
406 {
407 if (infos.size() != 4)
408 {
409 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
410 "TensorInfos should be of format: {input, output, weights, biases}.");
411 }
412
413 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
414 if (infos[3] == TensorInfo())
415 {
416 return IsTransposeConvolution2dSupported(infos[0],
417 infos[1],
418 desc,
419 infos[2],
420 EmptyOptional(),
421 reasonIfUnsupported);
422 }
423 else
424 {
425 return IsTransposeConvolution2dSupported(infos[0],
426 infos[1],
427 desc,
428 infos[2],
429 infos[3],
430 reasonIfUnsupported);
431 }
432 }
433 case LayerType::Cast:
434 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
435 case LayerType::ChannelShuffle:
436 return IsChannelShuffleSupported(infos[0],
437 infos[1],
438 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
439 reasonIfUnsupported);
440 case LayerType::Convolution3d:
441 {
442 if (infos.size() != 4)
443 {
444 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
445 "TensorInfos should be of format: {input, output, weights, biases}.");
446 }
447
448 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
449 if (infos[3] == TensorInfo())
450 {
451 return IsConvolution3dSupported(infos[0],
452 infos[1],
453 desc,
454 infos[2],
455 EmptyOptional(),
456 reasonIfUnsupported);
457 }
458 else
459 {
460 return IsConvolution3dSupported(infos[0],
461 infos[1],
462 desc,
463 infos[2],
464 infos[3],
465 reasonIfUnsupported);
466 }
467 }
468 case LayerType::Debug:
469 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
470 case LayerType::DetectionPostProcess:
471 return IsDetectionPostProcessSupported(infos[0],
472 infos[1],
473 infos[2],
474 infos[3],
475 infos[4],
476 infos[5],
477 infos[6],
478 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
479 (&descriptor)),
480 reasonIfUnsupported);
481 case LayerType::FakeQuantization:
482 return IsFakeQuantizationSupported(infos[0],
483 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
484 reasonIfUnsupported);
485 case LayerType::MemCopy:
486 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
487 case LayerType::Rank:
488 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
489 case LayerType::Shape:
490 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
491 case LayerType::UnidirectionalSequenceLstm:
492 {
493 if (infos.size() != 6)
494 {
495 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
496 "should be of format: {input, outputStateIn, cellStateIn, "
497 "hiddenStateOutputVal, cellStateOutputVal, output}");
498 }
499 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
500 return IsUnidirectionalSequenceLstmSupported(infos[0],
501 infos[1],
502 infos[2],
503 infos[3],
504 infos[4],
505 infos[5],
506 desc,
507 lstmParamsInfo.value(),
508 reasonIfUnsupported);
509 }
510 case LayerType::Pooling3d:
511 return IsPooling3dSupported(infos[0],
512 infos[1],
513 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
514 reasonIfUnsupported);
515 case LayerType::Map:
516 return true;
517 case LayerType::Unmap:
518 return true;
519 case LayerType::MemImport:
520 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
521 case LayerType::Merge:
522 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
523 case LayerType::QuantizedLstm:
524 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
525 infos[1],
526 infos[2],
527 infos[3],
528 infos[4],
529 quantizedLstmInputParamsInfo.value(),
530 reasonIfUnsupported);
531 default:
532 // layers not supported in neon by default:
533 // precompiled, standin, switch
534 return false;
535 }
536 }
537
IsActivationSupported(const TensorInfo & input,const TensorInfo & output,const ActivationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const538 bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
539 const TensorInfo& output,
540 const ActivationDescriptor& descriptor,
541 Optional<std::string&> reasonIfUnsupported) const
542 {
543 bool supported = true;
544
545 // Define supported types.
546 std::array<DataType,6> supportedTypes = {
547 DataType::Float32,
548 DataType::Float16,
549 DataType::QAsymmS8,
550 DataType::QAsymmU8,
551 DataType::QSymmS16
552 };
553
554 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
555 "Reference activation: input type not supported.");
556
557 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
558 "Reference activation: output type not supported.");
559
560 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
561 "Reference activation: input and output types mismatched.");
562
563 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
564 "Reference activation: input and output shapes are of different rank.");
565
566
567 struct ActivationFunctionSupported : public Rule
568 {
569 ActivationFunctionSupported(const ActivationDescriptor& desc)
570 {
571 switch(desc.m_Function)
572 {
573 case ActivationFunction::Abs:
574 case ActivationFunction::BoundedReLu:
575 case ActivationFunction::Elu:
576 case ActivationFunction::HardSwish:
577 case ActivationFunction::LeakyReLu:
578 case ActivationFunction::Linear:
579 case ActivationFunction::ReLu:
580 case ActivationFunction::Sigmoid:
581 case ActivationFunction::SoftReLu:
582 case ActivationFunction::Sqrt:
583 case ActivationFunction::Square:
584 case ActivationFunction::TanH:
585 {
586 m_Res = true;
587 break;
588 }
589 default:
590 {
591 m_Res = false;
592 break;
593 }
594 }
595 }
596 };
597
598 // Function is supported
599 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
600 "Reference activation: function not supported.");
601
602 return supported;
603 }
604
IsAdditionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const605 bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
606 const TensorInfo& input1,
607 const TensorInfo& output,
608 Optional<std::string&> reasonIfUnsupported) const
609 {
610 bool supported = true;
611
612 std::array<DataType,7> supportedTypes = {
613 DataType::Float32,
614 DataType::Float16,
615 DataType::QAsymmS8,
616 DataType::QAsymmU8,
617 DataType::QSymmS16,
618 DataType::Signed32
619 };
620
621 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
622 "Reference addition: input 0 is not a supported type.");
623
624 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
625 "Reference addition: input 1 is not a supported type.");
626
627 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
628 "Reference addition: output is not a supported type.");
629
630 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
631 "Reference addition: input 0 and Input 1 types are mismatched");
632
633 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
634 "Reference addition: input and output types are mismatched");
635
636 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
637 "Reference addition: shapes are not suitable for implicit broadcast.");
638
639 return supported;
640 }
641
IsArgMinMaxSupported(const armnn::TensorInfo & input,const armnn::TensorInfo & output,const armnn::ArgMinMaxDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const642 bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
643 const armnn::ArgMinMaxDescriptor &descriptor,
644 armnn::Optional<std::string &> reasonIfUnsupported) const
645 {
646 IgnoreUnused(descriptor);
647
648 std::array<DataType, 8> supportedInputTypes =
649 {
650 DataType::Float16,
651 DataType::Float32,
652 DataType::QAsymmS8,
653 DataType::QAsymmU8,
654 DataType::QSymmS16,
655 DataType::Signed32,
656 DataType::Signed64
657 };
658
659 std::array<DataType,2> supportedOutputTypes = {
660 DataType::Signed32,
661 DataType::Signed64
662 };
663
664 bool supported = true;
665
666 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
667 "Reference ArgMinMax: input is not a supported type.");
668 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
669 "Reference ArgMinMax: output type not supported");
670
671 return supported;
672 }
673
IsBatchMatMulSupported(const TensorInfo & inputX,const TensorInfo & inputY,const TensorInfo & output,const BatchMatMulDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const674 bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
675 const TensorInfo& inputY,
676 const TensorInfo& output,
677 const BatchMatMulDescriptor& descriptor,
678 Optional<std::string &> reasonIfUnsupported) const
679 {
680 IgnoreUnused(descriptor);
681
682 std::array<DataType, 6> supportedTypes =
683 {
684 DataType::Float16,
685 DataType::Float32,
686 DataType::QAsymmS8,
687 DataType::QAsymmU8,
688 DataType::QSymmS16
689 };
690
691 bool supported = true;
692
693 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
694 "Reference batch matrix multiplication: input X is not a supported type");
695
696 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
697 "Reference batch matrix multiplication: input Y is not a supported type");
698
699 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
700 "Reference batch matrix multiplication: output is not a supported type");
701
702 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
703 "Reference batch matrix multiplication: input X and input Y types are mismatched");
704
705 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
706 "Reference batch matrix multiplication: inputs and output types are mismatched");
707
708 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
709 reasonIfUnsupported,
710 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
711
712 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
713 reasonIfUnsupported,
714 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
715
716 return supported;
717 }
718
IsBatchNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & mean,const TensorInfo & variance,const TensorInfo & beta,const TensorInfo & gamma,const BatchNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const719 bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
720 const TensorInfo& output,
721 const TensorInfo& mean,
722 const TensorInfo& variance,
723 const TensorInfo& beta,
724 const TensorInfo& gamma,
725 const BatchNormalizationDescriptor& descriptor,
726 Optional<std::string&> reasonIfUnsupported) const
727 {
728 IgnoreUnused(descriptor);
729
730 std::array<DataType, 6> supportedTypes =
731 {
732 DataType::Float32,
733 DataType::Float16,
734 DataType::QAsymmS8,
735 DataType::QAsymmU8,
736 DataType::QSymmS16
737 };
738
739 bool supported = true;
740
741 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
742 "Reference batch normalization: input is not a supported type.");
743
744 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
745 "Reference batch normalization: output is not a supported type.");
746
747 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
748 "Reference batch normalization: input and output types are mismatched");
749
750 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
751 "Reference batch normalization: mean is not a supported type.");
752
753 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
754 "Reference batch normalization: variance is not a supported type.");
755
756 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
757 "Reference batch normalization: beta is not a supported type.");
758
759 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
760 "Reference batch normalization: gamma is not a supported type.");
761
762 return supported;
763 }
764
IsBatchToSpaceNdSupported(const TensorInfo & input,const TensorInfo & output,const BatchToSpaceNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const765 bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
766 const TensorInfo& output,
767 const BatchToSpaceNdDescriptor& descriptor,
768 Optional<std::string&> reasonIfUnsupported) const
769 {
770 IgnoreUnused(descriptor);
771
772 bool supported = true;
773
774 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
775 std::string inputTensorStr = "input";
776 std::string outputTensorStr = "output";
777
778 // Define supported types.
779 std::array<DataType,6> supportedTypes =
780 {
781 DataType::Float32,
782 DataType::Float16,
783 DataType::QAsymmS8,
784 DataType::QAsymmU8,
785 DataType::QSymmS16
786 };
787
788 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
789 "Reference BatchToSpaceNd: input type not supported.");
790
791 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
792 "Reference BatchToSpaceNd: output type not supported.");
793
794 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
795 "Reference BatchToSpaceNd: input and output types mismatched.");
796
797 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
798 reasonIfUnsupported,
799 CreateIncorrectDimensionsErrorMsg(4,
800 output.GetNumDimensions(),
801 batchToSpaceNdLayerStr,
802 outputTensorStr).data());
803
804 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
805 reasonIfUnsupported,
806 CreateIncorrectDimensionsErrorMsg(4,
807 input.GetNumDimensions(),
808 batchToSpaceNdLayerStr,
809 inputTensorStr).data());
810
811 return supported;
812 }
813
IsCastSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const814 bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
815 const TensorInfo& output,
816 Optional<std::string&> reasonIfUnsupported) const
817 {
818 std::array<DataType, 9> supportedInputTypes =
819 {
820 DataType::Float32,
821 DataType::Float16,
822 DataType::QSymmS8,
823 DataType::QAsymmS8,
824 DataType::QAsymmU8,
825 DataType::QSymmS16,
826 DataType::Signed32
827 };
828
829 bool supported = true;
830 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
831 "Reference cast: input is not a supported type");
832
833
834 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
835 "Reference cast: output is not a supported type");
836
837 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
838 "Reference cast: input and output shapes have different number of total elements");
839
840 return supported;
841 }
842
IsChannelShuffleSupported(const TensorInfo & input,const TensorInfo & output,const ChannelShuffleDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const843 bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
844 const TensorInfo& output,
845 const ChannelShuffleDescriptor& descriptor,
846 Optional<std::string&> reasonIfUnsupported) const
847 {
848 IgnoreUnused(descriptor);
849 bool supported = true;
850
851 // Define supported output and inputs types.
852 std::array<DataType, 7> supportedTypes =
853 {
854 DataType::Float32,
855 DataType::Float16,
856 DataType::QAsymmS8,
857 DataType::QAsymmU8,
858 DataType::QSymmS8,
859 DataType::QSymmS16
860 };
861
862 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
863 "Reference ChannelShuffle: input is not a supported type.");
864
865 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
866 "Reference ChannelShuffle: output is not a supported type.");
867
868 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
869 "Reference ChannelShuffle: input and output types are mismatched.");
870
871 return supported;
872 }
873
874
IsComparisonSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const ComparisonDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const875 bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
876 const TensorInfo& input1,
877 const TensorInfo& output,
878 const ComparisonDescriptor& descriptor,
879 Optional<std::string&> reasonIfUnsupported) const
880 {
881 IgnoreUnused(descriptor);
882 std::array<DataType, 8> supportedInputTypes =
883 {
884 DataType::Boolean,
885 DataType::Float32,
886 DataType::Float16,
887 DataType::QAsymmS8,
888 DataType::QAsymmU8,
889 DataType::QSymmS16,
890 DataType::Signed32
891 };
892
893 bool supported = true;
894 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
895 "Reference comparison: input 0 is not a supported type");
896
897 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
898 "Reference comparison: input 0 and Input 1 types are mismatched");
899
900 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
901 "Reference comparison: output is not of type Boolean");
902
903 return supported;
904 }
905
IsConcatSupported(const std::vector<const TensorInfo * > inputs,const TensorInfo & output,const OriginsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const906 bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
907 const TensorInfo& output,
908 const OriginsDescriptor& descriptor,
909 Optional<std::string&> reasonIfUnsupported) const
910 {
911 IgnoreUnused(descriptor);
912
913 bool supported = true;
914 std::array<DataType,7> supportedTypes =
915 {
916 DataType::Float32,
917 DataType::Float16,
918 DataType::QAsymmS8,
919 DataType::QAsymmU8,
920 DataType::QSymmS16,
921 DataType::Signed32
922 };
923
924 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
925 "Reference concatenation: output type not supported");
926 for (const TensorInfo* input : inputs)
927 {
928 ARMNN_ASSERT(input != nullptr);
929 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
930 "Reference concatenation: input type not supported");
931
932 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
933 "Reference concatenation: input and output types mismatched.");
934 }
935
936 return supported;
937 }
938
IsConstantSupported(const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const939 bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
940 Optional<std::string&> reasonIfUnsupported) const
941 {
942 std::array<DataType,8> supportedTypes =
943 {
944 DataType::Float16,
945 DataType::Float32,
946 DataType::QAsymmS8,
947 DataType::QAsymmU8,
948 DataType::QSymmS8,
949 DataType::QSymmS16,
950 DataType::Signed32
951 };
952
953 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
954 "Reference constant: output is not a supported type.");
955 }
956
IsConvertFp16ToFp32Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const957 bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
958 const TensorInfo& output,
959 Optional<std::string&> reasonIfUnsupported) const
960 {
961 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
962 input.GetDataType(),
963 &TrueFunc<>,
964 &FalseInputFuncF32<>,
965 &FalseFuncU8<>,
966 &FalseFuncI32<>,
967 &FalseFuncU8<>) &&
968 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
969 output.GetDataType(),
970 &FalseOutputFuncF16<>,
971 &TrueFunc<>,
972 &FalseFuncU8<>,
973 &FalseFuncI32<>,
974 &FalseFuncU8<>));
975 }
976
IsConvertFp32ToFp16Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const977 bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
978 const TensorInfo& output,
979 Optional<std::string&> reasonIfUnsupported) const
980 {
981 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
982 input.GetDataType(),
983 &FalseInputFuncF16<>,
984 &TrueFunc<>,
985 &FalseFuncU8<>,
986 &FalseFuncI32<>,
987 &FalseFuncU8<>) &&
988 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
989 output.GetDataType(),
990 &TrueFunc<>,
991 &FalseOutputFuncF32<>,
992 &FalseFuncU8<>,
993 &FalseFuncI32<>,
994 &FalseFuncU8<>));
995 }
996
IsConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const997 bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
998 const TensorInfo& output,
999 const Convolution2dDescriptor& descriptor,
1000 const TensorInfo& weights,
1001 const Optional<TensorInfo>& biases,
1002 Optional<std::string&> reasonIfUnsupported) const
1003 {
1004 bool supported = true;
1005
1006 // Define supported types.
1007 std::array<DataType,7> supportedTypes =
1008 {
1009 DataType::Float32,
1010 DataType::Float16,
1011 DataType::QAsymmS8,
1012 DataType::QAsymmU8,
1013 DataType::QSymmS8,
1014 DataType::QSymmS16
1015 };
1016
1017 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1018 "Reference Convolution2d: input is not a supported type.");
1019
1020 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1021 "Reference Convolution2d: output is not a supported type.");
1022
1023 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1024 "Reference Convolution2d: input and output types mismatched.");
1025
1026
1027 const DataType inputType = input.GetDataType();
1028 if (IsQuantized8BitType(inputType))
1029 {
1030 std::array<DataType, 3> supportedWeightTypes =
1031 {
1032 DataType::QAsymmS8,
1033 DataType::QAsymmU8,
1034 DataType::QSymmS8
1035 };
1036
1037 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1038 "Reference Convolution2d: weights type not supported for quantized input.");
1039 }
1040 else
1041 {
1042 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1043 "Reference Convolution2d: weights is not a supported type.");
1044
1045 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1046 "Reference Convolution2d: input and weights types mismatched.");
1047 }
1048
1049 if (biases.has_value())
1050 {
1051 std::array<DataType,4> biasesSupportedTypes =
1052 {
1053 DataType::Float32,
1054 DataType::Float16,
1055 DataType::Signed32
1056 };
1057
1058 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1059 "Reference Convolution2d: biases is not a supported type.");
1060 }
1061 IgnoreUnused(descriptor);
1062
1063 return supported;
1064 }
1065
IsConvolution3dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution3dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1066 bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1067 const TensorInfo& output,
1068 const Convolution3dDescriptor& descriptor,
1069 const TensorInfo& weights,
1070 const Optional<TensorInfo>& biases,
1071 Optional<std::string&> reasonIfUnsupported) const
1072 {
1073 bool supported = true;
1074
1075 // Define supported types.
1076 std::array<DataType,7> supportedTypes =
1077 {
1078 DataType::Float32,
1079 DataType::Float16,
1080 DataType::QAsymmS8,
1081 DataType::QAsymmU8,
1082 DataType::QSymmS8,
1083 DataType::QSymmS16
1084 };
1085
1086 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1087 "Reference Convolution3d: input is not a supported type.");
1088
1089 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1090 "Reference Convolution3d: output is not a supported type.");
1091
1092 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1093 "Reference Convolution3d: input and output types mismatched.");
1094
1095 const DataType inputType = input.GetDataType();
1096 if (IsQuantized8BitType(inputType))
1097 {
1098 std::array<DataType, 3> supportedWeightTypes =
1099 {
1100 DataType::QAsymmS8,
1101 DataType::QAsymmU8,
1102 DataType::QSymmS8
1103 };
1104
1105 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1106 "Reference Convolution3d: weights type not supported for quantized input.");
1107 }
1108 else
1109 {
1110 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1111 "Reference Convolution3d: weights is not a supported type.");
1112
1113 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1114 "Reference Convolution3d: input and weights types mismatched.");
1115 }
1116
1117 if (biases.has_value())
1118 {
1119 std::array<DataType,4> biasesSupportedTypes =
1120 {
1121 DataType::Float32,
1122 DataType::Float16,
1123 DataType::Signed32
1124 };
1125
1126 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1127 "Reference Convolution3d: biases is not a supported type.");
1128 }
1129 IgnoreUnused(descriptor);
1130
1131 return supported;
1132 }
1133
IsDebugSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1134 bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1135 const TensorInfo& output,
1136 Optional<std::string&> reasonIfUnsupported) const
1137 {
1138 bool supported = true;
1139
1140 std::array<DataType, 8> supportedTypes =
1141 {
1142 DataType::BFloat16,
1143 DataType::Float16,
1144 DataType::Float32,
1145 DataType::QAsymmS8,
1146 DataType::QAsymmU8,
1147 DataType::QSymmS8,
1148 DataType::QSymmS16,
1149 DataType::Signed32
1150 };
1151
1152 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1153 "Reference for Debug layer: input type not supported");
1154
1155 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1156 "Reference for Debug layer: output type not supported");
1157
1158 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1159 "Reference for Debug layer: input and output types are mismatched");
1160
1161 return supported;
1162 }
1163
IsDepthToSpaceSupported(const TensorInfo & input,const TensorInfo & output,const DepthToSpaceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1164 bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1165 const TensorInfo& output,
1166 const DepthToSpaceDescriptor& descriptor,
1167 Optional<std::string&> reasonIfUnsupported) const
1168 {
1169 IgnoreUnused(descriptor);
1170 bool supported = true;
1171
1172 std::array<DataType,6> supportedTypes =
1173 {
1174 DataType::Float32,
1175 DataType::Float16,
1176 DataType::QAsymmS8,
1177 DataType::QAsymmU8,
1178 DataType::QSymmS16
1179 };
1180
1181 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1182 "Reference DepthToSpace: input type not supported");
1183
1184 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1185 "Reference DepthToSpace: output type not supported");
1186
1187 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1188 "Reference DepthToSpace: input and output types are mismatched");
1189
1190 return supported;
1191 }
1192
IsDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1193 bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1194 const TensorInfo& output,
1195 const DepthwiseConvolution2dDescriptor& descriptor,
1196 const TensorInfo& weights,
1197 const Optional<TensorInfo>& biases,
1198 Optional<std::string&> reasonIfUnsupported) const
1199 {
1200 IgnoreUnused(descriptor);
1201 bool supported = true;
1202
1203 // Define supported types.
1204 std::array<DataType,7> supportedTypes =
1205 {
1206 DataType::Float32,
1207 DataType::Float16,
1208 DataType::QAsymmS8,
1209 DataType::QAsymmU8,
1210 DataType::QSymmS8,
1211 DataType::QSymmS16
1212 };
1213
1214 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1215 "Reference DepthwiseConvolution2d: input is not a supported type.");
1216
1217 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1218 "Reference DepthwiseConvolution2d: output is not a supported type.");
1219
1220 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1221 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1222
1223 const DataType inputType = input.GetDataType();
1224 if (IsQuantized8BitType(inputType))
1225 {
1226 std::array<DataType, 3> supportedWeightTypes =
1227 {
1228 DataType::QAsymmS8,
1229 DataType::QAsymmU8,
1230 DataType::QSymmS8,
1231 };
1232
1233 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1234 "Reference DepthwiseConvolution2d: weights type not supported for "
1235 "quantized input.");
1236 }
1237 else
1238 {
1239 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1240 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1241
1242 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1243 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1244 }
1245
1246 if (biases.has_value())
1247 {
1248 std::array<DataType,4> biasesSupportedTypes =
1249 {
1250 DataType::Float32,
1251 DataType::Float16,
1252 DataType::Signed32
1253 };
1254 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1255 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1256 }
1257
1258 return supported;
1259
1260 }
1261
IsDequantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1262 bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1263 const TensorInfo& output,
1264 Optional<std::string&> reasonIfUnsupported) const
1265 {
1266 bool supported = true;
1267
1268 std::array<DataType,5> supportedInputTypes = {
1269 DataType::QAsymmS8,
1270 DataType::QAsymmU8,
1271 DataType::QSymmS8,
1272 DataType::QSymmS16,
1273 DataType::Float16
1274 };
1275
1276 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1277 "Reference for Dequantize layer: input type not supported.");
1278
1279 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
1280 "Reference for Dequantize layer: per-axis quantized input not supported.");
1281
1282 std::array<DataType,3> supportedOutputTypes = {
1283 DataType::Float32,
1284 DataType::Float16
1285 };
1286
1287 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1288 "Reference for Dequantize layer: output type not supported.");
1289
1290 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1291 "Reference for Dequantize layer: input/output shapes have different num total "
1292 "elements.");
1293
1294 return supported;
1295 }
1296
IsDetectionPostProcessSupported(const TensorInfo & boxEncodings,const TensorInfo & scores,const TensorInfo & anchors,const TensorInfo & detectionBoxes,const TensorInfo & detectionClasses,const TensorInfo & detectionScores,const TensorInfo & numDetections,const DetectionPostProcessDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1297 bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1298 const TensorInfo& scores,
1299 const TensorInfo& anchors,
1300 const TensorInfo& detectionBoxes,
1301 const TensorInfo& detectionClasses,
1302 const TensorInfo& detectionScores,
1303 const TensorInfo& numDetections,
1304 const DetectionPostProcessDescriptor& descriptor,
1305 Optional<std::string&> reasonIfUnsupported) const
1306 {
1307 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
1308
1309 bool supported = true;
1310
1311 std::array<DataType,6> supportedInputTypes =
1312 {
1313 DataType::Float32,
1314 DataType::Float16,
1315 DataType::QAsymmS8,
1316 DataType::QAsymmU8,
1317 DataType::QSymmS16
1318 };
1319
1320 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
1321 "Reference DetectionPostProcess: input 0 is not a supported type.");
1322
1323 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
1324 "Reference DetectionPostProcess: input 1 is not a supported type.");
1325
1326 return supported;
1327 }
1328
IsDilatedDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1329 bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1330 const TensorInfo& output,
1331 const DepthwiseConvolution2dDescriptor& descriptor,
1332 const TensorInfo& weights,
1333 const Optional<TensorInfo>& biases,
1334 Optional<std::string&> reasonIfUnsupported) const
1335 {
1336 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
1337 }
1338
IsDivisionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1339 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
1340 const TensorInfo& input1,
1341 const TensorInfo& output,
1342 Optional<std::string&> reasonIfUnsupported) const
1343 {
1344 bool supported = true;
1345
1346 std::array<DataType,7> supportedTypes = {
1347 DataType::Float32,
1348 DataType::Float16,
1349 DataType::QAsymmS8,
1350 DataType::QAsymmU8,
1351 DataType::QSymmS16,
1352 DataType::Signed32
1353 };
1354
1355 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1356 "Reference division: input 0 is not a supported type.");
1357
1358 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1359 "Reference division: input 1 is not a supported type.");
1360
1361 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1362 "Reference division: output is not a supported type.");
1363
1364 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1365 "Reference division: input 0 and Input 1 types are mismatched");
1366
1367 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1368 "Reference division: input and output types are mismatched");
1369
1370 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1371 "Reference division: shapes are not suitable for implicit broadcast.");
1372
1373 return supported;
1374 }
1375
IsElementwiseUnarySupported(const TensorInfo & input,const TensorInfo & output,const ElementwiseUnaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1376 bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1377 const TensorInfo& output,
1378 const ElementwiseUnaryDescriptor& descriptor,
1379 Optional<std::string&> reasonIfUnsupported) const
1380 {
1381 IgnoreUnused(descriptor);
1382
1383 std::array<DataType, 7> supportedTypes =
1384 {
1385 DataType::Float32,
1386 DataType::Float16,
1387 DataType::QAsymmS8,
1388 DataType::QAsymmU8,
1389 DataType::QSymmS16,
1390 DataType::Signed32
1391 };
1392
1393 std::array<DataType, 1> logicalSupportedTypes =
1394 {
1395 DataType::Boolean
1396 };
1397
1398 bool supported = true;
1399
1400 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1401 {
1402 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1403 "Reference elementwise unary: input type not supported");
1404
1405 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1406 "Reference elementwise unary: output type not supported");
1407 }
1408 else
1409 {
1410 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1411 "Reference elementwise unary: input type not supported");
1412
1413 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1414 "Reference elementwise unary: output type not supported");
1415 }
1416
1417 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1418 "Reference elementwise unary: input and output types not matching");
1419
1420 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1421 "Reference elementwise unary: input and output shapes"
1422 "have different number of total elements");
1423
1424 return supported;
1425 }
1426
IsFakeQuantizationSupported(const TensorInfo & input,const FakeQuantizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1427 bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1428 const FakeQuantizationDescriptor& descriptor,
1429 Optional<std::string&> reasonIfUnsupported) const
1430 {
1431 IgnoreUnused(descriptor);
1432 bool supported = true;
1433
1434 std::array<DataType,1> supportedTypes =
1435 {
1436 DataType::Float32
1437 };
1438
1439 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1440 "Reference fake quantization: input type not supported.");
1441
1442 return supported;
1443 }
1444
IsFillSupported(const TensorInfo & input,const TensorInfo & output,const FillDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1445 bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1446 const TensorInfo& output,
1447 const FillDescriptor& descriptor,
1448 Optional<std::string&> reasonIfUnsupported) const
1449 {
1450 IgnoreUnused(descriptor);
1451 IgnoreUnused(output);
1452
1453 bool supported = true;
1454
1455 std::array<DataType,3> supportedTypes =
1456 {
1457 DataType::Float32,
1458 DataType::Float16,
1459 DataType::Signed32
1460 };
1461
1462 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
1463 "Reference Fill: input type not supported.");
1464
1465 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1466 "Reference Fill: output type not supported.");
1467 return supported;
1468 }
1469
IsFloorSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1470 bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1471 const TensorInfo& output,
1472 Optional<std::string&> reasonIfUnsupported) const
1473 {
1474 IgnoreUnused(output);
1475 bool supported = true;
1476
1477 std::array<DataType,3> supportedTypes =
1478 {
1479 DataType::Float32,
1480 DataType::Float16
1481 };
1482
1483 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1484 "Reference Floor: input type not supported.");
1485
1486 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1487 "Reference Floor: output type not supported.");
1488
1489 return supported;
1490 }
1491
IsFullyConnectedSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & weights,const TensorInfo & biases,const FullyConnectedDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1492 bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1493 const TensorInfo& output,
1494 const TensorInfo& weights,
1495 const TensorInfo& biases,
1496 const FullyConnectedDescriptor& descriptor,
1497 Optional<std::string&> reasonIfUnsupported) const
1498 {
1499 bool supported = true;
1500
1501 // Define supported types.
1502 std::array<DataType,6> supportedTypes =
1503 {
1504 DataType::Float32,
1505 DataType::Float16,
1506 DataType::QAsymmS8,
1507 DataType::QAsymmU8,
1508 DataType::QSymmS16
1509 };
1510
1511 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1512 "Reference Fully Connected: input type not supported.");
1513
1514 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1515 "Reference Fully Connected: output type not supported.");
1516
1517 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1518 "Reference Fully Connected: weights type not supported.");
1519
1520 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1521 "Reference Fully Connected: input and output types mismatched.");
1522
1523 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1524 "Reference Fully Connected: weights is not a supported type.");
1525
1526 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1527 "Reference Fully Connected: input and weights types mismatched.");
1528
1529 if (descriptor.m_BiasEnabled)
1530 {
1531 // Defined supported types for bias
1532 std::array<DataType, 5>
1533 supportedBiasTypes =
1534 {
1535 DataType::Float32,
1536 DataType::Float16,
1537 DataType::Signed32,
1538 DataType::QAsymmS8
1539 };
1540
1541 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1542 "Reference Fully Connected: bias type not supported.");
1543
1544 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1545 "Reference Fully Connected: bias and weight types mismatch.");
1546
1547 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1548 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1549
1550 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1551 "Reference Fully Connected: bias must have 1 dimension.");
1552
1553 }
1554
1555 return supported;
1556 }
1557
IsGatherNdSupported(const armnn::TensorInfo & input0,const armnn::TensorInfo & input1,const armnn::TensorInfo & output,armnn::Optional<std::string &> reasonIfUnsupported) const1558 bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1559 const armnn::TensorInfo& input1,
1560 const armnn::TensorInfo& output,
1561 armnn::Optional<std::string&> reasonIfUnsupported) const
1562 {
1563 bool supported = true;
1564 std::array<DataType,7> supportedTypes =
1565 {
1566 DataType::Float32,
1567 DataType::Float16,
1568 DataType::QAsymmS8,
1569 DataType::QAsymmU8,
1570 DataType::QSymmS16,
1571 DataType::Signed32
1572 };
1573
1574 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1575 "Reference GatherNd: input type not supported");
1576
1577 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1578 "Reference GatherNd: output type not supported");
1579
1580 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1581 "Reference GatherNd: indices (input1) type not supported");
1582
1583 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1584 "Reference GatherNd: input and output types not matching");
1585
1586 return supported;
1587 }
1588
IsGatherSupported(const armnn::TensorInfo & input0,const armnn::TensorInfo & input1,const armnn::TensorInfo & output,const GatherDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const1589 bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1590 const armnn::TensorInfo& input1,
1591 const armnn::TensorInfo& output,
1592 const GatherDescriptor& descriptor,
1593 armnn::Optional<std::string&> reasonIfUnsupported) const
1594 {
1595 bool supported = true;
1596 std::array<DataType,7> supportedTypes =
1597 {
1598 DataType::Float32,
1599 DataType::Float16,
1600 DataType::QAsymmS8,
1601 DataType::QAsymmU8,
1602 DataType::QSymmS16,
1603 DataType::Signed32
1604 };
1605
1606 IgnoreUnused(descriptor);
1607 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1608 "Reference Gather: input type not supported");
1609
1610 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1611 "Reference Gather: output type not supported");
1612
1613 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1614 "Reference Gather: indices (input1) type not supported");
1615
1616 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1617 "Reference Gather: input and output types not matching");
1618
1619 return supported;
1620 }
1621
IsInputSupported(const TensorInfo &,Optional<std::string &>) const1622 bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1623 Optional<std::string&> /*reasonIfUnsupported*/) const
1624 {
1625 return true;
1626 }
1627
IsInstanceNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const InstanceNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1628 bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1629 const TensorInfo& output,
1630 const InstanceNormalizationDescriptor& descriptor,
1631 Optional<std::string&> reasonIfUnsupported) const
1632 {
1633 IgnoreUnused(descriptor);
1634 // Define supported types
1635 std::array<DataType, 3> supportedTypes =
1636 {
1637 DataType::Float32,
1638 DataType::Float16
1639 };
1640
1641 bool supported = true;
1642
1643 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1644 "Reference Instance Normalization: input type not supported.");
1645
1646 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1647 "Reference Instance Normalization: output type not supported.");
1648
1649 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1650 "Reference Instance Normalization: input and output types mismatched.");
1651
1652 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1653 "Reference Instance Normalization: input and output shapes have different "
1654 "num total elements.");
1655
1656 return supported;
1657 }
1658
IsL2NormalizationSupported(const TensorInfo & input,const TensorInfo & output,const L2NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1659 bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1660 const TensorInfo& output,
1661 const L2NormalizationDescriptor& descriptor,
1662 Optional<std::string&> reasonIfUnsupported) const
1663 {
1664 IgnoreUnused(descriptor);
1665 // Define supported types
1666 std::array<DataType, 6> supportedTypes =
1667 {
1668 DataType::Float32,
1669 DataType::Float16,
1670 DataType::QAsymmS8,
1671 DataType::QAsymmU8,
1672 DataType::QSymmS16
1673 };
1674
1675 bool supported = true;
1676
1677 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1678 "Reference L2normalization: input type not supported.");
1679
1680 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1681 "Reference L2normalization: output type not supported.");
1682
1683 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1684 "Reference L2normalization: input and output types mismatched.");
1685
1686 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1687 "Reference L2normalization: input and output shapes have different "
1688 "num total elements.");
1689
1690 return supported;
1691 }
1692
IsLogicalBinarySupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const LogicalBinaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1693 bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1694 const TensorInfo& input1,
1695 const TensorInfo& output,
1696 const LogicalBinaryDescriptor& descriptor,
1697 Optional<std::string&> reasonIfUnsupported) const
1698 {
1699 IgnoreUnused(descriptor);
1700
1701 std::array<DataType, 1> supportedTypes =
1702 {
1703 DataType::Boolean
1704 };
1705
1706 bool supported = true;
1707 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1708 "Reference LogicalBinary: input 0 type not supported");
1709 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1710 "Reference LogicalBinary: input 1 type not supported");
1711
1712 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1713 "Reference LogicalBinary: input and output types do not match");
1714
1715 return supported;
1716 }
1717
IsLogSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const LogSoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1718 bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1719 const TensorInfo& output,
1720 const LogSoftmaxDescriptor& descriptor,
1721 Optional<std::string&> reasonIfUnsupported) const
1722 {
1723 IgnoreUnused(descriptor);
1724
1725 std::array<DataType, 3> supportedTypes =
1726 {
1727 DataType::Float32,
1728 DataType::Float16
1729 };
1730
1731 bool supported = true;
1732 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1733 "Reference LogSoftmax: input type not supported");
1734
1735 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1736 "Reference LogSoftmax: output type not supported");
1737
1738 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1739 "Reference LogSoftmax: input and output types do not match");
1740
1741 return supported;
1742 }
1743
IsLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & scratchBuffer,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const LstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1744 bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1745 const TensorInfo& outputStateIn,
1746 const TensorInfo& cellStateIn,
1747 const TensorInfo& scratchBuffer,
1748 const TensorInfo& outputStateOut,
1749 const TensorInfo& cellStateOut,
1750 const TensorInfo& output,
1751 const LstmDescriptor& descriptor,
1752 const LstmInputParamsInfo& paramsInfo,
1753 Optional<std::string&> reasonIfUnsupported) const
1754 {
1755 IgnoreUnused(descriptor);
1756 IgnoreUnused(paramsInfo);
1757
1758 bool supported = true;
1759
1760 std::array<DataType,3> supportedTypes = {
1761 DataType::Float32,
1762 DataType::QSymmS16
1763 };
1764
1765 // check inputs and outputs
1766 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1767 "Reference Lstm: input is not a supported type.");
1768 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1769 "Reference Lstm: input and outputStateIn types are mismatched");
1770 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1771 "Reference Lstm: input and cellStateIn types are mismatched");
1772 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1773 "Reference Lstm: input and scratchBuffer types are mismatched");
1774 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1775 "Reference Lstm: input and outputStateOut types are mismatched");
1776 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1777 "Reference Lstm: input and cellStateOut types are mismatched");
1778
1779 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1780 "Reference Lstm: input and output types are mismatched");
1781 // check layer parameters
1782 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
1783 "Reference Lstm: input and InputToForgetWeights types are mismatched");
1784 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
1785 "Reference Lstm: input and InputToCellWeights types are mismatched");
1786 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
1787 "Reference Lstm: input and InputToOutputWeights types are mismatched");
1788 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
1789 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
1790 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
1791 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
1792 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
1793 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
1794 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
1795 "Reference Lstm: input and ForgetGateBias types are mismatched");
1796 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
1797 "Reference Lstm: input and CellBias types are mismatched");
1798 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
1799 "Reference Lstm: input and OutputGateBias types are mismatched");
1800 if (!descriptor.m_CifgEnabled)
1801 {
1802 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
1803 "Reference Lstm: input and InputToInputWeights types are mismatched");
1804 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
1805 reasonIfUnsupported,
1806 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
1807 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
1808 "Reference Lstm: input and InputGateBias types are mismatched");
1809 if (descriptor.m_PeepholeEnabled)
1810 {
1811 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
1812 reasonIfUnsupported,
1813 "Reference Lstm: input and CellToInputWeights types are mismatched");
1814 }
1815 }
1816 if (descriptor.m_PeepholeEnabled)
1817 {
1818 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
1819 "Reference Lstm: input and CellToForgetWeights types are mismatched");
1820 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
1821 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1822 }
1823 if (descriptor.m_ProjectionEnabled)
1824 {
1825 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
1826 "Reference Lstm: input and mProjectionWeights types are mismatched");
1827 if (paramsInfo.m_ProjectionBias != nullptr)
1828 {
1829 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
1830 "Reference Lstm: input and ProjectionBias types are mismatched");
1831 }
1832 }
1833 if (descriptor.m_LayerNormEnabled)
1834 {
1835 if (!descriptor.m_CifgEnabled)
1836 {
1837 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
1838 reasonIfUnsupported,
1839 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1840 }
1841 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
1842 reasonIfUnsupported,
1843 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1844 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
1845 reasonIfUnsupported,
1846 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1847 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
1848 reasonIfUnsupported,
1849 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1850 }
1851
1852 return supported;
1853 }
1854
IsMaximumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1855 bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1856 const TensorInfo& input1,
1857 const TensorInfo& output,
1858 Optional<std::string&> reasonIfUnsupported) const
1859 {
1860 bool supported = true;
1861
1862 std::array<DataType,7> supportedTypes = {
1863 DataType::Float32,
1864 DataType::Float16,
1865 DataType::QAsymmS8,
1866 DataType::QAsymmU8,
1867 DataType::QSymmS16,
1868 DataType::Signed32
1869 };
1870
1871 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1872 "Reference maximum: input 0 is not a supported type.");
1873
1874 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1875 "Reference maximum: input 1 is not a supported type.");
1876
1877 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1878 "Reference maximum: output is not a supported type.");
1879
1880 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1881 "Reference maximum: input 0 and Input 1 types are mismatched");
1882
1883 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1884 "Reference maximum: input and output types are mismatched");
1885
1886 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1887 "Reference maximum: shapes are not suitable for implicit broadcast.");
1888
1889 return supported;
1890 }
1891
IsMeanSupported(const TensorInfo & input,const TensorInfo & output,const MeanDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1892 bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1893 const TensorInfo& output,
1894 const MeanDescriptor& descriptor,
1895 Optional<std::string&> reasonIfUnsupported) const
1896 {
1897 bool supported = true;
1898 std::string meanLayerStr = "Mean";
1899 std::string outputTensorStr = "output";
1900
1901 std::array<DataType,6> supportedTypes =
1902 {
1903 DataType::Float32,
1904 DataType::Float16,
1905 DataType::QAsymmS8,
1906 DataType::QAsymmU8,
1907 DataType::QSymmS16
1908 };
1909
1910 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1911 "Reference Mean: input type not supported.");
1912
1913 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1914 "Reference Mean: input and output types are mismatched");
1915
1916 if (descriptor.m_KeepDims)
1917 {
1918 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1919 reasonIfUnsupported,
1920 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1921 output.GetNumDimensions(),
1922 meanLayerStr, outputTensorStr).data());
1923 }
1924 else if (descriptor.m_Axis.empty())
1925 {
1926 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1927 reasonIfUnsupported,
1928 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1929 meanLayerStr, outputTensorStr).data());
1930 }
1931 else
1932 {
1933 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1934
1935 if (outputDim > 0)
1936 {
1937 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1938 reasonIfUnsupported,
1939 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1940 meanLayerStr, outputTensorStr).data());
1941 }
1942 else
1943 {
1944 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1945 reasonIfUnsupported,
1946 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1947 meanLayerStr, outputTensorStr).data());
1948 }
1949 }
1950
1951 return supported;
1952 }
1953
IsMemCopySupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1954 bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1955 const TensorInfo &output,
1956 Optional<std::string &> reasonIfUnsupported) const
1957 {
1958 bool supported = true;
1959
1960 std::array<DataType,7> supportedTypes =
1961 {
1962 DataType::BFloat16,
1963 DataType::Float32,
1964 DataType::Float16,
1965 DataType::QAsymmS8,
1966 DataType::QAsymmU8,
1967 DataType::QSymmS16,
1968 DataType::Boolean
1969 };
1970
1971 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1972 "Reference MemCopy: input type not supported");
1973
1974 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1975 "Reference MemCopy: output type not supported");
1976
1977 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1978 "Reference MemCopy: input and output types are mismatched");
1979
1980 return supported;
1981 }
1982
IsMinimumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1983 bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1984 const TensorInfo& input1,
1985 const TensorInfo& output,
1986 Optional<std::string&> reasonIfUnsupported) const
1987 {
1988 bool supported = true;
1989
1990 std::array<DataType,7> supportedTypes = {
1991 DataType::Float32,
1992 DataType::Float16,
1993 DataType::QAsymmS8,
1994 DataType::QAsymmU8,
1995 DataType::QSymmS16,
1996 DataType::Signed32
1997 };
1998
1999 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2000 "Reference minimum: input 0 is not a supported type.");
2001
2002 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2003 "Reference minimum: input 1 is not a supported type.");
2004
2005 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2006 "Reference minimum: output is not a supported type.");
2007
2008 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2009 "Reference minimum: input 0 and Input 1 types are mismatched");
2010
2011 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2012 "Reference minimum: input and output types are mismatched");
2013
2014 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2015 "Reference minimum: shapes are not suitable for implicit broadcast.");
2016
2017 return supported;
2018 }
2019
IsMultiplicationSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2020 bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2021 const TensorInfo& input1,
2022 const TensorInfo& output,
2023 Optional<std::string&> reasonIfUnsupported) const
2024 {
2025 bool supported = true;
2026
2027 std::array<DataType,7> supportedTypes = {
2028 DataType::Float32,
2029 DataType::Float16,
2030 DataType::QAsymmS8,
2031 DataType::QAsymmU8,
2032 DataType::QSymmS16,
2033 DataType::Signed32
2034 };
2035
2036 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2037 "Reference multiplication: input 0 is not a supported type.");
2038
2039 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2040 "Reference multiplication: input 1 is not a supported type.");
2041
2042 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2043 "Reference multiplication: output is not a supported type.");
2044
2045 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2046 "Reference multiplication: input 0 and Input 1 types are mismatched");
2047
2048 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2049 "Reference multiplication: input and output types are mismatched");
2050
2051 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2052 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2053
2054 return supported;
2055 }
2056
IsNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2057 bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2058 const TensorInfo& output,
2059 const NormalizationDescriptor& descriptor,
2060 Optional<std::string&> reasonIfUnsupported) const
2061 {
2062 IgnoreUnused(descriptor);
2063
2064 // Define supported types
2065 std::array<DataType, 6> supportedTypes =
2066 {
2067 DataType::Float16,
2068 DataType::Float32,
2069 DataType::QAsymmS8,
2070 DataType::QAsymmU8,
2071 DataType::QSymmS16
2072 };
2073
2074 bool supported = true;
2075
2076 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2077 "Reference normalization: input type not supported.");
2078
2079 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2080 "Reference normalization: output type not supported.");
2081
2082 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2083 "Reference normalization: input and output shapes have different "
2084 "num total elements.");
2085
2086 return supported;
2087 }
2088
IsOutputSupported(const TensorInfo &,Optional<std::string &>) const2089 bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2090 Optional<std::string&> /*reasonIfUnsupported*/) const
2091 {
2092 return true;
2093 }
2094
IsPadSupported(const TensorInfo & input,const TensorInfo & output,const PadDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2095 bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2096 const TensorInfo& output,
2097 const PadDescriptor& descriptor,
2098 Optional<std::string&> reasonIfUnsupported) const
2099 {
2100 IgnoreUnused(descriptor);
2101 bool supported = true;
2102
2103 // Define supported output and inputs types.
2104 std::array<DataType,6> supportedTypes =
2105 {
2106 DataType::Float32,
2107 DataType::Float16,
2108 DataType::QAsymmS8,
2109 DataType::QAsymmU8,
2110 DataType::QSymmS16
2111 };
2112
2113 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2114 "Reference pad: input is not a supported type.");
2115
2116 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2117 "Reference pad: output is not a supported type.");
2118
2119 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2120 "Reference pad: input and output types are mismatched.");
2121
2122 return supported;
2123 }
2124
IsPermuteSupported(const TensorInfo & input,const TensorInfo & output,const PermuteDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2125 bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2126 const TensorInfo& output,
2127 const PermuteDescriptor& descriptor,
2128 Optional<std::string&> reasonIfUnsupported) const
2129 {
2130 IgnoreUnused(descriptor);
2131 bool supported = true;
2132
2133 // Define supported output and inputs types.
2134 std::array<DataType, 6> supportedTypes =
2135 {
2136 DataType::BFloat16,
2137 DataType::Float32,
2138 DataType::Float16,
2139 DataType::QAsymmS8,
2140 DataType::QAsymmU8,
2141 DataType::QSymmS16
2142 };
2143
2144 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2145 "Reference permute: input is not a supported type.");
2146
2147 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2148 "Reference permute: output is not a supported type.");
2149
2150 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2151 "Reference permute: input and output types are mismatched.");
2152
2153 return supported;
2154 }
2155
IsPooling2dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling2dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2156 bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2157 const TensorInfo& output,
2158 const Pooling2dDescriptor& descriptor,
2159 Optional<std::string&> reasonIfUnsupported) const
2160 {
2161 IgnoreUnused(descriptor);
2162 bool supported = true;
2163
2164 // Define supported output and inputs types.
2165 std::array<DataType,6> supportedTypes =
2166 {
2167 DataType::Float32,
2168 DataType::Float16,
2169 DataType::QAsymmS8,
2170 DataType::QAsymmU8,
2171 DataType::QSymmS16
2172 };
2173
2174 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2175 "Reference poolind2d: input is not a supported type.");
2176
2177 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2178 "Reference poolind2d: output is not a supported type.");
2179
2180 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2181 "Reference poolind2d: input and output types are mismatched.");
2182
2183 return supported;
2184 }
2185
IsPooling3dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling3dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2186 bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2187 const TensorInfo& output,
2188 const Pooling3dDescriptor& descriptor,
2189 Optional<std::string&> reasonIfUnsupported) const
2190 {
2191 IgnoreUnused(descriptor);
2192 bool supported = true;
2193
2194 // Define supported output and inputs types.
2195 std::array<DataType,6> supportedTypes =
2196 {
2197 DataType::Float32,
2198 DataType::Float16,
2199 DataType::QAsymmS8,
2200 DataType::QAsymmU8,
2201 DataType::QSymmS16
2202 };
2203
2204 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2205 "Reference poolind3d: input is not a supported type.");
2206
2207 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2208 "Reference poolind3d: output is not a supported type.");
2209
2210 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2211 "Reference poolind3d: input and output types are mismatched.");
2212
2213 return supported;
2214 }
2215
2216
IsQLstmSupported(const TensorInfo & input,const TensorInfo & previousOutputIn,const TensorInfo & previousCellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const QLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const2217 bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2218 const TensorInfo& previousOutputIn,
2219 const TensorInfo& previousCellStateIn,
2220 const TensorInfo& outputStateOut,
2221 const TensorInfo& cellStateOut,
2222 const TensorInfo& output,
2223 const QLstmDescriptor& descriptor,
2224 const LstmInputParamsInfo& paramsInfo,
2225 Optional<std::string&> reasonIfUnsupported) const
2226 {
2227 IgnoreUnused(input);
2228 IgnoreUnused(previousOutputIn);
2229 IgnoreUnused(previousCellStateIn);
2230 IgnoreUnused(outputStateOut);
2231 IgnoreUnused(cellStateOut);
2232 IgnoreUnused(output);
2233 IgnoreUnused(descriptor);
2234 IgnoreUnused(paramsInfo);
2235
2236 IgnoreUnused(reasonIfUnsupported);
2237
2238 return true;
2239 }
2240
IsQuantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2241 bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2242 const TensorInfo& output,
2243 Optional<std::string&> reasonIfUnsupported) const
2244 {
2245 bool supported = true;
2246
2247 // Define supported input types.
2248 std::array<DataType,7> supportedInputTypes = {
2249 DataType::Float32,
2250 DataType::Float16,
2251 DataType::QAsymmS8,
2252 DataType::QAsymmU8,
2253 DataType::QSymmS8,
2254 DataType::QSymmS16
2255 };
2256
2257 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2258 "Reference quantize: input type not supported.");
2259
2260 // Define supported output types.
2261 std::array<DataType,4> supportedOutputTypes = {
2262 DataType::QAsymmS8,
2263 DataType::QAsymmU8,
2264 DataType::QSymmS8,
2265 DataType::QSymmS16
2266 };
2267 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2268 "Reference quantize: output type not supported.");
2269
2270 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2271 "Reference quantize: input and output shapes have different num total elements.");
2272
2273 return supported;
2274 }
2275
IsRankSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2276 bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2277 const TensorInfo& output,
2278 Optional<std::string&> reasonIfUnsupported) const
2279 {
2280 IgnoreUnused(input);
2281 // Define supported output types.
2282 std::array<DataType,1> supportedOutputTypes =
2283 {
2284 DataType::Signed32,
2285 };
2286
2287 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2288 "Reference rank: input type not supported.");
2289 }
2290
IsReduceSupported(const TensorInfo & input,const TensorInfo & output,const ReduceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2291 bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2292 const TensorInfo& output,
2293 const ReduceDescriptor& descriptor,
2294 Optional<std::string&> reasonIfUnsupported) const
2295 {
2296 IgnoreUnused(descriptor);
2297 bool supported = true;
2298 std::array<DataType,7> supportedTypes =
2299 {
2300 DataType::Float32,
2301 DataType::Float16,
2302 DataType::QAsymmS8,
2303 DataType::QAsymmU8,
2304 DataType::QSymmS16,
2305 DataType::Signed32
2306 };
2307
2308 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2309 "Reference Reduce: input type not supported");
2310
2311 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2312 "Reference Reduce: output type not supported");
2313
2314 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2315 "Reference Reduce: input and output types not matching");
2316
2317 return supported;
2318 }
2319
IsReshapeSupported(const TensorInfo & input,const TensorInfo & output,const ReshapeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2320 bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
2321 const TensorInfo& output,
2322 const ReshapeDescriptor& descriptor,
2323 Optional<std::string&> reasonIfUnsupported) const
2324 {
2325 IgnoreUnused(output);
2326 IgnoreUnused(descriptor);
2327 // Define supported output types.
2328 std::array<DataType,8> supportedOutputTypes =
2329 {
2330 DataType::BFloat16,
2331 DataType::Float32,
2332 DataType::Float16,
2333 DataType::Signed32,
2334 DataType::QAsymmS8,
2335 DataType::QAsymmU8,
2336 DataType::QSymmS16,
2337 DataType::Boolean
2338 };
2339
2340 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2341 "Reference reshape: input type not supported.");
2342 }
2343
IsResizeSupported(const TensorInfo & input,const TensorInfo & output,const ResizeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2344 bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2345 const TensorInfo& output,
2346 const ResizeDescriptor& descriptor,
2347 Optional<std::string&> reasonIfUnsupported) const
2348 {
2349 IgnoreUnused(descriptor);
2350 bool supported = true;
2351 std::array<DataType,6> supportedTypes =
2352 {
2353 DataType::BFloat16,
2354 DataType::Float32,
2355 DataType::Float16,
2356 DataType::QAsymmS8,
2357 DataType::QAsymmU8,
2358 DataType::QSymmS16
2359 };
2360
2361 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2362 "Reference Resize: input type not supported");
2363
2364 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2365 "Reference Resize: output type not supported");
2366
2367 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2368 "Reference Resize: input and output types not matching");
2369
2370 return supported;
2371 }
2372
IsShapeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2373 bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2374 const TensorInfo& output,
2375 Optional<std::string&> reasonIfUnsupported) const
2376 {
2377 IgnoreUnused(input);
2378 bool supported = true;
2379
2380 std::array<DataType, 1> supportedTypes =
2381 {
2382 DataType::Signed32
2383 };
2384
2385 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2386 "Reference Shape: output type not supported");
2387
2388 return supported;
2389 }
2390
IsSliceSupported(const TensorInfo & input,const TensorInfo & output,const SliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2391 bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2392 const TensorInfo& output,
2393 const SliceDescriptor& descriptor,
2394 Optional<std::string&> reasonIfUnsupported) const
2395 {
2396 IgnoreUnused(descriptor);
2397 bool supported = true;
2398
2399 std::array<DataType, 5> supportedTypes =
2400 {
2401 DataType::Float32,
2402 DataType::QAsymmS8,
2403 DataType::QAsymmU8,
2404 DataType::QSymmS16
2405 };
2406
2407 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2408 "Reference Slice: input type not supported");
2409
2410 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2411 "Reference Slice: output type not supported");
2412
2413 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2414 "Reference Slice: input and output types are mismatched");
2415
2416 return supported;
2417 }
2418
IsSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const SoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2419 bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2420 const TensorInfo& output,
2421 const SoftmaxDescriptor& descriptor,
2422 Optional<std::string&> reasonIfUnsupported) const
2423 {
2424 IgnoreUnused(descriptor);
2425 bool supported = true;
2426 std::array<DataType,7> supportedTypes =
2427 {
2428 DataType::Float32,
2429 DataType::Float16,
2430 DataType::QSymmS8,
2431 DataType::QAsymmS8,
2432 DataType::QAsymmU8,
2433 DataType::QSymmS16
2434 };
2435
2436 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2437 "Reference Softmax: output type not supported");
2438
2439 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2440 "Reference Softmax: input type not supported");
2441
2442 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2443 "Reference Softmax: input type not supported");
2444
2445 return supported;
2446 }
2447
IsSpaceToBatchNdSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToBatchNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2448 bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2449 const TensorInfo& output,
2450 const SpaceToBatchNdDescriptor& descriptor,
2451 Optional<std::string&> reasonIfUnsupported) const
2452 {
2453 IgnoreUnused(descriptor);
2454 bool supported = true;
2455 std::array<DataType,6> supportedTypes =
2456 {
2457 DataType::Float32,
2458 DataType::Float16,
2459 DataType::QAsymmS8,
2460 DataType::QAsymmU8,
2461 DataType::QSymmS16
2462 };
2463
2464 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2465 "Reference SpaceToBatchNd: input type not supported");
2466
2467 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2468 "Reference SpaceToBatchNd: output type not supported");
2469
2470 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2471 "Reference SpaceToBatchNd: input and output types are mismatched");
2472
2473 return supported;
2474 }
2475
IsSpaceToDepthSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToDepthDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2476 bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
2477 const TensorInfo& output,
2478 const SpaceToDepthDescriptor& descriptor,
2479 Optional<std::string&> reasonIfUnsupported) const
2480 {
2481
2482 IgnoreUnused(descriptor);
2483 bool supported = true;
2484
2485 std::array<DataType,6> supportedTypes =
2486 {
2487 DataType::Float32,
2488 DataType::Float16,
2489 DataType::QAsymmS8,
2490 DataType::QAsymmU8,
2491 DataType::QSymmS16
2492 };
2493
2494 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2495 "Reference SpaceToDepth: input type not supported");
2496
2497 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2498 "Reference SpaceToDepth: output type not supported");
2499
2500 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2501 "Reference SpaceToDepth: input and output types are mismatched");
2502
2503 return supported;
2504 }
2505
IsSplitterSupported(const TensorInfo & input,const std::vector<std::reference_wrapper<TensorInfo>> & outputs,const ViewsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2506 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
2507 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2508 const ViewsDescriptor& descriptor,
2509 Optional<std::string&> reasonIfUnsupported) const
2510 {
2511 IgnoreUnused(descriptor);
2512 bool supported = true;
2513 std::array<DataType,6> supportedTypes =
2514 {
2515 DataType::Float32,
2516 DataType::Float16,
2517 DataType::QAsymmS8,
2518 DataType::QAsymmU8,
2519 DataType::QSymmS16
2520 };
2521
2522 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2523 "Reference splitter: output type not supported");
2524 for (const TensorInfo& output : outputs)
2525 {
2526 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2527 "Reference splitter: input type not supported");
2528
2529 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2530 "Reference splitter: input and output types mismatched.");
2531 }
2532
2533 return supported;
2534 }
2535
IsStackSupported(const std::vector<const TensorInfo * > & inputs,const TensorInfo & output,const StackDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2536 bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2537 const TensorInfo& output,
2538 const StackDescriptor& descriptor,
2539 Optional<std::string&> reasonIfUnsupported) const
2540 {
2541 IgnoreUnused(descriptor);
2542
2543 bool supported = true;
2544 std::array<DataType,7> supportedTypes =
2545 {
2546 DataType::Float32,
2547 DataType::Float16,
2548 DataType::QAsymmS8,
2549 DataType::QAsymmU8,
2550 DataType::QSymmS16,
2551 DataType::Signed32
2552 };
2553
2554 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2555 "Reference stack: output type not supported");
2556 for (const TensorInfo* input : inputs)
2557 {
2558 ARMNN_ASSERT(input != nullptr);
2559 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2560 "Reference stack: input type not supported");
2561
2562 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2563 "Reference stack: input and output types mismatched.");
2564 }
2565
2566 return supported;
2567 }
2568
IsStridedSliceSupported(const TensorInfo & input,const TensorInfo & output,const StridedSliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2569 bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2570 const TensorInfo& output,
2571 const StridedSliceDescriptor& descriptor,
2572 Optional<std::string&> reasonIfUnsupported) const
2573 {
2574 IgnoreUnused(descriptor);
2575 bool supported = true;
2576
2577 std::array<DataType,5> supportedTypes =
2578 {
2579 DataType::Float32,
2580 DataType::QAsymmS8,
2581 DataType::QAsymmU8,
2582 DataType::QSymmS16
2583 };
2584
2585 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2586 "Reference StridedSlice: input type not supported");
2587
2588 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2589 "Reference StridedSlice: output type not supported");
2590
2591 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2592 "Reference StridedSlice: input and output types are mismatched");
2593
2594 return supported;
2595 }
2596
IsSubtractionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2597 bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2598 const TensorInfo& input1,
2599 const TensorInfo& output,
2600 Optional<std::string&> reasonIfUnsupported) const
2601 {
2602 bool supported = true;
2603
2604 std::array<DataType,7> supportedTypes = {
2605 DataType::Float32,
2606 DataType::Float16,
2607 DataType::QAsymmS8,
2608 DataType::QAsymmU8,
2609 DataType::QSymmS16,
2610 DataType::Signed32
2611 };
2612
2613 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2614 "Reference subtraction: input 0 is not a supported type.");
2615
2616 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2617 "Reference subtraction: input 1 is not a supported type.");
2618
2619 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2620 "Reference subtraction: output is not a supported type.");
2621
2622 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2623 "Reference subtraction: input 0 and Input 1 types are mismatched");
2624
2625 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2626 "Reference subtraction: input and output types are mismatched");
2627
2628 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2629 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2630
2631 return supported;
2632 }
2633
IsPreluSupported(const TensorInfo & input,const TensorInfo & alpha,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2634 bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2635 const TensorInfo& alpha,
2636 const TensorInfo& output,
2637 Optional<std::string&> reasonIfUnsupported) const
2638 {
2639 bool supported = true;
2640
2641 std::array<DataType, 6> supportedTypes
2642 {
2643 DataType::Float32,
2644 DataType::Float16,
2645 DataType::QAsymmS8,
2646 DataType::QAsymmU8,
2647 DataType::QSymmS16
2648 };
2649
2650 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2651 "PReLU: input is not a supported type.");
2652
2653 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2654 "PReLU: alpha is not a supported type.");
2655
2656 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2657 "PReLU: output is not a supported type.");
2658
2659 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2660 "PReLU: input, alpha and output types are mismatched");
2661
2662 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2663 "PReLU: shapes are not suitable for implicit broadcast");
2664
2665 return supported;
2666 }
2667
IsTransposeConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const TransposeConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const2668 bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2669 const TensorInfo& output,
2670 const TransposeConvolution2dDescriptor& descriptor,
2671 const TensorInfo& weights,
2672 const Optional<TensorInfo>& biases,
2673 Optional<std::string&> reasonIfUnsupported) const
2674 {
2675 IgnoreUnused(descriptor);
2676 bool supported = true;
2677
2678 std::array<DataType,7> supportedTypes =
2679 {
2680 DataType::Float32,
2681 DataType::Float16,
2682 DataType::QAsymmS8,
2683 DataType::QAsymmU8,
2684 DataType::QSymmS8,
2685 DataType::QSymmS16
2686 };
2687
2688 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2689 "Reference TransposeConvolution2d: input is not a supported type.");
2690
2691 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2692 "Reference TransposeConvolution2d: output is not a supported type.");
2693
2694 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2695 "Reference TransposeConvolution2d: input and output types mismatched.");
2696
2697
2698 const DataType inputType = input.GetDataType();
2699 if (IsQuantized8BitType(inputType))
2700 {
2701 std::array<DataType, 3> supportedWeightTypes =
2702 {
2703 DataType::QAsymmS8,
2704 DataType::QAsymmU8,
2705 DataType::QSymmS8
2706 };
2707
2708 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2709 "Reference TransposeConvolution2d: weights type not supported for "
2710 "quantized input.");
2711 }
2712 else
2713 {
2714 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2715 "Reference TransposeConvolution2d: weights is not a supported type.");
2716
2717 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2718 "Reference TransposeConvolution2d: input and weights types mismatched.");
2719 }
2720
2721 if (biases.has_value())
2722 {
2723 std::array<DataType,4> biasesSupportedTypes =
2724 {
2725 DataType::Float32,
2726 DataType::Float16,
2727 DataType::Signed32
2728 };
2729 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2730 "Reference TransposeConvolution2d: biases is not a supported type.");
2731 }
2732
2733 return supported;
2734 }
2735
IsTransposeSupported(const TensorInfo & input,const TensorInfo & output,const TransposeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2736 bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2737 const TensorInfo& output,
2738 const TransposeDescriptor& descriptor,
2739 Optional<std::string&> reasonIfUnsupported) const
2740 {
2741 IgnoreUnused(descriptor);
2742 bool supported = true;
2743
2744 // Define supported output and inputs types.
2745 std::array<DataType, 6> supportedTypes =
2746 {
2747 DataType::BFloat16,
2748 DataType::Float32,
2749 DataType::Float16,
2750 DataType::QAsymmS8,
2751 DataType::QAsymmU8,
2752 DataType::QSymmS16
2753 };
2754
2755 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2756 "Reference transpose: input is not a supported type.");
2757
2758 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2759 "Reference transpose: output is not a supported type.");
2760
2761 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2762 "Reference transpose: input and output types are mismatched.");
2763
2764 return supported;
2765 }
2766
IsUnidirectionalSequenceLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const UnidirectionalSequenceLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const2767 bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2768 const TensorInfo& input,
2769 const TensorInfo& outputStateIn,
2770 const TensorInfo& cellStateIn,
2771 const TensorInfo& outputStateOut,
2772 const TensorInfo& cellStateOut,
2773 const TensorInfo& output,
2774 const UnidirectionalSequenceLstmDescriptor& descriptor,
2775 const LstmInputParamsInfo& paramsInfo,
2776 Optional<std::string&> reasonIfUnsupported) const
2777 {
2778 IgnoreUnused(descriptor);
2779 IgnoreUnused(paramsInfo);
2780 IgnoreUnused(outputStateIn);
2781 IgnoreUnused(cellStateIn);
2782 IgnoreUnused(outputStateOut);
2783 IgnoreUnused(cellStateOut);
2784 bool supported = true;
2785
2786 std::array<DataType, 2> supportedTypes =
2787 {
2788 DataType::Float32,
2789 DataType::QAsymmS8
2790 };
2791
2792 std::array<DataType, 2> supportedWeightTypes =
2793 {
2794 DataType::Float32,
2795 DataType::QAsymmS8
2796 };
2797
2798 std::array<DataType, 3> supportedBiasTypes =
2799 {
2800 DataType::Float32,
2801 DataType::QAsymmS8,
2802 DataType::Signed32
2803 };
2804
2805 // check inputs and outputs
2806 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2807 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
2808 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2809 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
2810
2811 // check layer parameters
2812 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2813 reasonIfUnsupported,
2814 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2815 "is not a supported type.");
2816 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2817 reasonIfUnsupported,
2818 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2819 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2820 reasonIfUnsupported,
2821 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2822 "is not a supported type.");
2823 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2824 reasonIfUnsupported,
2825 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2826 "is not a supported type.");
2827 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2828 reasonIfUnsupported,
2829 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2830 "is not a supported type.");
2831 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2832 reasonIfUnsupported,
2833 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2834 "is not a supported type.");
2835
2836 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2837 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2838 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2839 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2840 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2841 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
2842 if (!descriptor.m_CifgEnabled)
2843 {
2844 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2845 reasonIfUnsupported,
2846 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2847 "is not a supported type.");
2848 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2849 reasonIfUnsupported,
2850 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2851 "is not a supported type.");
2852 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2853 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
2854 if (descriptor.m_PeepholeEnabled)
2855 {
2856 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2857 reasonIfUnsupported,
2858 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2859 "is not a supported type.");
2860 }
2861 }
2862 if (descriptor.m_PeepholeEnabled)
2863 {
2864 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2865 reasonIfUnsupported,
2866 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2867 "is not a supported type.");
2868 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2869 reasonIfUnsupported,
2870 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2871 "is not a supported type.");
2872 }
2873 if (descriptor.m_ProjectionEnabled)
2874 {
2875 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2876 reasonIfUnsupported,
2877 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2878 "is not a supported type.");
2879 if (paramsInfo.m_ProjectionBias != nullptr)
2880 {
2881 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2882 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2883 "are mismatched");
2884 }
2885 }
2886 if (descriptor.m_LayerNormEnabled)
2887 {
2888 if (!descriptor.m_CifgEnabled)
2889 {
2890 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2891 reasonIfUnsupported,
2892 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2893 "is not a supported type.");
2894 }
2895 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2896 reasonIfUnsupported,
2897 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2898 "is not a supported type.");
2899 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2900 reasonIfUnsupported,
2901 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2902 "is not a supported type.");
2903 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2904 reasonIfUnsupported,
2905 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2906 "is not a supported type.");
2907 }
2908
2909 return supported;
2910 }
2911
2912 } // namespace armnn
2913