xref: /aosp_15_r20/external/armnn/src/backends/tosaCommon/TosaLayerSupportRules.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker // List of Layer Support Rules common to TOSA backends only, for use with CheckSupportRule()
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker struct TosaOperatorAttributeOfAny : public Rule
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker     template<typename Container>
TosaOperatorAttributeOfAnyTosaOperatorAttributeOfAny13*89c4ff92SAndroid Build Coastguard Worker     explicit TosaOperatorAttributeOfAny(TosaSerializationOperator* op, const Container& c)
14*89c4ff92SAndroid Build Coastguard Worker     {
15*89c4ff92SAndroid Build Coastguard Worker         m_Res = std::any_of(c.begin(), c.end(), [&op](Attribute attribute)
16*89c4ff92SAndroid Build Coastguard Worker         {
17*89c4ff92SAndroid Build Coastguard Worker             return attribute == op->GetAttributeType();
18*89c4ff92SAndroid Build Coastguard Worker         });
19*89c4ff92SAndroid Build Coastguard Worker     }
20*89c4ff92SAndroid Build Coastguard Worker };
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker struct TosaTypeAnyOf : public Rule
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker     template<typename Container>
TosaTypeAnyOfTosaTypeAnyOf25*89c4ff92SAndroid Build Coastguard Worker     TosaTypeAnyOf(TosaSerializationTensor* tensor, const Container& c)
26*89c4ff92SAndroid Build Coastguard Worker     {
27*89c4ff92SAndroid Build Coastguard Worker         m_Res = std::any_of(c.begin(), c.end(), [&tensor](DType dt)
28*89c4ff92SAndroid Build Coastguard Worker         {
29*89c4ff92SAndroid Build Coastguard Worker             return dt == tensor->GetDtype();
30*89c4ff92SAndroid Build Coastguard Worker         });
31*89c4ff92SAndroid Build Coastguard Worker     }
32*89c4ff92SAndroid Build Coastguard Worker };
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker struct TosaTensorNumDimensionsWithinBounds : public Rule
35*89c4ff92SAndroid Build Coastguard Worker {
TosaTensorNumDimensionsWithinBoundsTosaTensorNumDimensionsWithinBounds36*89c4ff92SAndroid Build Coastguard Worker     explicit TosaTensorNumDimensionsWithinBounds(TosaSerializationTensor* tensor)
37*89c4ff92SAndroid Build Coastguard Worker     {
38*89c4ff92SAndroid Build Coastguard Worker         m_Res = (tensor->GetShape().size() <= MaxNumOfTensorDimensions) || (!tensor->GetShape().empty());
39*89c4ff92SAndroid Build Coastguard Worker     }
40*89c4ff92SAndroid Build Coastguard Worker };
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker struct TosaAssertSize : public Rule
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker     template<typename Container>
TosaAssertSizeTosaAssertSize45*89c4ff92SAndroid Build Coastguard Worker     explicit TosaAssertSize(const Container& c1, const Container& c2)
46*89c4ff92SAndroid Build Coastguard Worker     {
47*89c4ff92SAndroid Build Coastguard Worker         m_Res = (c1.size() == c2.size());
48*89c4ff92SAndroid Build Coastguard Worker     }
49*89c4ff92SAndroid Build Coastguard Worker };
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker struct TosaContainerContainsTwoTypes : public Rule
52*89c4ff92SAndroid Build Coastguard Worker {
TosaContainerContainsTwoTypesTosaContainerContainsTwoTypes53*89c4ff92SAndroid Build Coastguard Worker     explicit TosaContainerContainsTwoTypes(std::tuple<DType, DType>& check,
54*89c4ff92SAndroid Build Coastguard Worker                                            const std::vector<std::tuple<DType, DType>>& c)
55*89c4ff92SAndroid Build Coastguard Worker     {
56*89c4ff92SAndroid Build Coastguard Worker         for (auto item: c)
57*89c4ff92SAndroid Build Coastguard Worker         {
58*89c4ff92SAndroid Build Coastguard Worker             if (std::get<0>(check) == std::get<0>(item) &&
59*89c4ff92SAndroid Build Coastguard Worker                 std::get<1>(check) == std::get<1>(item))
60*89c4ff92SAndroid Build Coastguard Worker             {
61*89c4ff92SAndroid Build Coastguard Worker                 m_Res = true;
62*89c4ff92SAndroid Build Coastguard Worker                 return;
63*89c4ff92SAndroid Build Coastguard Worker             }
64*89c4ff92SAndroid Build Coastguard Worker         }
65*89c4ff92SAndroid Build Coastguard Worker         m_Res = false;
66*89c4ff92SAndroid Build Coastguard Worker     }
67*89c4ff92SAndroid Build Coastguard Worker };
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker struct TosaContainerContainsThreeTypes : public Rule
70*89c4ff92SAndroid Build Coastguard Worker {
TosaContainerContainsThreeTypesTosaContainerContainsThreeTypes71*89c4ff92SAndroid Build Coastguard Worker     explicit TosaContainerContainsThreeTypes(std::tuple<DType, DType, DType>& check,
72*89c4ff92SAndroid Build Coastguard Worker                                              const std::vector<std::tuple<DType, DType, DType>>& c)
73*89c4ff92SAndroid Build Coastguard Worker     {
74*89c4ff92SAndroid Build Coastguard Worker         for (auto item: c)
75*89c4ff92SAndroid Build Coastguard Worker         {
76*89c4ff92SAndroid Build Coastguard Worker             if (std::get<0>(check) == std::get<0>(item) &&
77*89c4ff92SAndroid Build Coastguard Worker                 std::get<1>(check) == std::get<1>(item) &&
78*89c4ff92SAndroid Build Coastguard Worker                 std::get<2>(check) == std::get<2>(item))
79*89c4ff92SAndroid Build Coastguard Worker             {
80*89c4ff92SAndroid Build Coastguard Worker                 m_Res = true;
81*89c4ff92SAndroid Build Coastguard Worker                 return;
82*89c4ff92SAndroid Build Coastguard Worker             }
83*89c4ff92SAndroid Build Coastguard Worker         }
84*89c4ff92SAndroid Build Coastguard Worker         m_Res = false;
85*89c4ff92SAndroid Build Coastguard Worker     }
86*89c4ff92SAndroid Build Coastguard Worker };
87