xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Unsupported.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/StrategyBase.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <layers/StandInLayer.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <sstream>
16*89c4ff92SAndroid Build Coastguard Worker #include <vector>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_Unsupported")
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker class StandInLayerVerifier : public StrategyBase<NoThrowStrategy>
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker public:
StandInLayerVerifier(const std::vector<TensorInfo> & inputInfos,const std::vector<TensorInfo> & outputInfos)25*89c4ff92SAndroid Build Coastguard Worker     StandInLayerVerifier(const std::vector<TensorInfo>& inputInfos,
26*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<TensorInfo>& outputInfos)
27*89c4ff92SAndroid Build Coastguard Worker         : m_InputInfos(inputInfos)
28*89c4ff92SAndroid Build Coastguard Worker         , m_OutputInfos(outputInfos) {}
29*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)30*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
31*89c4ff92SAndroid Build Coastguard Worker                          const armnn::BaseDescriptor& descriptor,
32*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<armnn::ConstTensor>& constants,
33*89c4ff92SAndroid Build Coastguard Worker                          const char* name,
34*89c4ff92SAndroid Build Coastguard Worker                          const armnn::LayerBindingId id = 0) override
35*89c4ff92SAndroid Build Coastguard Worker     {
36*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(descriptor, constants, id);
37*89c4ff92SAndroid Build Coastguard Worker         switch (layer->GetType())
38*89c4ff92SAndroid Build Coastguard Worker         {
39*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::StandIn:
40*89c4ff92SAndroid Build Coastguard Worker             {
41*89c4ff92SAndroid Build Coastguard Worker                 auto standInDescriptor = static_cast<const armnn::StandInDescriptor&>(descriptor);
42*89c4ff92SAndroid Build Coastguard Worker                 unsigned int numInputs = armnn::numeric_cast<unsigned int>(m_InputInfos.size());
43*89c4ff92SAndroid Build Coastguard Worker                         CHECK(standInDescriptor.m_NumInputs    == numInputs);
44*89c4ff92SAndroid Build Coastguard Worker                         CHECK(layer->GetNumInputSlots() == numInputs);
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker                 unsigned int numOutputs = armnn::numeric_cast<unsigned int>(m_OutputInfos.size());
47*89c4ff92SAndroid Build Coastguard Worker                         CHECK(standInDescriptor.m_NumOutputs    == numOutputs);
48*89c4ff92SAndroid Build Coastguard Worker                         CHECK(layer->GetNumOutputSlots() == numOutputs);
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker                 const StandInLayer* standInLayer = PolymorphicDowncast<const StandInLayer*>(layer);
51*89c4ff92SAndroid Build Coastguard Worker                 for (unsigned int i = 0u; i < numInputs; ++i)
52*89c4ff92SAndroid Build Coastguard Worker                 {
53*89c4ff92SAndroid Build Coastguard Worker                     const OutputSlot* connectedSlot = standInLayer->GetInputSlot(i).GetConnectedOutputSlot();
54*89c4ff92SAndroid Build Coastguard Worker                             CHECK(connectedSlot != nullptr);
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker                     const TensorInfo& inputInfo = connectedSlot->GetTensorInfo();
57*89c4ff92SAndroid Build Coastguard Worker                             CHECK(inputInfo == m_InputInfos[i]);
58*89c4ff92SAndroid Build Coastguard Worker                 }
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker                 for (unsigned int i = 0u; i < numOutputs; ++i)
61*89c4ff92SAndroid Build Coastguard Worker                 {
62*89c4ff92SAndroid Build Coastguard Worker                     const TensorInfo& outputInfo = layer->GetOutputSlot(i).GetTensorInfo();
63*89c4ff92SAndroid Build Coastguard Worker                             CHECK(outputInfo == m_OutputInfos[i]);
64*89c4ff92SAndroid Build Coastguard Worker                 }
65*89c4ff92SAndroid Build Coastguard Worker                 break;
66*89c4ff92SAndroid Build Coastguard Worker             }
67*89c4ff92SAndroid Build Coastguard Worker             default:
68*89c4ff92SAndroid Build Coastguard Worker             {
69*89c4ff92SAndroid Build Coastguard Worker                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
70*89c4ff92SAndroid Build Coastguard Worker             }
71*89c4ff92SAndroid Build Coastguard Worker         }
72*89c4ff92SAndroid Build Coastguard Worker     }
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker private:
75*89c4ff92SAndroid Build Coastguard Worker     std::vector<TensorInfo> m_InputInfos;
76*89c4ff92SAndroid Build Coastguard Worker     std::vector<TensorInfo> m_OutputInfos;
77*89c4ff92SAndroid Build Coastguard Worker };
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker class DummyCustomFixture : public ParserFlatbuffersFixture
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker public:
DummyCustomFixture(const std::vector<TensorInfo> & inputInfos,const std::vector<TensorInfo> & outputInfos)82*89c4ff92SAndroid Build Coastguard Worker     explicit DummyCustomFixture(const std::vector<TensorInfo>& inputInfos,
83*89c4ff92SAndroid Build Coastguard Worker                                 const std::vector<TensorInfo>& outputInfos)
84*89c4ff92SAndroid Build Coastguard Worker         : ParserFlatbuffersFixture()
85*89c4ff92SAndroid Build Coastguard Worker         , m_StandInLayerVerifier(inputInfos, outputInfos)
86*89c4ff92SAndroid Build Coastguard Worker     {
87*89c4ff92SAndroid Build Coastguard Worker         const unsigned int numInputs = armnn::numeric_cast<unsigned int>(inputInfos.size());
88*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(numInputs > 0);
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker         const unsigned int numOutputs = armnn::numeric_cast<unsigned int>(outputInfos.size());
91*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(numOutputs > 0);
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
94*89c4ff92SAndroid Build Coastguard Worker             {
95*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
96*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [{
97*89c4ff92SAndroid Build Coastguard Worker                     "builtin_code": "CUSTOM",
98*89c4ff92SAndroid Build Coastguard Worker                     "custom_code": "DummyCustomOperator"
99*89c4ff92SAndroid Build Coastguard Worker                 }],
100*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [ {
101*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [)";
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker         // Add input tensors
104*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0u; i < numInputs; ++i)
105*89c4ff92SAndroid Build Coastguard Worker         {
106*89c4ff92SAndroid Build Coastguard Worker             const TensorInfo& inputInfo = inputInfos[i];
107*89c4ff92SAndroid Build Coastguard Worker             m_JsonString += R"(
108*89c4ff92SAndroid Build Coastguard Worker                     {
109*89c4ff92SAndroid Build Coastguard Worker                         "shape": )" + GetTensorShapeAsString(inputInfo.GetShape()) + R"(,
110*89c4ff92SAndroid Build Coastguard Worker                         "type": )" + GetDataTypeAsString(inputInfo.GetDataType()) + R"(,
111*89c4ff92SAndroid Build Coastguard Worker                         "buffer": 0,
112*89c4ff92SAndroid Build Coastguard Worker                         "name": "inputTensor)" + std::to_string(i) + R"(",
113*89c4ff92SAndroid Build Coastguard Worker                         "quantization": {
114*89c4ff92SAndroid Build Coastguard Worker                             "min": [ 0.0 ],
115*89c4ff92SAndroid Build Coastguard Worker                             "max": [ 255.0 ],
116*89c4ff92SAndroid Build Coastguard Worker                             "scale": [ )" + std::to_string(inputInfo.GetQuantizationScale()) + R"( ],
117*89c4ff92SAndroid Build Coastguard Worker                             "zero_point": [ )" + std::to_string(inputInfo.GetQuantizationOffset()) + R"( ],
118*89c4ff92SAndroid Build Coastguard Worker                         }
119*89c4ff92SAndroid Build Coastguard Worker                     },)";
120*89c4ff92SAndroid Build Coastguard Worker         }
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker         // Add output tensors
123*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0u; i < numOutputs; ++i)
124*89c4ff92SAndroid Build Coastguard Worker         {
125*89c4ff92SAndroid Build Coastguard Worker             const TensorInfo& outputInfo = outputInfos[i];
126*89c4ff92SAndroid Build Coastguard Worker             m_JsonString += R"(
127*89c4ff92SAndroid Build Coastguard Worker                     {
128*89c4ff92SAndroid Build Coastguard Worker                         "shape": )" + GetTensorShapeAsString(outputInfo.GetShape()) + R"(,
129*89c4ff92SAndroid Build Coastguard Worker                         "type": )" + GetDataTypeAsString(outputInfo.GetDataType()) + R"(,
130*89c4ff92SAndroid Build Coastguard Worker                         "buffer": 0,
131*89c4ff92SAndroid Build Coastguard Worker                         "name": "outputTensor)" + std::to_string(i) + R"(",
132*89c4ff92SAndroid Build Coastguard Worker                         "quantization": {
133*89c4ff92SAndroid Build Coastguard Worker                             "min": [ 0.0 ],
134*89c4ff92SAndroid Build Coastguard Worker                             "max": [ 255.0 ],
135*89c4ff92SAndroid Build Coastguard Worker                             "scale": [ )" + std::to_string(outputInfo.GetQuantizationScale()) + R"( ],
136*89c4ff92SAndroid Build Coastguard Worker                             "zero_point": [ )" + std::to_string(outputInfo.GetQuantizationOffset()) + R"( ],
137*89c4ff92SAndroid Build Coastguard Worker                         }
138*89c4ff92SAndroid Build Coastguard Worker                     })";
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker             if (i + 1 < numOutputs)
141*89c4ff92SAndroid Build Coastguard Worker             {
142*89c4ff92SAndroid Build Coastguard Worker                 m_JsonString += ",";
143*89c4ff92SAndroid Build Coastguard Worker             }
144*89c4ff92SAndroid Build Coastguard Worker         }
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker         const std::string inputIndices  = GetIndicesAsString(0u, numInputs - 1u);
147*89c4ff92SAndroid Build Coastguard Worker         const std::string outputIndices = GetIndicesAsString(numInputs, numInputs + numOutputs - 1u);
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker         // Add dummy custom operator
150*89c4ff92SAndroid Build Coastguard Worker         m_JsonString +=  R"(],
151*89c4ff92SAndroid Build Coastguard Worker                     "inputs": )" + inputIndices + R"(,
152*89c4ff92SAndroid Build Coastguard Worker                     "outputs": )" + outputIndices + R"(,
153*89c4ff92SAndroid Build Coastguard Worker                     "operators": [
154*89c4ff92SAndroid Build Coastguard Worker                         {
155*89c4ff92SAndroid Build Coastguard Worker                             "opcode_index": 0,
156*89c4ff92SAndroid Build Coastguard Worker                             "inputs": )" + inputIndices + R"(,
157*89c4ff92SAndroid Build Coastguard Worker                             "outputs": )" + outputIndices + R"(,
158*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options_type": 0,
159*89c4ff92SAndroid Build Coastguard Worker                             "custom_options": [ ],
160*89c4ff92SAndroid Build Coastguard Worker                             "custom_options_format": "FLEXBUFFERS"
161*89c4ff92SAndroid Build Coastguard Worker                         }
162*89c4ff92SAndroid Build Coastguard Worker                     ],
163*89c4ff92SAndroid Build Coastguard Worker                 } ],
164*89c4ff92SAndroid Build Coastguard Worker                 "buffers" : [
165*89c4ff92SAndroid Build Coastguard Worker                     { },
166*89c4ff92SAndroid Build Coastguard Worker                     { }
167*89c4ff92SAndroid Build Coastguard Worker                 ]
168*89c4ff92SAndroid Build Coastguard Worker             }
169*89c4ff92SAndroid Build Coastguard Worker         )";
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker         ReadStringToBinary();
172*89c4ff92SAndroid Build Coastguard Worker     }
173*89c4ff92SAndroid Build Coastguard Worker 
RunTest()174*89c4ff92SAndroid Build Coastguard Worker     void RunTest()
175*89c4ff92SAndroid Build Coastguard Worker     {
176*89c4ff92SAndroid Build Coastguard Worker         INetworkPtr network = m_Parser->CreateNetworkFromBinary(m_GraphBinary);
177*89c4ff92SAndroid Build Coastguard Worker         network->ExecuteStrategy(m_StandInLayerVerifier);
178*89c4ff92SAndroid Build Coastguard Worker     }
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker private:
GetTensorShapeAsString(const TensorShape & tensorShape)181*89c4ff92SAndroid Build Coastguard Worker     static std::string GetTensorShapeAsString(const TensorShape& tensorShape)
182*89c4ff92SAndroid Build Coastguard Worker     {
183*89c4ff92SAndroid Build Coastguard Worker         std::stringstream stream;
184*89c4ff92SAndroid Build Coastguard Worker         stream << "[ ";
185*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0u; i < tensorShape.GetNumDimensions(); ++i)
186*89c4ff92SAndroid Build Coastguard Worker         {
187*89c4ff92SAndroid Build Coastguard Worker             stream << tensorShape[i];
188*89c4ff92SAndroid Build Coastguard Worker             if (i + 1 < tensorShape.GetNumDimensions())
189*89c4ff92SAndroid Build Coastguard Worker             {
190*89c4ff92SAndroid Build Coastguard Worker                 stream << ",";
191*89c4ff92SAndroid Build Coastguard Worker             }
192*89c4ff92SAndroid Build Coastguard Worker             stream << " ";
193*89c4ff92SAndroid Build Coastguard Worker         }
194*89c4ff92SAndroid Build Coastguard Worker         stream << "]";
195*89c4ff92SAndroid Build Coastguard Worker 
196*89c4ff92SAndroid Build Coastguard Worker         return stream.str();
197*89c4ff92SAndroid Build Coastguard Worker     }
198*89c4ff92SAndroid Build Coastguard Worker 
GetDataTypeAsString(DataType dataType)199*89c4ff92SAndroid Build Coastguard Worker     static std::string GetDataTypeAsString(DataType dataType)
200*89c4ff92SAndroid Build Coastguard Worker     {
201*89c4ff92SAndroid Build Coastguard Worker         switch (dataType)
202*89c4ff92SAndroid Build Coastguard Worker         {
203*89c4ff92SAndroid Build Coastguard Worker             case DataType::Float32:         return "FLOAT32";
204*89c4ff92SAndroid Build Coastguard Worker             case DataType::QAsymmU8: return "UINT8";
205*89c4ff92SAndroid Build Coastguard Worker             default:                        return "UNKNOWN";
206*89c4ff92SAndroid Build Coastguard Worker         }
207*89c4ff92SAndroid Build Coastguard Worker     }
208*89c4ff92SAndroid Build Coastguard Worker 
GetIndicesAsString(unsigned int first,unsigned int last)209*89c4ff92SAndroid Build Coastguard Worker     static std::string GetIndicesAsString(unsigned int first, unsigned int last)
210*89c4ff92SAndroid Build Coastguard Worker     {
211*89c4ff92SAndroid Build Coastguard Worker         std::stringstream stream;
212*89c4ff92SAndroid Build Coastguard Worker         stream << "[ ";
213*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = first; i <= last ; ++i)
214*89c4ff92SAndroid Build Coastguard Worker         {
215*89c4ff92SAndroid Build Coastguard Worker             stream << i;
216*89c4ff92SAndroid Build Coastguard Worker             if (i + 1 <= last)
217*89c4ff92SAndroid Build Coastguard Worker             {
218*89c4ff92SAndroid Build Coastguard Worker                 stream << ",";
219*89c4ff92SAndroid Build Coastguard Worker             }
220*89c4ff92SAndroid Build Coastguard Worker             stream << " ";
221*89c4ff92SAndroid Build Coastguard Worker         }
222*89c4ff92SAndroid Build Coastguard Worker         stream << "]";
223*89c4ff92SAndroid Build Coastguard Worker 
224*89c4ff92SAndroid Build Coastguard Worker         return stream.str();
225*89c4ff92SAndroid Build Coastguard Worker     }
226*89c4ff92SAndroid Build Coastguard Worker 
227*89c4ff92SAndroid Build Coastguard Worker     StandInLayerVerifier m_StandInLayerVerifier;
228*89c4ff92SAndroid Build Coastguard Worker };
229*89c4ff92SAndroid Build Coastguard Worker 
230*89c4ff92SAndroid Build Coastguard Worker class DummyCustom1Input1OutputFixture : public DummyCustomFixture
231*89c4ff92SAndroid Build Coastguard Worker {
232*89c4ff92SAndroid Build Coastguard Worker public:
DummyCustom1Input1OutputFixture()233*89c4ff92SAndroid Build Coastguard Worker     DummyCustom1Input1OutputFixture()
234*89c4ff92SAndroid Build Coastguard Worker         : DummyCustomFixture({ TensorInfo({ 1, 1 }, DataType::Float32) },
235*89c4ff92SAndroid Build Coastguard Worker                              { TensorInfo({ 2, 2 }, DataType::Float32) }) {}
236*89c4ff92SAndroid Build Coastguard Worker };
237*89c4ff92SAndroid Build Coastguard Worker 
238*89c4ff92SAndroid Build Coastguard Worker class DummyCustom2Inputs1OutputFixture : public DummyCustomFixture
239*89c4ff92SAndroid Build Coastguard Worker {
240*89c4ff92SAndroid Build Coastguard Worker public:
DummyCustom2Inputs1OutputFixture()241*89c4ff92SAndroid Build Coastguard Worker     DummyCustom2Inputs1OutputFixture()
242*89c4ff92SAndroid Build Coastguard Worker         : DummyCustomFixture({ TensorInfo({ 1, 1 }, DataType::Float32), TensorInfo({ 2, 2 }, DataType::Float32) },
243*89c4ff92SAndroid Build Coastguard Worker                              { TensorInfo({ 3, 3 }, DataType::Float32) }) {}
244*89c4ff92SAndroid Build Coastguard Worker };
245*89c4ff92SAndroid Build Coastguard Worker 
246*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DummyCustom1Input1OutputFixture, "UnsupportedCustomOperator1Input1Output")
247*89c4ff92SAndroid Build Coastguard Worker {
248*89c4ff92SAndroid Build Coastguard Worker     RunTest();
249*89c4ff92SAndroid Build Coastguard Worker }
250*89c4ff92SAndroid Build Coastguard Worker 
251*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DummyCustom2Inputs1OutputFixture, "UnsupportedCustomOperator2Inputs1Output")
252*89c4ff92SAndroid Build Coastguard Worker {
253*89c4ff92SAndroid Build Coastguard Worker     RunTest();
254*89c4ff92SAndroid Build Coastguard Worker }
255*89c4ff92SAndroid Build Coastguard Worker 
256*89c4ff92SAndroid Build Coastguard Worker }
257