1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp" 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker using armnnTfLiteParser::TfLiteParserImpl; 11*89c4ff92SAndroid Build Coastguard Worker using ModelPtr = TfLiteParserImpl::ModelPtr; 12*89c4ff92SAndroid Build Coastguard Worker using SubgraphPtr = TfLiteParserImpl::SubgraphPtr; 13*89c4ff92SAndroid Build Coastguard Worker using OperatorPtr = TfLiteParserImpl::OperatorPtr; 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_LoadModel") 16*89c4ff92SAndroid Build Coastguard Worker { 17*89c4ff92SAndroid Build Coastguard Worker struct LoadModelFixture : public ParserFlatbuffersFixture 18*89c4ff92SAndroid Build Coastguard Worker { LoadModelFixtureLoadModelFixture19*89c4ff92SAndroid Build Coastguard Worker explicit LoadModelFixture() 20*89c4ff92SAndroid Build Coastguard Worker { 21*89c4ff92SAndroid Build Coastguard Worker m_JsonString = R"( 22*89c4ff92SAndroid Build Coastguard Worker { 23*89c4ff92SAndroid Build Coastguard Worker "version": 3, 24*89c4ff92SAndroid Build Coastguard Worker "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ], 25*89c4ff92SAndroid Build Coastguard Worker "subgraphs": [ 26*89c4ff92SAndroid Build Coastguard Worker { 27*89c4ff92SAndroid Build Coastguard Worker "tensors": [ 28*89c4ff92SAndroid Build Coastguard Worker { 29*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 1, 1, 1 ] , 30*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 31*89c4ff92SAndroid Build Coastguard Worker "buffer": 0, 32*89c4ff92SAndroid Build Coastguard Worker "name": "OutputTensor", 33*89c4ff92SAndroid Build Coastguard Worker "quantization": { 34*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 35*89c4ff92SAndroid Build Coastguard Worker "max": [ 255.0 ], 36*89c4ff92SAndroid Build Coastguard Worker "scale": [ 1.0 ], 37*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ] 38*89c4ff92SAndroid Build Coastguard Worker } 39*89c4ff92SAndroid Build Coastguard Worker }, 40*89c4ff92SAndroid Build Coastguard Worker { 41*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 2, 2, 1 ] , 42*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 43*89c4ff92SAndroid Build Coastguard Worker "buffer": 1, 44*89c4ff92SAndroid Build Coastguard Worker "name": "InputTensor", 45*89c4ff92SAndroid Build Coastguard Worker "quantization": { 46*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 47*89c4ff92SAndroid Build Coastguard Worker "max": [ 255.0 ], 48*89c4ff92SAndroid Build Coastguard Worker "scale": [ 1.0 ], 49*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ] 50*89c4ff92SAndroid Build Coastguard Worker } 51*89c4ff92SAndroid Build Coastguard Worker } 52*89c4ff92SAndroid Build Coastguard Worker ], 53*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 1 ], 54*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 0 ], 55*89c4ff92SAndroid Build Coastguard Worker "operators": [ { 56*89c4ff92SAndroid Build Coastguard Worker "opcode_index": 0, 57*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 1 ], 58*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 0 ], 59*89c4ff92SAndroid Build Coastguard Worker "builtin_options_type": "Pool2DOptions", 60*89c4ff92SAndroid Build Coastguard Worker "builtin_options": 61*89c4ff92SAndroid Build Coastguard Worker { 62*89c4ff92SAndroid Build Coastguard Worker "padding": "VALID", 63*89c4ff92SAndroid Build Coastguard Worker "stride_w": 2, 64*89c4ff92SAndroid Build Coastguard Worker "stride_h": 2, 65*89c4ff92SAndroid Build Coastguard Worker "filter_width": 2, 66*89c4ff92SAndroid Build Coastguard Worker "filter_height": 2, 67*89c4ff92SAndroid Build Coastguard Worker "fused_activation_function": "NONE" 68*89c4ff92SAndroid Build Coastguard Worker }, 69*89c4ff92SAndroid Build Coastguard Worker "custom_options_format": "FLEXBUFFERS" 70*89c4ff92SAndroid Build Coastguard Worker } ] 71*89c4ff92SAndroid Build Coastguard Worker }, 72*89c4ff92SAndroid Build Coastguard Worker { 73*89c4ff92SAndroid Build Coastguard Worker "tensors": [ 74*89c4ff92SAndroid Build Coastguard Worker { 75*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 3, 3, 1 ], 76*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 77*89c4ff92SAndroid Build Coastguard Worker "buffer": 0, 78*89c4ff92SAndroid Build Coastguard Worker "name": "ConvInputTensor", 79*89c4ff92SAndroid Build Coastguard Worker "quantization": { 80*89c4ff92SAndroid Build Coastguard Worker "scale": [ 1.0 ], 81*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ], 82*89c4ff92SAndroid Build Coastguard Worker } 83*89c4ff92SAndroid Build Coastguard Worker }, 84*89c4ff92SAndroid Build Coastguard Worker { 85*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 1, 1, 1 ], 86*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 87*89c4ff92SAndroid Build Coastguard Worker "buffer": 1, 88*89c4ff92SAndroid Build Coastguard Worker "name": "ConvOutputTensor", 89*89c4ff92SAndroid Build Coastguard Worker "quantization": { 90*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 91*89c4ff92SAndroid Build Coastguard Worker "max": [ 511.0 ], 92*89c4ff92SAndroid Build Coastguard Worker "scale": [ 2.0 ], 93*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ], 94*89c4ff92SAndroid Build Coastguard Worker } 95*89c4ff92SAndroid Build Coastguard Worker }, 96*89c4ff92SAndroid Build Coastguard Worker { 97*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 3, 3, 1 ], 98*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 99*89c4ff92SAndroid Build Coastguard Worker "buffer": 2, 100*89c4ff92SAndroid Build Coastguard Worker "name": "filterTensor", 101*89c4ff92SAndroid Build Coastguard Worker "quantization": { 102*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 103*89c4ff92SAndroid Build Coastguard Worker "max": [ 255.0 ], 104*89c4ff92SAndroid Build Coastguard Worker "scale": [ 1.0 ], 105*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ], 106*89c4ff92SAndroid Build Coastguard Worker } 107*89c4ff92SAndroid Build Coastguard Worker } 108*89c4ff92SAndroid Build Coastguard Worker ], 109*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 0 ], 110*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 1 ], 111*89c4ff92SAndroid Build Coastguard Worker "operators": [ 112*89c4ff92SAndroid Build Coastguard Worker { 113*89c4ff92SAndroid Build Coastguard Worker "opcode_index": 1, 114*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 0, 2 ], 115*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 1 ], 116*89c4ff92SAndroid Build Coastguard Worker "builtin_options_type": "Conv2DOptions", 117*89c4ff92SAndroid Build Coastguard Worker "builtin_options": { 118*89c4ff92SAndroid Build Coastguard Worker "padding": "VALID", 119*89c4ff92SAndroid Build Coastguard Worker "stride_w": 1, 120*89c4ff92SAndroid Build Coastguard Worker "stride_h": 1, 121*89c4ff92SAndroid Build Coastguard Worker "fused_activation_function": "NONE" 122*89c4ff92SAndroid Build Coastguard Worker }, 123*89c4ff92SAndroid Build Coastguard Worker "custom_options_format": "FLEXBUFFERS" 124*89c4ff92SAndroid Build Coastguard Worker } 125*89c4ff92SAndroid Build Coastguard Worker ], 126*89c4ff92SAndroid Build Coastguard Worker } 127*89c4ff92SAndroid Build Coastguard Worker ], 128*89c4ff92SAndroid Build Coastguard Worker "description": "Test loading a model", 129*89c4ff92SAndroid Build Coastguard Worker "buffers" : [ {}, {} ] 130*89c4ff92SAndroid Build Coastguard Worker })"; 131*89c4ff92SAndroid Build Coastguard Worker 132*89c4ff92SAndroid Build Coastguard Worker ReadStringToBinary(); 133*89c4ff92SAndroid Build Coastguard Worker } 134*89c4ff92SAndroid Build Coastguard Worker CheckModelLoadModelFixture135*89c4ff92SAndroid Build Coastguard Worker void CheckModel(const ModelPtr& model, uint32_t version, size_t opcodeSize, 136*89c4ff92SAndroid Build Coastguard Worker const std::vector<tflite::BuiltinOperator>& opcodes, 137*89c4ff92SAndroid Build Coastguard Worker size_t subgraphs, const std::string desc, size_t buffers) 138*89c4ff92SAndroid Build Coastguard Worker { 139*89c4ff92SAndroid Build Coastguard Worker CHECK(model); 140*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(version, model->version); 141*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(opcodeSize, model->operator_codes.size()); 142*89c4ff92SAndroid Build Coastguard Worker CheckBuiltinOperators(opcodes, model->operator_codes); 143*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(subgraphs, model->subgraphs.size()); 144*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(desc, model->description); 145*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(buffers, model->buffers.size()); 146*89c4ff92SAndroid Build Coastguard Worker } 147*89c4ff92SAndroid Build Coastguard Worker CheckBuiltinOperatorsLoadModelFixture148*89c4ff92SAndroid Build Coastguard Worker void CheckBuiltinOperators(const std::vector<tflite::BuiltinOperator>& expectedOperators, 149*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& result) 150*89c4ff92SAndroid Build Coastguard Worker { 151*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(expectedOperators.size(), result.size()); 152*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < expectedOperators.size(); i++) 153*89c4ff92SAndroid Build Coastguard Worker { 154*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(expectedOperators[i], result[i]->builtin_code); 155*89c4ff92SAndroid Build Coastguard Worker } 156*89c4ff92SAndroid Build Coastguard Worker } 157*89c4ff92SAndroid Build Coastguard Worker CheckSubgraphLoadModelFixture158*89c4ff92SAndroid Build Coastguard Worker void CheckSubgraph(const SubgraphPtr& subgraph, size_t tensors, const std::vector<int32_t>& inputs, 159*89c4ff92SAndroid Build Coastguard Worker const std::vector<int32_t>& outputs, size_t operators, const std::string& name) 160*89c4ff92SAndroid Build Coastguard Worker { 161*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraph); 162*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(tensors, subgraph->tensors.size()); 163*89c4ff92SAndroid Build Coastguard Worker CHECK(std::equal(inputs.begin(), inputs.end(), subgraph->inputs.begin(), subgraph->inputs.end())); 164*89c4ff92SAndroid Build Coastguard Worker CHECK(std::equal(outputs.begin(), outputs.end(), 165*89c4ff92SAndroid Build Coastguard Worker subgraph->outputs.begin(), subgraph->outputs.end())); 166*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(operators, subgraph->operators.size()); 167*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(name, subgraph->name); 168*89c4ff92SAndroid Build Coastguard Worker } 169*89c4ff92SAndroid Build Coastguard Worker CheckOperatorLoadModelFixture170*89c4ff92SAndroid Build Coastguard Worker void CheckOperator(const OperatorPtr& operatorPtr, uint32_t opcode, const std::vector<int32_t>& inputs, 171*89c4ff92SAndroid Build Coastguard Worker const std::vector<int32_t>& outputs, tflite::BuiltinOptions optionType, 172*89c4ff92SAndroid Build Coastguard Worker tflite::CustomOptionsFormat custom_options_format) 173*89c4ff92SAndroid Build Coastguard Worker { 174*89c4ff92SAndroid Build Coastguard Worker CHECK(operatorPtr); 175*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(opcode, operatorPtr->opcode_index); 176*89c4ff92SAndroid Build Coastguard Worker CHECK(std::equal(inputs.begin(), inputs.end(), 177*89c4ff92SAndroid Build Coastguard Worker operatorPtr->inputs.begin(), operatorPtr->inputs.end())); 178*89c4ff92SAndroid Build Coastguard Worker CHECK(std::equal(outputs.begin(), outputs.end(), 179*89c4ff92SAndroid Build Coastguard Worker operatorPtr->outputs.begin(), operatorPtr->outputs.end())); 180*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(optionType, operatorPtr->builtin_options.type); 181*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(custom_options_format, operatorPtr->custom_options_format); 182*89c4ff92SAndroid Build Coastguard Worker } 183*89c4ff92SAndroid Build Coastguard Worker }; 184*89c4ff92SAndroid Build Coastguard Worker 185*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromBinary") 186*89c4ff92SAndroid Build Coastguard Worker { 187*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 188*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 189*89c4ff92SAndroid Build Coastguard Worker CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D }, 190*89c4ff92SAndroid Build Coastguard Worker 2, "Test loading a model", 2); 191*89c4ff92SAndroid Build Coastguard Worker CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, ""); 192*89c4ff92SAndroid Build Coastguard Worker CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, ""); 193*89c4ff92SAndroid Build Coastguard Worker CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions, 194*89c4ff92SAndroid Build Coastguard Worker tflite::CustomOptionsFormat_FLEXBUFFERS); 195*89c4ff92SAndroid Build Coastguard Worker CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions, 196*89c4ff92SAndroid Build Coastguard Worker tflite::CustomOptionsFormat_FLEXBUFFERS); 197*89c4ff92SAndroid Build Coastguard Worker } 198*89c4ff92SAndroid Build Coastguard Worker 199*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromFile") 200*89c4ff92SAndroid Build Coastguard Worker { 201*89c4ff92SAndroid Build Coastguard Worker using namespace fs; 202*89c4ff92SAndroid Build Coastguard Worker fs::path fname = armnnUtils::Filesystem::NamedTempFile("Armnn-tfLite-LoadModelFromFile-TempFile.csv"); 203*89c4ff92SAndroid Build Coastguard Worker bool saved = flatbuffers::SaveFile(fname.c_str(), 204*89c4ff92SAndroid Build Coastguard Worker reinterpret_cast<char *>(m_GraphBinary.data()), 205*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size(), true); 206*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(saved, "Cannot save test file"); 207*89c4ff92SAndroid Build Coastguard Worker 208*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromFile(fname.c_str()); 209*89c4ff92SAndroid Build Coastguard Worker CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D }, 210*89c4ff92SAndroid Build Coastguard Worker 2, "Test loading a model", 2); 211*89c4ff92SAndroid Build Coastguard Worker CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, ""); 212*89c4ff92SAndroid Build Coastguard Worker CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, ""); 213*89c4ff92SAndroid Build Coastguard Worker CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions, 214*89c4ff92SAndroid Build Coastguard Worker tflite::CustomOptionsFormat_FLEXBUFFERS); 215*89c4ff92SAndroid Build Coastguard Worker CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions, 216*89c4ff92SAndroid Build Coastguard Worker tflite::CustomOptionsFormat_FLEXBUFFERS); 217*89c4ff92SAndroid Build Coastguard Worker remove(fname); 218*89c4ff92SAndroid Build Coastguard Worker } 219*89c4ff92SAndroid Build Coastguard Worker 220*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LoadNullBinary") 221*89c4ff92SAndroid Build Coastguard Worker { 222*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(nullptr, 0), armnn::InvalidArgumentException); 223*89c4ff92SAndroid Build Coastguard Worker } 224*89c4ff92SAndroid Build Coastguard Worker 225*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LoadInvalidBinary") 226*89c4ff92SAndroid Build Coastguard Worker { 227*89c4ff92SAndroid Build Coastguard Worker std::string testData = "invalid data"; 228*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(reinterpret_cast<const uint8_t*>(&testData), 229*89c4ff92SAndroid Build Coastguard Worker testData.length()), armnn::ParseException); 230*89c4ff92SAndroid Build Coastguard Worker } 231*89c4ff92SAndroid Build Coastguard Worker 232*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LoadFileNotFound") 233*89c4ff92SAndroid Build Coastguard Worker { 234*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile("invalidfile.tflite"), armnn::FileNotFoundException); 235*89c4ff92SAndroid Build Coastguard Worker } 236*89c4ff92SAndroid Build Coastguard Worker 237*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LoadNullPtrFile") 238*89c4ff92SAndroid Build Coastguard Worker { 239*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile(nullptr), armnn::InvalidArgumentException); 240*89c4ff92SAndroid Build Coastguard Worker } 241*89c4ff92SAndroid Build Coastguard Worker 242*89c4ff92SAndroid Build Coastguard Worker } 243