1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 using armnnTfLiteParser::TfLiteParserImpl; 9 using ModelPtr = TfLiteParserImpl::ModelPtr; 10 using TensorRawPtr = TfLiteParserImpl::TensorRawPtr; 11 12 TEST_SUITE("TensorflowLiteParser_GetSubgraphInputsOutputs") 13 { 14 struct GetSubgraphInputsOutputsMainFixture : public ParserFlatbuffersFixture 15 { GetSubgraphInputsOutputsMainFixtureGetSubgraphInputsOutputsMainFixture16 explicit GetSubgraphInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs) 17 { 18 m_JsonString = R"( 19 { 20 "version": 3, 21 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ], 22 "subgraphs": [ 23 { 24 "tensors": [ 25 { 26 "shape": [ 1, 1, 1, 1 ] , 27 "type": "UINT8", 28 "buffer": 0, 29 "name": "OutputTensor", 30 "quantization": { 31 "min": [ 0.0 ], 32 "max": [ 255.0 ], 33 "scale": [ 1.0 ], 34 "zero_point": [ 0 ] 35 } 36 }, 37 { 38 "shape": [ 1, 2, 2, 1 ] , 39 "type": "UINT8", 40 "buffer": 1, 41 "name": "InputTensor", 42 "quantization": { 43 "min": [ -1.2 ], 44 "max": [ 25.5 ], 45 "scale": [ 0.25 ], 46 "zero_point": [ 10 ] 47 } 48 } 49 ], 50 "inputs": )" 51 + inputs 52 + R"(, 53 "outputs": )" 54 + outputs 55 + R"(, 56 "operators": [ { 57 "opcode_index": 0, 58 "inputs": [ 1 ], 59 "outputs": [ 0 ], 60 "builtin_options_type": "Pool2DOptions", 61 "builtin_options": 62 { 63 "padding": "VALID", 64 "stride_w": 2, 65 "stride_h": 2, 66 "filter_width": 2, 67 "filter_height": 2, 68 "fused_activation_function": "NONE" 69 }, 70 "custom_options_format": "FLEXBUFFERS" 71 } ] 72 }, 73 { 74 "tensors": [ 75 { 76 "shape": [ 1, 3, 3, 1 ], 77 "type": "UINT8", 78 "buffer": 0, 79 "name": "ConvInputTensor", 80 "quantization": { 81 "scale": [ 1.0 ], 82 "zero_point": [ 0 ], 83 } 84 }, 85 { 86 "shape": [ 1, 1, 1, 1 ], 87 "type": "UINT8", 88 "buffer": 1, 89 "name": "ConvOutputTensor", 90 "quantization": { 91 "min": [ 0.0 ], 92 "max": [ 511.0 ], 93 "scale": [ 2.0 ], 94 "zero_point": [ 0 ], 95 } 96 }, 97 { 98 "shape": [ 1, 3, 3, 1 ], 99 "type": "UINT8", 100 "buffer": 2, 101 "name": "filterTensor", 102 "quantization": { 103 "min": [ 0.0 ], 104 "max": [ 255.0 ], 105 "scale": [ 1.0 ], 106 "zero_point": [ 0 ], 107 } 108 } 109 ], 110 "inputs": [ 0 ], 111 "outputs": [ 1 ], 112 "operators": [ 113 { 114 "opcode_index": 0, 115 "inputs": [ 0, 2 ], 116 "outputs": [ 1 ], 117 "builtin_options_type": "Conv2DOptions", 118 "builtin_options": { 119 "padding": "VALID", 120 "stride_w": 1, 121 "stride_h": 1, 122 "fused_activation_function": "NONE" 123 }, 124 "custom_options_format": "FLEXBUFFERS" 125 } 126 ], 127 } 128 ], 129 "description": "Test Subgraph Inputs Outputs", 130 "buffers" : [ 131 { }, 132 { }, 133 { "data": [ 2,1,0, 6,2,1, 4,1,2 ], }, 134 { }, 135 ] 136 })"; 137 138 ReadStringToBinary(); 139 } 140 141 }; 142 143 struct GetEmptySubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture 144 { GetEmptySubgraphInputsOutputsFixtureGetEmptySubgraphInputsOutputsFixture145 GetEmptySubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ ]", "[ ]") {} 146 }; 147 148 struct GetSubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture 149 { GetSubgraphInputsOutputsFixtureGetSubgraphInputsOutputsFixture150 GetSubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {} 151 }; 152 153 TEST_CASE_FIXTURE(GetEmptySubgraphInputsOutputsFixture, "GetEmptySubgraphInputs") 154 { 155 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 156 m_GraphBinary.size()); 157 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0); 158 CHECK_EQ(0, subgraphTensors.size()); 159 } 160 161 TEST_CASE_FIXTURE(GetEmptySubgraphInputsOutputsFixture, "GetEmptySubgraphOutputs") 162 { 163 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 164 m_GraphBinary.size()); 165 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0); 166 CHECK_EQ(0, subgraphTensors.size()); 167 } 168 169 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputs") 170 { 171 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 172 m_GraphBinary.size()); 173 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0); 174 CHECK_EQ(1, subgraphTensors.size()); 175 CHECK_EQ(1, subgraphTensors[0].first); 176 CheckTensors(subgraphTensors[0].second, 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1, 177 "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 }); 178 } 179 180 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputsSimpleQuantized") 181 { 182 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 183 m_GraphBinary.size()); 184 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0); 185 CHECK_EQ(1, subgraphTensors.size()); 186 CHECK_EQ(0, subgraphTensors[0].first); 187 CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0, 188 "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 }); 189 } 190 191 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputsEmptyMinMax") 192 { 193 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 194 m_GraphBinary.size()); 195 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 1); 196 CHECK_EQ(1, subgraphTensors.size()); 197 CHECK_EQ(0, subgraphTensors[0].first); 198 CheckTensors(subgraphTensors[0].second, 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0, 199 "ConvInputTensor", { }, { }, { 1.0f }, { 0 }); 200 } 201 202 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputs") 203 { 204 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 205 m_GraphBinary.size()); 206 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 1); 207 CHECK_EQ(1, subgraphTensors.size()); 208 CHECK_EQ(1, subgraphTensors[0].first); 209 CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1, 210 "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 }); 211 } 212 213 TEST_CASE("GetSubgraphInputsNullModel") 214 { 215 CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphInputs(nullptr, 0), armnn::ParseException); 216 } 217 218 TEST_CASE("GetSubgraphOutputsNullModel") 219 { 220 CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphOutputs(nullptr, 0), armnn::ParseException); 221 } 222 223 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputsInvalidSubgraph") 224 { 225 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 226 m_GraphBinary.size()); 227 CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphInputs(model, 2), armnn::ParseException); 228 } 229 230 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputsInvalidSubgraph") 231 { 232 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 233 m_GraphBinary.size()); 234 CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphOutputs(model, 2), armnn::ParseException); 235 } 236 237 }