1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp" 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker using armnnTfLiteParser::TfLiteParserImpl; 9*89c4ff92SAndroid Build Coastguard Worker using ModelPtr = TfLiteParserImpl::ModelPtr; 10*89c4ff92SAndroid Build Coastguard Worker using TensorRawPtr = TfLiteParserImpl::TensorRawPtr; 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_GetSubgraphInputsOutputs") 13*89c4ff92SAndroid Build Coastguard Worker { 14*89c4ff92SAndroid Build Coastguard Worker struct GetSubgraphInputsOutputsMainFixture : public ParserFlatbuffersFixture 15*89c4ff92SAndroid Build Coastguard Worker { GetSubgraphInputsOutputsMainFixtureGetSubgraphInputsOutputsMainFixture16*89c4ff92SAndroid Build Coastguard Worker explicit GetSubgraphInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs) 17*89c4ff92SAndroid Build Coastguard Worker { 18*89c4ff92SAndroid Build Coastguard Worker m_JsonString = R"( 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker "version": 3, 21*89c4ff92SAndroid Build Coastguard Worker "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ], 22*89c4ff92SAndroid Build Coastguard Worker "subgraphs": [ 23*89c4ff92SAndroid Build Coastguard Worker { 24*89c4ff92SAndroid Build Coastguard Worker "tensors": [ 25*89c4ff92SAndroid Build Coastguard Worker { 26*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 1, 1, 1 ] , 27*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 28*89c4ff92SAndroid Build Coastguard Worker "buffer": 0, 29*89c4ff92SAndroid Build Coastguard Worker "name": "OutputTensor", 30*89c4ff92SAndroid Build Coastguard Worker "quantization": { 31*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 32*89c4ff92SAndroid Build Coastguard Worker "max": [ 255.0 ], 33*89c4ff92SAndroid Build Coastguard Worker "scale": [ 1.0 ], 34*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ] 35*89c4ff92SAndroid Build Coastguard Worker } 36*89c4ff92SAndroid Build Coastguard Worker }, 37*89c4ff92SAndroid Build Coastguard Worker { 38*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 2, 2, 1 ] , 39*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 40*89c4ff92SAndroid Build Coastguard Worker "buffer": 1, 41*89c4ff92SAndroid Build Coastguard Worker "name": "InputTensor", 42*89c4ff92SAndroid Build Coastguard Worker "quantization": { 43*89c4ff92SAndroid Build Coastguard Worker "min": [ -1.2 ], 44*89c4ff92SAndroid Build Coastguard Worker "max": [ 25.5 ], 45*89c4ff92SAndroid Build Coastguard Worker "scale": [ 0.25 ], 46*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 10 ] 47*89c4ff92SAndroid Build Coastguard Worker } 48*89c4ff92SAndroid Build Coastguard Worker } 49*89c4ff92SAndroid Build Coastguard Worker ], 50*89c4ff92SAndroid Build Coastguard Worker "inputs": )" 51*89c4ff92SAndroid Build Coastguard Worker + inputs 52*89c4ff92SAndroid Build Coastguard Worker + R"(, 53*89c4ff92SAndroid Build Coastguard Worker "outputs": )" 54*89c4ff92SAndroid Build Coastguard Worker + outputs 55*89c4ff92SAndroid Build Coastguard Worker + R"(, 56*89c4ff92SAndroid Build Coastguard Worker "operators": [ { 57*89c4ff92SAndroid Build Coastguard Worker "opcode_index": 0, 58*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 1 ], 59*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 0 ], 60*89c4ff92SAndroid Build Coastguard Worker "builtin_options_type": "Pool2DOptions", 61*89c4ff92SAndroid Build Coastguard Worker "builtin_options": 62*89c4ff92SAndroid Build Coastguard Worker { 63*89c4ff92SAndroid Build Coastguard Worker "padding": "VALID", 64*89c4ff92SAndroid Build Coastguard Worker "stride_w": 2, 65*89c4ff92SAndroid Build Coastguard Worker "stride_h": 2, 66*89c4ff92SAndroid Build Coastguard Worker "filter_width": 2, 67*89c4ff92SAndroid Build Coastguard Worker "filter_height": 2, 68*89c4ff92SAndroid Build Coastguard Worker "fused_activation_function": "NONE" 69*89c4ff92SAndroid Build Coastguard Worker }, 70*89c4ff92SAndroid Build Coastguard Worker "custom_options_format": "FLEXBUFFERS" 71*89c4ff92SAndroid Build Coastguard Worker } ] 72*89c4ff92SAndroid Build Coastguard Worker }, 73*89c4ff92SAndroid Build Coastguard Worker { 74*89c4ff92SAndroid Build Coastguard Worker "tensors": [ 75*89c4ff92SAndroid Build Coastguard Worker { 76*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 3, 3, 1 ], 77*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 78*89c4ff92SAndroid Build Coastguard Worker "buffer": 0, 79*89c4ff92SAndroid Build Coastguard Worker "name": "ConvInputTensor", 80*89c4ff92SAndroid Build Coastguard Worker "quantization": { 81*89c4ff92SAndroid Build Coastguard Worker "scale": [ 1.0 ], 82*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ], 83*89c4ff92SAndroid Build Coastguard Worker } 84*89c4ff92SAndroid Build Coastguard Worker }, 85*89c4ff92SAndroid Build Coastguard Worker { 86*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 1, 1, 1 ], 87*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 88*89c4ff92SAndroid Build Coastguard Worker "buffer": 1, 89*89c4ff92SAndroid Build Coastguard Worker "name": "ConvOutputTensor", 90*89c4ff92SAndroid Build Coastguard Worker "quantization": { 91*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 92*89c4ff92SAndroid Build Coastguard Worker "max": [ 511.0 ], 93*89c4ff92SAndroid Build Coastguard Worker "scale": [ 2.0 ], 94*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ], 95*89c4ff92SAndroid Build Coastguard Worker } 96*89c4ff92SAndroid Build Coastguard Worker }, 97*89c4ff92SAndroid Build Coastguard Worker { 98*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 3, 3, 1 ], 99*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 100*89c4ff92SAndroid Build Coastguard Worker "buffer": 2, 101*89c4ff92SAndroid Build Coastguard Worker "name": "filterTensor", 102*89c4ff92SAndroid Build Coastguard Worker "quantization": { 103*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 104*89c4ff92SAndroid Build Coastguard Worker "max": [ 255.0 ], 105*89c4ff92SAndroid Build Coastguard Worker "scale": [ 1.0 ], 106*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ], 107*89c4ff92SAndroid Build Coastguard Worker } 108*89c4ff92SAndroid Build Coastguard Worker } 109*89c4ff92SAndroid Build Coastguard Worker ], 110*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 0 ], 111*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 1 ], 112*89c4ff92SAndroid Build Coastguard Worker "operators": [ 113*89c4ff92SAndroid Build Coastguard Worker { 114*89c4ff92SAndroid Build Coastguard Worker "opcode_index": 0, 115*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 0, 2 ], 116*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 1 ], 117*89c4ff92SAndroid Build Coastguard Worker "builtin_options_type": "Conv2DOptions", 118*89c4ff92SAndroid Build Coastguard Worker "builtin_options": { 119*89c4ff92SAndroid Build Coastguard Worker "padding": "VALID", 120*89c4ff92SAndroid Build Coastguard Worker "stride_w": 1, 121*89c4ff92SAndroid Build Coastguard Worker "stride_h": 1, 122*89c4ff92SAndroid Build Coastguard Worker "fused_activation_function": "NONE" 123*89c4ff92SAndroid Build Coastguard Worker }, 124*89c4ff92SAndroid Build Coastguard Worker "custom_options_format": "FLEXBUFFERS" 125*89c4ff92SAndroid Build Coastguard Worker } 126*89c4ff92SAndroid Build Coastguard Worker ], 127*89c4ff92SAndroid Build Coastguard Worker } 128*89c4ff92SAndroid Build Coastguard Worker ], 129*89c4ff92SAndroid Build Coastguard Worker "description": "Test Subgraph Inputs Outputs", 130*89c4ff92SAndroid Build Coastguard Worker "buffers" : [ 131*89c4ff92SAndroid Build Coastguard Worker { }, 132*89c4ff92SAndroid Build Coastguard Worker { }, 133*89c4ff92SAndroid Build Coastguard Worker { "data": [ 2,1,0, 6,2,1, 4,1,2 ], }, 134*89c4ff92SAndroid Build Coastguard Worker { }, 135*89c4ff92SAndroid Build Coastguard Worker ] 136*89c4ff92SAndroid Build Coastguard Worker })"; 137*89c4ff92SAndroid Build Coastguard Worker 138*89c4ff92SAndroid Build Coastguard Worker ReadStringToBinary(); 139*89c4ff92SAndroid Build Coastguard Worker } 140*89c4ff92SAndroid Build Coastguard Worker 141*89c4ff92SAndroid Build Coastguard Worker }; 142*89c4ff92SAndroid Build Coastguard Worker 143*89c4ff92SAndroid Build Coastguard Worker struct GetEmptySubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture 144*89c4ff92SAndroid Build Coastguard Worker { GetEmptySubgraphInputsOutputsFixtureGetEmptySubgraphInputsOutputsFixture145*89c4ff92SAndroid Build Coastguard Worker GetEmptySubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ ]", "[ ]") {} 146*89c4ff92SAndroid Build Coastguard Worker }; 147*89c4ff92SAndroid Build Coastguard Worker 148*89c4ff92SAndroid Build Coastguard Worker struct GetSubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture 149*89c4ff92SAndroid Build Coastguard Worker { GetSubgraphInputsOutputsFixtureGetSubgraphInputsOutputsFixture150*89c4ff92SAndroid Build Coastguard Worker GetSubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {} 151*89c4ff92SAndroid Build Coastguard Worker }; 152*89c4ff92SAndroid Build Coastguard Worker 153*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetEmptySubgraphInputsOutputsFixture, "GetEmptySubgraphInputs") 154*89c4ff92SAndroid Build Coastguard Worker { 155*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 156*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 157*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0); 158*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(0, subgraphTensors.size()); 159*89c4ff92SAndroid Build Coastguard Worker } 160*89c4ff92SAndroid Build Coastguard Worker 161*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetEmptySubgraphInputsOutputsFixture, "GetEmptySubgraphOutputs") 162*89c4ff92SAndroid Build Coastguard Worker { 163*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 164*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 165*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0); 166*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(0, subgraphTensors.size()); 167*89c4ff92SAndroid Build Coastguard Worker } 168*89c4ff92SAndroid Build Coastguard Worker 169*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputs") 170*89c4ff92SAndroid Build Coastguard Worker { 171*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 172*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 173*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0); 174*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(1, subgraphTensors.size()); 175*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(1, subgraphTensors[0].first); 176*89c4ff92SAndroid Build Coastguard Worker CheckTensors(subgraphTensors[0].second, 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1, 177*89c4ff92SAndroid Build Coastguard Worker "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 }); 178*89c4ff92SAndroid Build Coastguard Worker } 179*89c4ff92SAndroid Build Coastguard Worker 180*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputsSimpleQuantized") 181*89c4ff92SAndroid Build Coastguard Worker { 182*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 183*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 184*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0); 185*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(1, subgraphTensors.size()); 186*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(0, subgraphTensors[0].first); 187*89c4ff92SAndroid Build Coastguard Worker CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0, 188*89c4ff92SAndroid Build Coastguard Worker "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 }); 189*89c4ff92SAndroid Build Coastguard Worker } 190*89c4ff92SAndroid Build Coastguard Worker 191*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputsEmptyMinMax") 192*89c4ff92SAndroid Build Coastguard Worker { 193*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 194*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 195*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 1); 196*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(1, subgraphTensors.size()); 197*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(0, subgraphTensors[0].first); 198*89c4ff92SAndroid Build Coastguard Worker CheckTensors(subgraphTensors[0].second, 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0, 199*89c4ff92SAndroid Build Coastguard Worker "ConvInputTensor", { }, { }, { 1.0f }, { 0 }); 200*89c4ff92SAndroid Build Coastguard Worker } 201*89c4ff92SAndroid Build Coastguard Worker 202*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputs") 203*89c4ff92SAndroid Build Coastguard Worker { 204*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 205*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 206*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 1); 207*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(1, subgraphTensors.size()); 208*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(1, subgraphTensors[0].first); 209*89c4ff92SAndroid Build Coastguard Worker CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1, 210*89c4ff92SAndroid Build Coastguard Worker "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 }); 211*89c4ff92SAndroid Build Coastguard Worker } 212*89c4ff92SAndroid Build Coastguard Worker 213*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GetSubgraphInputsNullModel") 214*89c4ff92SAndroid Build Coastguard Worker { 215*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphInputs(nullptr, 0), armnn::ParseException); 216*89c4ff92SAndroid Build Coastguard Worker } 217*89c4ff92SAndroid Build Coastguard Worker 218*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GetSubgraphOutputsNullModel") 219*89c4ff92SAndroid Build Coastguard Worker { 220*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphOutputs(nullptr, 0), armnn::ParseException); 221*89c4ff92SAndroid Build Coastguard Worker } 222*89c4ff92SAndroid Build Coastguard Worker 223*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputsInvalidSubgraph") 224*89c4ff92SAndroid Build Coastguard Worker { 225*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 226*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 227*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphInputs(model, 2), armnn::ParseException); 228*89c4ff92SAndroid Build Coastguard Worker } 229*89c4ff92SAndroid Build Coastguard Worker 230*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputsInvalidSubgraph") 231*89c4ff92SAndroid Build Coastguard Worker { 232*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 233*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 234*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphOutputs(model, 2), armnn::ParseException); 235*89c4ff92SAndroid Build Coastguard Worker } 236*89c4ff92SAndroid Build Coastguard Worker 237*89c4ff92SAndroid Build Coastguard Worker }