1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 #include <armnnUtils/Filesystem.hpp> 9 10 using armnnTfLiteParser::TfLiteParserImpl; 11 using ModelPtr = TfLiteParserImpl::ModelPtr; 12 using SubgraphPtr = TfLiteParserImpl::SubgraphPtr; 13 using OperatorPtr = TfLiteParserImpl::OperatorPtr; 14 15 TEST_SUITE("TensorflowLiteParser_LoadModel") 16 { 17 struct LoadModelFixture : public ParserFlatbuffersFixture 18 { LoadModelFixtureLoadModelFixture19 explicit LoadModelFixture() 20 { 21 m_JsonString = R"( 22 { 23 "version": 3, 24 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ], 25 "subgraphs": [ 26 { 27 "tensors": [ 28 { 29 "shape": [ 1, 1, 1, 1 ] , 30 "type": "UINT8", 31 "buffer": 0, 32 "name": "OutputTensor", 33 "quantization": { 34 "min": [ 0.0 ], 35 "max": [ 255.0 ], 36 "scale": [ 1.0 ], 37 "zero_point": [ 0 ] 38 } 39 }, 40 { 41 "shape": [ 1, 2, 2, 1 ] , 42 "type": "UINT8", 43 "buffer": 1, 44 "name": "InputTensor", 45 "quantization": { 46 "min": [ 0.0 ], 47 "max": [ 255.0 ], 48 "scale": [ 1.0 ], 49 "zero_point": [ 0 ] 50 } 51 } 52 ], 53 "inputs": [ 1 ], 54 "outputs": [ 0 ], 55 "operators": [ { 56 "opcode_index": 0, 57 "inputs": [ 1 ], 58 "outputs": [ 0 ], 59 "builtin_options_type": "Pool2DOptions", 60 "builtin_options": 61 { 62 "padding": "VALID", 63 "stride_w": 2, 64 "stride_h": 2, 65 "filter_width": 2, 66 "filter_height": 2, 67 "fused_activation_function": "NONE" 68 }, 69 "custom_options_format": "FLEXBUFFERS" 70 } ] 71 }, 72 { 73 "tensors": [ 74 { 75 "shape": [ 1, 3, 3, 1 ], 76 "type": "UINT8", 77 "buffer": 0, 78 "name": "ConvInputTensor", 79 "quantization": { 80 "scale": [ 1.0 ], 81 "zero_point": [ 0 ], 82 } 83 }, 84 { 85 "shape": [ 1, 1, 1, 1 ], 86 "type": "UINT8", 87 "buffer": 1, 88 "name": "ConvOutputTensor", 89 "quantization": { 90 "min": [ 0.0 ], 91 "max": [ 511.0 ], 92 "scale": [ 2.0 ], 93 "zero_point": [ 0 ], 94 } 95 }, 96 { 97 "shape": [ 1, 3, 3, 1 ], 98 "type": "UINT8", 99 "buffer": 2, 100 "name": "filterTensor", 101 "quantization": { 102 "min": [ 0.0 ], 103 "max": [ 255.0 ], 104 "scale": [ 1.0 ], 105 "zero_point": [ 0 ], 106 } 107 } 108 ], 109 "inputs": [ 0 ], 110 "outputs": [ 1 ], 111 "operators": [ 112 { 113 "opcode_index": 1, 114 "inputs": [ 0, 2 ], 115 "outputs": [ 1 ], 116 "builtin_options_type": "Conv2DOptions", 117 "builtin_options": { 118 "padding": "VALID", 119 "stride_w": 1, 120 "stride_h": 1, 121 "fused_activation_function": "NONE" 122 }, 123 "custom_options_format": "FLEXBUFFERS" 124 } 125 ], 126 } 127 ], 128 "description": "Test loading a model", 129 "buffers" : [ {}, {} ] 130 })"; 131 132 ReadStringToBinary(); 133 } 134 CheckModelLoadModelFixture135 void CheckModel(const ModelPtr& model, uint32_t version, size_t opcodeSize, 136 const std::vector<tflite::BuiltinOperator>& opcodes, 137 size_t subgraphs, const std::string desc, size_t buffers) 138 { 139 CHECK(model); 140 CHECK_EQ(version, model->version); 141 CHECK_EQ(opcodeSize, model->operator_codes.size()); 142 CheckBuiltinOperators(opcodes, model->operator_codes); 143 CHECK_EQ(subgraphs, model->subgraphs.size()); 144 CHECK_EQ(desc, model->description); 145 CHECK_EQ(buffers, model->buffers.size()); 146 } 147 CheckBuiltinOperatorsLoadModelFixture148 void CheckBuiltinOperators(const std::vector<tflite::BuiltinOperator>& expectedOperators, 149 const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& result) 150 { 151 CHECK_EQ(expectedOperators.size(), result.size()); 152 for (size_t i = 0; i < expectedOperators.size(); i++) 153 { 154 CHECK_EQ(expectedOperators[i], result[i]->builtin_code); 155 } 156 } 157 CheckSubgraphLoadModelFixture158 void CheckSubgraph(const SubgraphPtr& subgraph, size_t tensors, const std::vector<int32_t>& inputs, 159 const std::vector<int32_t>& outputs, size_t operators, const std::string& name) 160 { 161 CHECK(subgraph); 162 CHECK_EQ(tensors, subgraph->tensors.size()); 163 CHECK(std::equal(inputs.begin(), inputs.end(), subgraph->inputs.begin(), subgraph->inputs.end())); 164 CHECK(std::equal(outputs.begin(), outputs.end(), 165 subgraph->outputs.begin(), subgraph->outputs.end())); 166 CHECK_EQ(operators, subgraph->operators.size()); 167 CHECK_EQ(name, subgraph->name); 168 } 169 CheckOperatorLoadModelFixture170 void CheckOperator(const OperatorPtr& operatorPtr, uint32_t opcode, const std::vector<int32_t>& inputs, 171 const std::vector<int32_t>& outputs, tflite::BuiltinOptions optionType, 172 tflite::CustomOptionsFormat custom_options_format) 173 { 174 CHECK(operatorPtr); 175 CHECK_EQ(opcode, operatorPtr->opcode_index); 176 CHECK(std::equal(inputs.begin(), inputs.end(), 177 operatorPtr->inputs.begin(), operatorPtr->inputs.end())); 178 CHECK(std::equal(outputs.begin(), outputs.end(), 179 operatorPtr->outputs.begin(), operatorPtr->outputs.end())); 180 CHECK_EQ(optionType, operatorPtr->builtin_options.type); 181 CHECK_EQ(custom_options_format, operatorPtr->custom_options_format); 182 } 183 }; 184 185 TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromBinary") 186 { 187 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 188 m_GraphBinary.size()); 189 CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D }, 190 2, "Test loading a model", 2); 191 CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, ""); 192 CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, ""); 193 CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions, 194 tflite::CustomOptionsFormat_FLEXBUFFERS); 195 CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions, 196 tflite::CustomOptionsFormat_FLEXBUFFERS); 197 } 198 199 TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromFile") 200 { 201 using namespace fs; 202 fs::path fname = armnnUtils::Filesystem::NamedTempFile("Armnn-tfLite-LoadModelFromFile-TempFile.csv"); 203 bool saved = flatbuffers::SaveFile(fname.c_str(), 204 reinterpret_cast<char *>(m_GraphBinary.data()), 205 m_GraphBinary.size(), true); 206 CHECK_MESSAGE(saved, "Cannot save test file"); 207 208 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromFile(fname.c_str()); 209 CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D }, 210 2, "Test loading a model", 2); 211 CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, ""); 212 CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, ""); 213 CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions, 214 tflite::CustomOptionsFormat_FLEXBUFFERS); 215 CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions, 216 tflite::CustomOptionsFormat_FLEXBUFFERS); 217 remove(fname); 218 } 219 220 TEST_CASE("LoadNullBinary") 221 { 222 CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(nullptr, 0), armnn::InvalidArgumentException); 223 } 224 225 TEST_CASE("LoadInvalidBinary") 226 { 227 std::string testData = "invalid data"; 228 CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(reinterpret_cast<const uint8_t*>(&testData), 229 testData.length()), armnn::ParseException); 230 } 231 232 TEST_CASE("LoadFileNotFound") 233 { 234 CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile("invalidfile.tflite"), armnn::FileNotFoundException); 235 } 236 237 TEST_CASE("LoadNullPtrFile") 238 { 239 CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile(nullptr), armnn::InvalidArgumentException); 240 } 241 242 } 243