1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 using armnnTfLiteParser::TfLiteParserImpl; 9 using ModelPtr = TfLiteParserImpl::ModelPtr; 10 11 TEST_SUITE("TensorflowLiteParser_GetInputsOutputs") 12 { 13 struct GetInputsOutputsMainFixture : public ParserFlatbuffersFixture 14 { GetInputsOutputsMainFixtureGetInputsOutputsMainFixture15 explicit GetInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs) 16 { 17 m_JsonString = R"( 18 { 19 "version": 3, 20 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ], 21 "subgraphs": [ 22 { 23 "tensors": [ 24 { 25 "shape": [ 1, 1, 1, 1 ] , 26 "type": "UINT8", 27 "buffer": 0, 28 "name": "OutputTensor", 29 "quantization": { 30 "min": [ 0.0 ], 31 "max": [ 255.0 ], 32 "scale": [ 1.0 ], 33 "zero_point": [ 0 ] 34 } 35 }, 36 { 37 "shape": [ 1, 2, 2, 1 ] , 38 "type": "UINT8", 39 "buffer": 1, 40 "name": "InputTensor", 41 "quantization": { 42 "min": [ -1.2 ], 43 "max": [ 25.5 ], 44 "scale": [ 0.25 ], 45 "zero_point": [ 10 ] 46 } 47 } 48 ], 49 "inputs": [ 1 ], 50 "outputs": [ 0 ], 51 "operators": [ { 52 "opcode_index": 0, 53 "inputs": )" 54 + inputs 55 + R"(, 56 "outputs": )" 57 + outputs 58 + R"(, 59 "builtin_options_type": "Pool2DOptions", 60 "builtin_options": 61 { 62 "padding": "VALID", 63 "stride_w": 2, 64 "stride_h": 2, 65 "filter_width": 2, 66 "filter_height": 2, 67 "fused_activation_function": "NONE" 68 }, 69 "custom_options_format": "FLEXBUFFERS" 70 } ] 71 }, 72 { 73 "tensors": [ 74 { 75 "shape": [ 1, 3, 3, 1 ], 76 "type": "UINT8", 77 "buffer": 0, 78 "name": "ConvInputTensor", 79 "quantization": { 80 "scale": [ 1.0 ], 81 "zero_point": [ 0 ], 82 } 83 }, 84 { 85 "shape": [ 1, 1, 1, 1 ], 86 "type": "UINT8", 87 "buffer": 1, 88 "name": "ConvOutputTensor", 89 "quantization": { 90 "min": [ 0.0 ], 91 "max": [ 511.0 ], 92 "scale": [ 2.0 ], 93 "zero_point": [ 0 ], 94 } 95 }, 96 { 97 "shape": [ 1, 3, 3, 1 ], 98 "type": "UINT8", 99 "buffer": 2, 100 "name": "filterTensor", 101 "quantization": { 102 "min": [ 0.0 ], 103 "max": [ 255.0 ], 104 "scale": [ 1.0 ], 105 "zero_point": [ 0 ], 106 } 107 } 108 ], 109 "inputs": [ 0 ], 110 "outputs": [ 1 ], 111 "operators": [ 112 { 113 "opcode_index": 0, 114 "inputs": [ 0, 2 ], 115 "outputs": [ 1 ], 116 "builtin_options_type": "Conv2DOptions", 117 "builtin_options": { 118 "padding": "VALID", 119 "stride_w": 1, 120 "stride_h": 1, 121 "fused_activation_function": "NONE" 122 }, 123 "custom_options_format": "FLEXBUFFERS" 124 } 125 ], 126 } 127 ], 128 "description": "Test Subgraph Inputs Outputs", 129 "buffers" : [ 130 { }, 131 { }, 132 { "data": [ 2,1,0, 6,2,1, 4,1,2 ], }, 133 { }, 134 ] 135 })"; 136 137 ReadStringToBinary(); 138 } 139 140 }; 141 142 struct GetEmptyInputsOutputsFixture : GetInputsOutputsMainFixture 143 { GetEmptyInputsOutputsFixtureGetEmptyInputsOutputsFixture144 GetEmptyInputsOutputsFixture() : GetInputsOutputsMainFixture("[ ]", "[ ]") {} 145 }; 146 147 struct GetInputsOutputsFixture : GetInputsOutputsMainFixture 148 { GetInputsOutputsFixtureGetInputsOutputsFixture149 GetInputsOutputsFixture() : GetInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {} 150 }; 151 152 TEST_CASE_FIXTURE(GetEmptyInputsOutputsFixture, "GetEmptyInputs") 153 { 154 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 155 m_GraphBinary.size()); 156 TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetInputs(model, 0, 0); 157 CHECK_EQ(0, tensors.size()); 158 } 159 160 TEST_CASE_FIXTURE(GetEmptyInputsOutputsFixture, "GetEmptyOutputs") 161 { 162 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 163 m_GraphBinary.size()); 164 TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetOutputs(model, 0, 0); 165 CHECK_EQ(0, tensors.size()); 166 } 167 168 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputs") 169 { 170 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 171 m_GraphBinary.size()); 172 TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetInputs(model, 0, 0); 173 CHECK_EQ(1, tensors.size()); 174 CheckTensors(tensors[0], 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1, 175 "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 }); 176 } 177 178 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputs") 179 { 180 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 181 m_GraphBinary.size()); 182 TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetOutputs(model, 0, 0); 183 CHECK_EQ(1, tensors.size()); 184 CheckTensors(tensors[0], 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0, 185 "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 }); 186 } 187 188 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputsMultipleInputs") 189 { 190 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 191 m_GraphBinary.size()); 192 TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetInputs(model, 1, 0); 193 CHECK_EQ(2, tensors.size()); 194 CheckTensors(tensors[0], 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0, 195 "ConvInputTensor", { }, { }, { 1.0f }, { 0 }); 196 CheckTensors(tensors[1], 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 2, 197 "filterTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 }); 198 } 199 200 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputs2") 201 { 202 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 203 m_GraphBinary.size()); 204 TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetOutputs(model, 1, 0); 205 CHECK_EQ(1, tensors.size()); 206 CheckTensors(tensors[0], 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1, 207 "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 }); 208 } 209 210 TEST_CASE("GetInputsNullModel") 211 { 212 CHECK_THROWS_AS(TfLiteParserImpl::GetInputs(nullptr, 0, 0), armnn::ParseException); 213 } 214 215 TEST_CASE("GetOutputsNullModel") 216 { 217 CHECK_THROWS_AS(TfLiteParserImpl::GetOutputs(nullptr, 0, 0), armnn::ParseException); 218 } 219 220 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputsInvalidSubgraph") 221 { 222 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 223 m_GraphBinary.size()); 224 CHECK_THROWS_AS(TfLiteParserImpl::GetInputs(model, 2, 0), armnn::ParseException); 225 } 226 227 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputsInvalidSubgraph") 228 { 229 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 230 m_GraphBinary.size()); 231 CHECK_THROWS_AS(TfLiteParserImpl::GetOutputs(model, 2, 0), armnn::ParseException); 232 } 233 234 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputsInvalidOperator") 235 { 236 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 237 m_GraphBinary.size()); 238 CHECK_THROWS_AS(TfLiteParserImpl::GetInputs(model, 0, 1), armnn::ParseException); 239 } 240 241 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputsInvalidOperator") 242 { 243 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 244 m_GraphBinary.size()); 245 CHECK_THROWS_AS(TfLiteParserImpl::GetOutputs(model, 0, 1), armnn::ParseException); 246 } 247 248 }