1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp" 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/StrategyBase.hpp> 9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp> 11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp> 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker #include <layers/StandInLayer.hpp> 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker #include <sstream> 16*89c4ff92SAndroid Build Coastguard Worker #include <vector> 17*89c4ff92SAndroid Build Coastguard Worker 18*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_Unsupported") 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker class StandInLayerVerifier : public StrategyBase<NoThrowStrategy> 23*89c4ff92SAndroid Build Coastguard Worker { 24*89c4ff92SAndroid Build Coastguard Worker public: StandInLayerVerifier(const std::vector<TensorInfo> & inputInfos,const std::vector<TensorInfo> & outputInfos)25*89c4ff92SAndroid Build Coastguard Worker StandInLayerVerifier(const std::vector<TensorInfo>& inputInfos, 26*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorInfo>& outputInfos) 27*89c4ff92SAndroid Build Coastguard Worker : m_InputInfos(inputInfos) 28*89c4ff92SAndroid Build Coastguard Worker , m_OutputInfos(outputInfos) {} 29*89c4ff92SAndroid Build Coastguard Worker ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)30*89c4ff92SAndroid Build Coastguard Worker void ExecuteStrategy(const armnn::IConnectableLayer* layer, 31*89c4ff92SAndroid Build Coastguard Worker const armnn::BaseDescriptor& descriptor, 32*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::ConstTensor>& constants, 33*89c4ff92SAndroid Build Coastguard Worker const char* name, 34*89c4ff92SAndroid Build Coastguard Worker const armnn::LayerBindingId id = 0) override 35*89c4ff92SAndroid Build Coastguard Worker { 36*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(descriptor, constants, id); 37*89c4ff92SAndroid Build Coastguard Worker switch (layer->GetType()) 38*89c4ff92SAndroid Build Coastguard Worker { 39*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::StandIn: 40*89c4ff92SAndroid Build Coastguard Worker { 41*89c4ff92SAndroid Build Coastguard Worker auto standInDescriptor = static_cast<const armnn::StandInDescriptor&>(descriptor); 42*89c4ff92SAndroid Build Coastguard Worker unsigned int numInputs = armnn::numeric_cast<unsigned int>(m_InputInfos.size()); 43*89c4ff92SAndroid Build Coastguard Worker CHECK(standInDescriptor.m_NumInputs == numInputs); 44*89c4ff92SAndroid Build Coastguard Worker CHECK(layer->GetNumInputSlots() == numInputs); 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker unsigned int numOutputs = armnn::numeric_cast<unsigned int>(m_OutputInfos.size()); 47*89c4ff92SAndroid Build Coastguard Worker CHECK(standInDescriptor.m_NumOutputs == numOutputs); 48*89c4ff92SAndroid Build Coastguard Worker CHECK(layer->GetNumOutputSlots() == numOutputs); 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker const StandInLayer* standInLayer = PolymorphicDowncast<const StandInLayer*>(layer); 51*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < numInputs; ++i) 52*89c4ff92SAndroid Build Coastguard Worker { 53*89c4ff92SAndroid Build Coastguard Worker const OutputSlot* connectedSlot = standInLayer->GetInputSlot(i).GetConnectedOutputSlot(); 54*89c4ff92SAndroid Build Coastguard Worker CHECK(connectedSlot != nullptr); 55*89c4ff92SAndroid Build Coastguard Worker 56*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputInfo = connectedSlot->GetTensorInfo(); 57*89c4ff92SAndroid Build Coastguard Worker CHECK(inputInfo == m_InputInfos[i]); 58*89c4ff92SAndroid Build Coastguard Worker } 59*89c4ff92SAndroid Build Coastguard Worker 60*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < numOutputs; ++i) 61*89c4ff92SAndroid Build Coastguard Worker { 62*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputInfo = layer->GetOutputSlot(i).GetTensorInfo(); 63*89c4ff92SAndroid Build Coastguard Worker CHECK(outputInfo == m_OutputInfos[i]); 64*89c4ff92SAndroid Build Coastguard Worker } 65*89c4ff92SAndroid Build Coastguard Worker break; 66*89c4ff92SAndroid Build Coastguard Worker } 67*89c4ff92SAndroid Build Coastguard Worker default: 68*89c4ff92SAndroid Build Coastguard Worker { 69*89c4ff92SAndroid Build Coastguard Worker m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType())); 70*89c4ff92SAndroid Build Coastguard Worker } 71*89c4ff92SAndroid Build Coastguard Worker } 72*89c4ff92SAndroid Build Coastguard Worker } 73*89c4ff92SAndroid Build Coastguard Worker 74*89c4ff92SAndroid Build Coastguard Worker private: 75*89c4ff92SAndroid Build Coastguard Worker std::vector<TensorInfo> m_InputInfos; 76*89c4ff92SAndroid Build Coastguard Worker std::vector<TensorInfo> m_OutputInfos; 77*89c4ff92SAndroid Build Coastguard Worker }; 78*89c4ff92SAndroid Build Coastguard Worker 79*89c4ff92SAndroid Build Coastguard Worker class DummyCustomFixture : public ParserFlatbuffersFixture 80*89c4ff92SAndroid Build Coastguard Worker { 81*89c4ff92SAndroid Build Coastguard Worker public: DummyCustomFixture(const std::vector<TensorInfo> & inputInfos,const std::vector<TensorInfo> & outputInfos)82*89c4ff92SAndroid Build Coastguard Worker explicit DummyCustomFixture(const std::vector<TensorInfo>& inputInfos, 83*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorInfo>& outputInfos) 84*89c4ff92SAndroid Build Coastguard Worker : ParserFlatbuffersFixture() 85*89c4ff92SAndroid Build Coastguard Worker , m_StandInLayerVerifier(inputInfos, outputInfos) 86*89c4ff92SAndroid Build Coastguard Worker { 87*89c4ff92SAndroid Build Coastguard Worker const unsigned int numInputs = armnn::numeric_cast<unsigned int>(inputInfos.size()); 88*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(numInputs > 0); 89*89c4ff92SAndroid Build Coastguard Worker 90*89c4ff92SAndroid Build Coastguard Worker const unsigned int numOutputs = armnn::numeric_cast<unsigned int>(outputInfos.size()); 91*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(numOutputs > 0); 92*89c4ff92SAndroid Build Coastguard Worker 93*89c4ff92SAndroid Build Coastguard Worker m_JsonString = R"( 94*89c4ff92SAndroid Build Coastguard Worker { 95*89c4ff92SAndroid Build Coastguard Worker "version": 3, 96*89c4ff92SAndroid Build Coastguard Worker "operator_codes": [{ 97*89c4ff92SAndroid Build Coastguard Worker "builtin_code": "CUSTOM", 98*89c4ff92SAndroid Build Coastguard Worker "custom_code": "DummyCustomOperator" 99*89c4ff92SAndroid Build Coastguard Worker }], 100*89c4ff92SAndroid Build Coastguard Worker "subgraphs": [ { 101*89c4ff92SAndroid Build Coastguard Worker "tensors": [)"; 102*89c4ff92SAndroid Build Coastguard Worker 103*89c4ff92SAndroid Build Coastguard Worker // Add input tensors 104*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < numInputs; ++i) 105*89c4ff92SAndroid Build Coastguard Worker { 106*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputInfo = inputInfos[i]; 107*89c4ff92SAndroid Build Coastguard Worker m_JsonString += R"( 108*89c4ff92SAndroid Build Coastguard Worker { 109*89c4ff92SAndroid Build Coastguard Worker "shape": )" + GetTensorShapeAsString(inputInfo.GetShape()) + R"(, 110*89c4ff92SAndroid Build Coastguard Worker "type": )" + GetDataTypeAsString(inputInfo.GetDataType()) + R"(, 111*89c4ff92SAndroid Build Coastguard Worker "buffer": 0, 112*89c4ff92SAndroid Build Coastguard Worker "name": "inputTensor)" + std::to_string(i) + R"(", 113*89c4ff92SAndroid Build Coastguard Worker "quantization": { 114*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 115*89c4ff92SAndroid Build Coastguard Worker "max": [ 255.0 ], 116*89c4ff92SAndroid Build Coastguard Worker "scale": [ )" + std::to_string(inputInfo.GetQuantizationScale()) + R"( ], 117*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ )" + std::to_string(inputInfo.GetQuantizationOffset()) + R"( ], 118*89c4ff92SAndroid Build Coastguard Worker } 119*89c4ff92SAndroid Build Coastguard Worker },)"; 120*89c4ff92SAndroid Build Coastguard Worker } 121*89c4ff92SAndroid Build Coastguard Worker 122*89c4ff92SAndroid Build Coastguard Worker // Add output tensors 123*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < numOutputs; ++i) 124*89c4ff92SAndroid Build Coastguard Worker { 125*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputInfo = outputInfos[i]; 126*89c4ff92SAndroid Build Coastguard Worker m_JsonString += R"( 127*89c4ff92SAndroid Build Coastguard Worker { 128*89c4ff92SAndroid Build Coastguard Worker "shape": )" + GetTensorShapeAsString(outputInfo.GetShape()) + R"(, 129*89c4ff92SAndroid Build Coastguard Worker "type": )" + GetDataTypeAsString(outputInfo.GetDataType()) + R"(, 130*89c4ff92SAndroid Build Coastguard Worker "buffer": 0, 131*89c4ff92SAndroid Build Coastguard Worker "name": "outputTensor)" + std::to_string(i) + R"(", 132*89c4ff92SAndroid Build Coastguard Worker "quantization": { 133*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 134*89c4ff92SAndroid Build Coastguard Worker "max": [ 255.0 ], 135*89c4ff92SAndroid Build Coastguard Worker "scale": [ )" + std::to_string(outputInfo.GetQuantizationScale()) + R"( ], 136*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ )" + std::to_string(outputInfo.GetQuantizationOffset()) + R"( ], 137*89c4ff92SAndroid Build Coastguard Worker } 138*89c4ff92SAndroid Build Coastguard Worker })"; 139*89c4ff92SAndroid Build Coastguard Worker 140*89c4ff92SAndroid Build Coastguard Worker if (i + 1 < numOutputs) 141*89c4ff92SAndroid Build Coastguard Worker { 142*89c4ff92SAndroid Build Coastguard Worker m_JsonString += ","; 143*89c4ff92SAndroid Build Coastguard Worker } 144*89c4ff92SAndroid Build Coastguard Worker } 145*89c4ff92SAndroid Build Coastguard Worker 146*89c4ff92SAndroid Build Coastguard Worker const std::string inputIndices = GetIndicesAsString(0u, numInputs - 1u); 147*89c4ff92SAndroid Build Coastguard Worker const std::string outputIndices = GetIndicesAsString(numInputs, numInputs + numOutputs - 1u); 148*89c4ff92SAndroid Build Coastguard Worker 149*89c4ff92SAndroid Build Coastguard Worker // Add dummy custom operator 150*89c4ff92SAndroid Build Coastguard Worker m_JsonString += R"(], 151*89c4ff92SAndroid Build Coastguard Worker "inputs": )" + inputIndices + R"(, 152*89c4ff92SAndroid Build Coastguard Worker "outputs": )" + outputIndices + R"(, 153*89c4ff92SAndroid Build Coastguard Worker "operators": [ 154*89c4ff92SAndroid Build Coastguard Worker { 155*89c4ff92SAndroid Build Coastguard Worker "opcode_index": 0, 156*89c4ff92SAndroid Build Coastguard Worker "inputs": )" + inputIndices + R"(, 157*89c4ff92SAndroid Build Coastguard Worker "outputs": )" + outputIndices + R"(, 158*89c4ff92SAndroid Build Coastguard Worker "builtin_options_type": 0, 159*89c4ff92SAndroid Build Coastguard Worker "custom_options": [ ], 160*89c4ff92SAndroid Build Coastguard Worker "custom_options_format": "FLEXBUFFERS" 161*89c4ff92SAndroid Build Coastguard Worker } 162*89c4ff92SAndroid Build Coastguard Worker ], 163*89c4ff92SAndroid Build Coastguard Worker } ], 164*89c4ff92SAndroid Build Coastguard Worker "buffers" : [ 165*89c4ff92SAndroid Build Coastguard Worker { }, 166*89c4ff92SAndroid Build Coastguard Worker { } 167*89c4ff92SAndroid Build Coastguard Worker ] 168*89c4ff92SAndroid Build Coastguard Worker } 169*89c4ff92SAndroid Build Coastguard Worker )"; 170*89c4ff92SAndroid Build Coastguard Worker 171*89c4ff92SAndroid Build Coastguard Worker ReadStringToBinary(); 172*89c4ff92SAndroid Build Coastguard Worker } 173*89c4ff92SAndroid Build Coastguard Worker RunTest()174*89c4ff92SAndroid Build Coastguard Worker void RunTest() 175*89c4ff92SAndroid Build Coastguard Worker { 176*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = m_Parser->CreateNetworkFromBinary(m_GraphBinary); 177*89c4ff92SAndroid Build Coastguard Worker network->ExecuteStrategy(m_StandInLayerVerifier); 178*89c4ff92SAndroid Build Coastguard Worker } 179*89c4ff92SAndroid Build Coastguard Worker 180*89c4ff92SAndroid Build Coastguard Worker private: GetTensorShapeAsString(const TensorShape & tensorShape)181*89c4ff92SAndroid Build Coastguard Worker static std::string GetTensorShapeAsString(const TensorShape& tensorShape) 182*89c4ff92SAndroid Build Coastguard Worker { 183*89c4ff92SAndroid Build Coastguard Worker std::stringstream stream; 184*89c4ff92SAndroid Build Coastguard Worker stream << "[ "; 185*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < tensorShape.GetNumDimensions(); ++i) 186*89c4ff92SAndroid Build Coastguard Worker { 187*89c4ff92SAndroid Build Coastguard Worker stream << tensorShape[i]; 188*89c4ff92SAndroid Build Coastguard Worker if (i + 1 < tensorShape.GetNumDimensions()) 189*89c4ff92SAndroid Build Coastguard Worker { 190*89c4ff92SAndroid Build Coastguard Worker stream << ","; 191*89c4ff92SAndroid Build Coastguard Worker } 192*89c4ff92SAndroid Build Coastguard Worker stream << " "; 193*89c4ff92SAndroid Build Coastguard Worker } 194*89c4ff92SAndroid Build Coastguard Worker stream << "]"; 195*89c4ff92SAndroid Build Coastguard Worker 196*89c4ff92SAndroid Build Coastguard Worker return stream.str(); 197*89c4ff92SAndroid Build Coastguard Worker } 198*89c4ff92SAndroid Build Coastguard Worker GetDataTypeAsString(DataType dataType)199*89c4ff92SAndroid Build Coastguard Worker static std::string GetDataTypeAsString(DataType dataType) 200*89c4ff92SAndroid Build Coastguard Worker { 201*89c4ff92SAndroid Build Coastguard Worker switch (dataType) 202*89c4ff92SAndroid Build Coastguard Worker { 203*89c4ff92SAndroid Build Coastguard Worker case DataType::Float32: return "FLOAT32"; 204*89c4ff92SAndroid Build Coastguard Worker case DataType::QAsymmU8: return "UINT8"; 205*89c4ff92SAndroid Build Coastguard Worker default: return "UNKNOWN"; 206*89c4ff92SAndroid Build Coastguard Worker } 207*89c4ff92SAndroid Build Coastguard Worker } 208*89c4ff92SAndroid Build Coastguard Worker GetIndicesAsString(unsigned int first,unsigned int last)209*89c4ff92SAndroid Build Coastguard Worker static std::string GetIndicesAsString(unsigned int first, unsigned int last) 210*89c4ff92SAndroid Build Coastguard Worker { 211*89c4ff92SAndroid Build Coastguard Worker std::stringstream stream; 212*89c4ff92SAndroid Build Coastguard Worker stream << "[ "; 213*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = first; i <= last ; ++i) 214*89c4ff92SAndroid Build Coastguard Worker { 215*89c4ff92SAndroid Build Coastguard Worker stream << i; 216*89c4ff92SAndroid Build Coastguard Worker if (i + 1 <= last) 217*89c4ff92SAndroid Build Coastguard Worker { 218*89c4ff92SAndroid Build Coastguard Worker stream << ","; 219*89c4ff92SAndroid Build Coastguard Worker } 220*89c4ff92SAndroid Build Coastguard Worker stream << " "; 221*89c4ff92SAndroid Build Coastguard Worker } 222*89c4ff92SAndroid Build Coastguard Worker stream << "]"; 223*89c4ff92SAndroid Build Coastguard Worker 224*89c4ff92SAndroid Build Coastguard Worker return stream.str(); 225*89c4ff92SAndroid Build Coastguard Worker } 226*89c4ff92SAndroid Build Coastguard Worker 227*89c4ff92SAndroid Build Coastguard Worker StandInLayerVerifier m_StandInLayerVerifier; 228*89c4ff92SAndroid Build Coastguard Worker }; 229*89c4ff92SAndroid Build Coastguard Worker 230*89c4ff92SAndroid Build Coastguard Worker class DummyCustom1Input1OutputFixture : public DummyCustomFixture 231*89c4ff92SAndroid Build Coastguard Worker { 232*89c4ff92SAndroid Build Coastguard Worker public: DummyCustom1Input1OutputFixture()233*89c4ff92SAndroid Build Coastguard Worker DummyCustom1Input1OutputFixture() 234*89c4ff92SAndroid Build Coastguard Worker : DummyCustomFixture({ TensorInfo({ 1, 1 }, DataType::Float32) }, 235*89c4ff92SAndroid Build Coastguard Worker { TensorInfo({ 2, 2 }, DataType::Float32) }) {} 236*89c4ff92SAndroid Build Coastguard Worker }; 237*89c4ff92SAndroid Build Coastguard Worker 238*89c4ff92SAndroid Build Coastguard Worker class DummyCustom2Inputs1OutputFixture : public DummyCustomFixture 239*89c4ff92SAndroid Build Coastguard Worker { 240*89c4ff92SAndroid Build Coastguard Worker public: DummyCustom2Inputs1OutputFixture()241*89c4ff92SAndroid Build Coastguard Worker DummyCustom2Inputs1OutputFixture() 242*89c4ff92SAndroid Build Coastguard Worker : DummyCustomFixture({ TensorInfo({ 1, 1 }, DataType::Float32), TensorInfo({ 2, 2 }, DataType::Float32) }, 243*89c4ff92SAndroid Build Coastguard Worker { TensorInfo({ 3, 3 }, DataType::Float32) }) {} 244*89c4ff92SAndroid Build Coastguard Worker }; 245*89c4ff92SAndroid Build Coastguard Worker 246*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DummyCustom1Input1OutputFixture, "UnsupportedCustomOperator1Input1Output") 247*89c4ff92SAndroid Build Coastguard Worker { 248*89c4ff92SAndroid Build Coastguard Worker RunTest(); 249*89c4ff92SAndroid Build Coastguard Worker } 250*89c4ff92SAndroid Build Coastguard Worker 251*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DummyCustom2Inputs1OutputFixture, "UnsupportedCustomOperator2Inputs1Output") 252*89c4ff92SAndroid Build Coastguard Worker { 253*89c4ff92SAndroid Build Coastguard Worker RunTest(); 254*89c4ff92SAndroid Build Coastguard Worker } 255*89c4ff92SAndroid Build Coastguard Worker 256*89c4ff92SAndroid Build Coastguard Worker } 257