1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "SampleDynamicLayerSupport.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <InternalTypes.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <LayerSupportCommon.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker namespace sdb // sample dynamic backend
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker
IsLayerSupported(const armnn::LayerType & type,const std::vector<armnn::TensorInfo> & infos,const armnn::BaseDescriptor &,const armnn::Optional<armnn::LstmInputParamsInfo> &,const armnn::Optional<armnn::QuantizedLstmInputParamsInfo> &,armnn::Optional<std::string &> reasonIfUnsupported) const15*89c4ff92SAndroid Build Coastguard Worker bool SampleDynamicLayerSupport::IsLayerSupported(const armnn::LayerType& type,
16*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::TensorInfo>& infos,
17*89c4ff92SAndroid Build Coastguard Worker const armnn::BaseDescriptor& /*descriptor*/,
18*89c4ff92SAndroid Build Coastguard Worker const armnn::Optional<armnn::LstmInputParamsInfo>&
19*89c4ff92SAndroid Build Coastguard Worker /*lstmParamsInfo*/,
20*89c4ff92SAndroid Build Coastguard Worker const armnn::Optional<armnn::QuantizedLstmInputParamsInfo>&
21*89c4ff92SAndroid Build Coastguard Worker /*quantizedLstmParamsInfo*/,
22*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<std::string&> reasonIfUnsupported) const
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker switch (type)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::Input:
27*89c4ff92SAndroid Build Coastguard Worker return IsInputSupported(infos[0], reasonIfUnsupported);
28*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::Output:
29*89c4ff92SAndroid Build Coastguard Worker return IsOutputSupported(infos[0], reasonIfUnsupported);
30*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::Addition:
31*89c4ff92SAndroid Build Coastguard Worker return IsAdditionSupported(infos[0],
32*89c4ff92SAndroid Build Coastguard Worker infos[1],
33*89c4ff92SAndroid Build Coastguard Worker infos[2],
34*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
35*89c4ff92SAndroid Build Coastguard Worker default:
36*89c4ff92SAndroid Build Coastguard Worker return false;
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker }
39*89c4ff92SAndroid Build Coastguard Worker
IsInputSupported(const armnn::TensorInfo & input,armnn::Optional<std::string &> reasonIfUnsupported) const40*89c4ff92SAndroid Build Coastguard Worker bool SampleDynamicLayerSupport::IsInputSupported(const armnn::TensorInfo& input,
41*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<std::string&> reasonIfUnsupported) const
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker return true;
44*89c4ff92SAndroid Build Coastguard Worker }
45*89c4ff92SAndroid Build Coastguard Worker
IsOutputSupported(const armnn::TensorInfo & output,armnn::Optional<std::string &> reasonIfUnsupported) const46*89c4ff92SAndroid Build Coastguard Worker bool SampleDynamicLayerSupport::IsOutputSupported(const armnn::TensorInfo& output,
47*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<std::string&> reasonIfUnsupported) const
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker return true;
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
IsAdditionSupported(const armnn::TensorInfo & input0,const armnn::TensorInfo & input1,const armnn::TensorInfo & output,armnn::Optional<std::string &> reasonIfUnsupported) const52*89c4ff92SAndroid Build Coastguard Worker bool SampleDynamicLayerSupport::IsAdditionSupported(const armnn::TensorInfo& input0,
53*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& input1,
54*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& output,
55*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<std::string&> reasonIfUnsupported) const
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker if (input0.GetDataType() != armnn::DataType::Float32)
59*89c4ff92SAndroid Build Coastguard Worker {
60*89c4ff92SAndroid Build Coastguard Worker return false;
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker if (input0.GetDataType() != input1.GetDataType())
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker return false;
66*89c4ff92SAndroid Build Coastguard Worker }
67*89c4ff92SAndroid Build Coastguard Worker
68*89c4ff92SAndroid Build Coastguard Worker if (input0.GetDataType() != output.GetDataType())
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker return false;
71*89c4ff92SAndroid Build Coastguard Worker }
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker return true;
74*89c4ff92SAndroid Build Coastguard Worker }
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Worker } // namespace sdb
77