1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker namespace armnn
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker
GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)14*89c4ff92SAndroid Build Coastguard Worker inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker if (!weightsType)
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker return weightsType;
19*89c4ff92SAndroid Build Coastguard Worker }
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker switch(weightsType.value())
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Float16:
24*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Float32:
25*89c4ff92SAndroid Build Coastguard Worker return weightsType;
26*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QAsymmS8:
27*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QAsymmU8:
28*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QSymmS8:
29*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QSymmS16:
30*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Signed32;
31*89c4ff92SAndroid Build Coastguard Worker default:
32*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
33*89c4ff92SAndroid Build Coastguard Worker }
34*89c4ff92SAndroid Build Coastguard Worker return armnn::EmptyOptional();
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker
37*89c4ff92SAndroid Build Coastguard Worker template<typename F>
CheckSupportRule(F rule,Optional<std::string &> reasonIfUnsupported,const char * reason)38*89c4ff92SAndroid Build Coastguard Worker bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker bool supported = rule();
41*89c4ff92SAndroid Build Coastguard Worker if (!supported && reason)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
44*89c4ff92SAndroid Build Coastguard Worker }
45*89c4ff92SAndroid Build Coastguard Worker return supported;
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker struct Rule
49*89c4ff92SAndroid Build Coastguard Worker {
operator ()armnn::Rule50*89c4ff92SAndroid Build Coastguard Worker bool operator()() const
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker return m_Res;
53*89c4ff92SAndroid Build Coastguard Worker }
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker bool m_Res = true;
56*89c4ff92SAndroid Build Coastguard Worker };
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker template<typename T>
AllTypesAreEqualImpl(T)59*89c4ff92SAndroid Build Coastguard Worker bool AllTypesAreEqualImpl(T)
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker return true;
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker
64*89c4ff92SAndroid Build Coastguard Worker template<typename T, typename... Rest>
AllTypesAreEqualImpl(T t1,T t2,Rest...rest)65*89c4ff92SAndroid Build Coastguard Worker bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
66*89c4ff92SAndroid Build Coastguard Worker {
67*89c4ff92SAndroid Build Coastguard Worker static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
68*89c4ff92SAndroid Build Coastguard Worker
69*89c4ff92SAndroid Build Coastguard Worker return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker struct TypesAreEqual : public Rule
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker template<typename ... Ts>
TypesAreEqualarmnn::TypesAreEqual75*89c4ff92SAndroid Build Coastguard Worker TypesAreEqual(const Ts&... ts)
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker m_Res = AllTypesAreEqualImpl(ts...);
78*89c4ff92SAndroid Build Coastguard Worker }
79*89c4ff92SAndroid Build Coastguard Worker };
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker struct QuantizationParametersAreEqual : public Rule
82*89c4ff92SAndroid Build Coastguard Worker {
QuantizationParametersAreEqualarmnn::QuantizationParametersAreEqual83*89c4ff92SAndroid Build Coastguard Worker QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
84*89c4ff92SAndroid Build Coastguard Worker {
85*89c4ff92SAndroid Build Coastguard Worker m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
86*89c4ff92SAndroid Build Coastguard Worker info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
87*89c4ff92SAndroid Build Coastguard Worker }
88*89c4ff92SAndroid Build Coastguard Worker };
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker struct TypeAnyOf : public Rule
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker template<typename Container>
TypeAnyOfarmnn::TypeAnyOf93*89c4ff92SAndroid Build Coastguard Worker TypeAnyOf(const TensorInfo& info, const Container& c)
94*89c4ff92SAndroid Build Coastguard Worker {
95*89c4ff92SAndroid Build Coastguard Worker m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker return dt == info.GetDataType();
98*89c4ff92SAndroid Build Coastguard Worker });
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker };
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker struct TypeIs : public Rule
103*89c4ff92SAndroid Build Coastguard Worker {
TypeIsarmnn::TypeIs104*89c4ff92SAndroid Build Coastguard Worker TypeIs(const TensorInfo& info, DataType dt)
105*89c4ff92SAndroid Build Coastguard Worker {
106*89c4ff92SAndroid Build Coastguard Worker m_Res = dt == info.GetDataType();
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker };
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker struct TypeNotPerAxisQuantized : public Rule
111*89c4ff92SAndroid Build Coastguard Worker {
TypeNotPerAxisQuantizedarmnn::TypeNotPerAxisQuantized112*89c4ff92SAndroid Build Coastguard Worker TypeNotPerAxisQuantized(const TensorInfo& info)
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
115*89c4ff92SAndroid Build Coastguard Worker }
116*89c4ff92SAndroid Build Coastguard Worker };
117*89c4ff92SAndroid Build Coastguard Worker
118*89c4ff92SAndroid Build Coastguard Worker struct BiasAndWeightsTypesMatch : public Rule
119*89c4ff92SAndroid Build Coastguard Worker {
BiasAndWeightsTypesMatcharmnn::BiasAndWeightsTypesMatch120*89c4ff92SAndroid Build Coastguard Worker BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
121*89c4ff92SAndroid Build Coastguard Worker {
122*89c4ff92SAndroid Build Coastguard Worker m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
123*89c4ff92SAndroid Build Coastguard Worker }
124*89c4ff92SAndroid Build Coastguard Worker };
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker struct BiasAndWeightsTypesCompatible : public Rule
127*89c4ff92SAndroid Build Coastguard Worker {
128*89c4ff92SAndroid Build Coastguard Worker template<typename Container>
BiasAndWeightsTypesCompatiblearmnn::BiasAndWeightsTypesCompatible129*89c4ff92SAndroid Build Coastguard Worker BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
132*89c4ff92SAndroid Build Coastguard Worker {
133*89c4ff92SAndroid Build Coastguard Worker return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
134*89c4ff92SAndroid Build Coastguard Worker });
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker };
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Worker struct ShapesAreSameRank : public Rule
139*89c4ff92SAndroid Build Coastguard Worker {
ShapesAreSameRankarmnn::ShapesAreSameRank140*89c4ff92SAndroid Build Coastguard Worker ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
141*89c4ff92SAndroid Build Coastguard Worker {
142*89c4ff92SAndroid Build Coastguard Worker m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
143*89c4ff92SAndroid Build Coastguard Worker }
144*89c4ff92SAndroid Build Coastguard Worker };
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker struct ShapesAreSameTotalSize : public Rule
147*89c4ff92SAndroid Build Coastguard Worker {
ShapesAreSameTotalSizearmnn::ShapesAreSameTotalSize148*89c4ff92SAndroid Build Coastguard Worker ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
149*89c4ff92SAndroid Build Coastguard Worker {
150*89c4ff92SAndroid Build Coastguard Worker m_Res = info0.GetNumElements() == info1.GetNumElements();
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker };
153*89c4ff92SAndroid Build Coastguard Worker
154*89c4ff92SAndroid Build Coastguard Worker struct ShapesAreBroadcastCompatible : public Rule
155*89c4ff92SAndroid Build Coastguard Worker {
CalcInputSizearmnn::ShapesAreBroadcastCompatible156*89c4ff92SAndroid Build Coastguard Worker unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
157*89c4ff92SAndroid Build Coastguard Worker {
158*89c4ff92SAndroid Build Coastguard Worker unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
159*89c4ff92SAndroid Build Coastguard Worker unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
160*89c4ff92SAndroid Build Coastguard Worker return sizeIn;
161*89c4ff92SAndroid Build Coastguard Worker }
162*89c4ff92SAndroid Build Coastguard Worker
ShapesAreBroadcastCompatiblearmnn::ShapesAreBroadcastCompatible163*89c4ff92SAndroid Build Coastguard Worker ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker const TensorShape& shape0 = in0.GetShape();
166*89c4ff92SAndroid Build Coastguard Worker const TensorShape& shape1 = in1.GetShape();
167*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outShape = out.GetShape();
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker unsigned int sizeOut = outShape[i];
172*89c4ff92SAndroid Build Coastguard Worker unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
173*89c4ff92SAndroid Build Coastguard Worker unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Worker m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
176*89c4ff92SAndroid Build Coastguard Worker ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
177*89c4ff92SAndroid Build Coastguard Worker }
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker };
180*89c4ff92SAndroid Build Coastguard Worker
181*89c4ff92SAndroid Build Coastguard Worker struct TensorNumDimensionsAreCorrect : public Rule
182*89c4ff92SAndroid Build Coastguard Worker {
TensorNumDimensionsAreCorrectarmnn::TensorNumDimensionsAreCorrect183*89c4ff92SAndroid Build Coastguard Worker TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker m_Res = info.GetNumDimensions() == expectedNumDimensions;
186*89c4ff92SAndroid Build Coastguard Worker }
187*89c4ff92SAndroid Build Coastguard Worker };
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Worker struct TensorNumDimensionsAreGreaterOrEqualTo : public Rule
190*89c4ff92SAndroid Build Coastguard Worker {
TensorNumDimensionsAreGreaterOrEqualToarmnn::TensorNumDimensionsAreGreaterOrEqualTo191*89c4ff92SAndroid Build Coastguard Worker TensorNumDimensionsAreGreaterOrEqualTo(const TensorInfo& info, unsigned int numDimensionsToCompare)
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker m_Res = info.GetNumDimensions() >= numDimensionsToCompare;
194*89c4ff92SAndroid Build Coastguard Worker }
195*89c4ff92SAndroid Build Coastguard Worker };
196*89c4ff92SAndroid Build Coastguard Worker
197*89c4ff92SAndroid Build Coastguard Worker } //namespace armnn