1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp" 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker using armnnTfLiteParser::TfLiteParserImpl; 9*89c4ff92SAndroid Build Coastguard Worker using ModelPtr = TfLiteParserImpl::ModelPtr; 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_GetInputsOutputs") 12*89c4ff92SAndroid Build Coastguard Worker { 13*89c4ff92SAndroid Build Coastguard Worker struct GetInputsOutputsMainFixture : public ParserFlatbuffersFixture 14*89c4ff92SAndroid Build Coastguard Worker { GetInputsOutputsMainFixtureGetInputsOutputsMainFixture15*89c4ff92SAndroid Build Coastguard Worker explicit GetInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs) 16*89c4ff92SAndroid Build Coastguard Worker { 17*89c4ff92SAndroid Build Coastguard Worker m_JsonString = R"( 18*89c4ff92SAndroid Build Coastguard Worker { 19*89c4ff92SAndroid Build Coastguard Worker "version": 3, 20*89c4ff92SAndroid Build Coastguard Worker "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ], 21*89c4ff92SAndroid Build Coastguard Worker "subgraphs": [ 22*89c4ff92SAndroid Build Coastguard Worker { 23*89c4ff92SAndroid Build Coastguard Worker "tensors": [ 24*89c4ff92SAndroid Build Coastguard Worker { 25*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 1, 1, 1 ] , 26*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 27*89c4ff92SAndroid Build Coastguard Worker "buffer": 0, 28*89c4ff92SAndroid Build Coastguard Worker "name": "OutputTensor", 29*89c4ff92SAndroid Build Coastguard Worker "quantization": { 30*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 31*89c4ff92SAndroid Build Coastguard Worker "max": [ 255.0 ], 32*89c4ff92SAndroid Build Coastguard Worker "scale": [ 1.0 ], 33*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ] 34*89c4ff92SAndroid Build Coastguard Worker } 35*89c4ff92SAndroid Build Coastguard Worker }, 36*89c4ff92SAndroid Build Coastguard Worker { 37*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 2, 2, 1 ] , 38*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 39*89c4ff92SAndroid Build Coastguard Worker "buffer": 1, 40*89c4ff92SAndroid Build Coastguard Worker "name": "InputTensor", 41*89c4ff92SAndroid Build Coastguard Worker "quantization": { 42*89c4ff92SAndroid Build Coastguard Worker "min": [ -1.2 ], 43*89c4ff92SAndroid Build Coastguard Worker "max": [ 25.5 ], 44*89c4ff92SAndroid Build Coastguard Worker "scale": [ 0.25 ], 45*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 10 ] 46*89c4ff92SAndroid Build Coastguard Worker } 47*89c4ff92SAndroid Build Coastguard Worker } 48*89c4ff92SAndroid Build Coastguard Worker ], 49*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 1 ], 50*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 0 ], 51*89c4ff92SAndroid Build Coastguard Worker "operators": [ { 52*89c4ff92SAndroid Build Coastguard Worker "opcode_index": 0, 53*89c4ff92SAndroid Build Coastguard Worker "inputs": )" 54*89c4ff92SAndroid Build Coastguard Worker + inputs 55*89c4ff92SAndroid Build Coastguard Worker + R"(, 56*89c4ff92SAndroid Build Coastguard Worker "outputs": )" 57*89c4ff92SAndroid Build Coastguard Worker + outputs 58*89c4ff92SAndroid Build Coastguard Worker + R"(, 59*89c4ff92SAndroid Build Coastguard Worker "builtin_options_type": "Pool2DOptions", 60*89c4ff92SAndroid Build Coastguard Worker "builtin_options": 61*89c4ff92SAndroid Build Coastguard Worker { 62*89c4ff92SAndroid Build Coastguard Worker "padding": "VALID", 63*89c4ff92SAndroid Build Coastguard Worker "stride_w": 2, 64*89c4ff92SAndroid Build Coastguard Worker "stride_h": 2, 65*89c4ff92SAndroid Build Coastguard Worker "filter_width": 2, 66*89c4ff92SAndroid Build Coastguard Worker "filter_height": 2, 67*89c4ff92SAndroid Build Coastguard Worker "fused_activation_function": "NONE" 68*89c4ff92SAndroid Build Coastguard Worker }, 69*89c4ff92SAndroid Build Coastguard Worker "custom_options_format": "FLEXBUFFERS" 70*89c4ff92SAndroid Build Coastguard Worker } ] 71*89c4ff92SAndroid Build Coastguard Worker }, 72*89c4ff92SAndroid Build Coastguard Worker { 73*89c4ff92SAndroid Build Coastguard Worker "tensors": [ 74*89c4ff92SAndroid Build Coastguard Worker { 75*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 3, 3, 1 ], 76*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 77*89c4ff92SAndroid Build Coastguard Worker "buffer": 0, 78*89c4ff92SAndroid Build Coastguard Worker "name": "ConvInputTensor", 79*89c4ff92SAndroid Build Coastguard Worker "quantization": { 80*89c4ff92SAndroid Build Coastguard Worker "scale": [ 1.0 ], 81*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ], 82*89c4ff92SAndroid Build Coastguard Worker } 83*89c4ff92SAndroid Build Coastguard Worker }, 84*89c4ff92SAndroid Build Coastguard Worker { 85*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 1, 1, 1 ], 86*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 87*89c4ff92SAndroid Build Coastguard Worker "buffer": 1, 88*89c4ff92SAndroid Build Coastguard Worker "name": "ConvOutputTensor", 89*89c4ff92SAndroid Build Coastguard Worker "quantization": { 90*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 91*89c4ff92SAndroid Build Coastguard Worker "max": [ 511.0 ], 92*89c4ff92SAndroid Build Coastguard Worker "scale": [ 2.0 ], 93*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ], 94*89c4ff92SAndroid Build Coastguard Worker } 95*89c4ff92SAndroid Build Coastguard Worker }, 96*89c4ff92SAndroid Build Coastguard Worker { 97*89c4ff92SAndroid Build Coastguard Worker "shape": [ 1, 3, 3, 1 ], 98*89c4ff92SAndroid Build Coastguard Worker "type": "UINT8", 99*89c4ff92SAndroid Build Coastguard Worker "buffer": 2, 100*89c4ff92SAndroid Build Coastguard Worker "name": "filterTensor", 101*89c4ff92SAndroid Build Coastguard Worker "quantization": { 102*89c4ff92SAndroid Build Coastguard Worker "min": [ 0.0 ], 103*89c4ff92SAndroid Build Coastguard Worker "max": [ 255.0 ], 104*89c4ff92SAndroid Build Coastguard Worker "scale": [ 1.0 ], 105*89c4ff92SAndroid Build Coastguard Worker "zero_point": [ 0 ], 106*89c4ff92SAndroid Build Coastguard Worker } 107*89c4ff92SAndroid Build Coastguard Worker } 108*89c4ff92SAndroid Build Coastguard Worker ], 109*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 0 ], 110*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 1 ], 111*89c4ff92SAndroid Build Coastguard Worker "operators": [ 112*89c4ff92SAndroid Build Coastguard Worker { 113*89c4ff92SAndroid Build Coastguard Worker "opcode_index": 0, 114*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 0, 2 ], 115*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 1 ], 116*89c4ff92SAndroid Build Coastguard Worker "builtin_options_type": "Conv2DOptions", 117*89c4ff92SAndroid Build Coastguard Worker "builtin_options": { 118*89c4ff92SAndroid Build Coastguard Worker "padding": "VALID", 119*89c4ff92SAndroid Build Coastguard Worker "stride_w": 1, 120*89c4ff92SAndroid Build Coastguard Worker "stride_h": 1, 121*89c4ff92SAndroid Build Coastguard Worker "fused_activation_function": "NONE" 122*89c4ff92SAndroid Build Coastguard Worker }, 123*89c4ff92SAndroid Build Coastguard Worker "custom_options_format": "FLEXBUFFERS" 124*89c4ff92SAndroid Build Coastguard Worker } 125*89c4ff92SAndroid Build Coastguard Worker ], 126*89c4ff92SAndroid Build Coastguard Worker } 127*89c4ff92SAndroid Build Coastguard Worker ], 128*89c4ff92SAndroid Build Coastguard Worker "description": "Test Subgraph Inputs Outputs", 129*89c4ff92SAndroid Build Coastguard Worker "buffers" : [ 130*89c4ff92SAndroid Build Coastguard Worker { }, 131*89c4ff92SAndroid Build Coastguard Worker { }, 132*89c4ff92SAndroid Build Coastguard Worker { "data": [ 2,1,0, 6,2,1, 4,1,2 ], }, 133*89c4ff92SAndroid Build Coastguard Worker { }, 134*89c4ff92SAndroid Build Coastguard Worker ] 135*89c4ff92SAndroid Build Coastguard Worker })"; 136*89c4ff92SAndroid Build Coastguard Worker 137*89c4ff92SAndroid Build Coastguard Worker ReadStringToBinary(); 138*89c4ff92SAndroid Build Coastguard Worker } 139*89c4ff92SAndroid Build Coastguard Worker 140*89c4ff92SAndroid Build Coastguard Worker }; 141*89c4ff92SAndroid Build Coastguard Worker 142*89c4ff92SAndroid Build Coastguard Worker struct GetEmptyInputsOutputsFixture : GetInputsOutputsMainFixture 143*89c4ff92SAndroid Build Coastguard Worker { GetEmptyInputsOutputsFixtureGetEmptyInputsOutputsFixture144*89c4ff92SAndroid Build Coastguard Worker GetEmptyInputsOutputsFixture() : GetInputsOutputsMainFixture("[ ]", "[ ]") {} 145*89c4ff92SAndroid Build Coastguard Worker }; 146*89c4ff92SAndroid Build Coastguard Worker 147*89c4ff92SAndroid Build Coastguard Worker struct GetInputsOutputsFixture : GetInputsOutputsMainFixture 148*89c4ff92SAndroid Build Coastguard Worker { GetInputsOutputsFixtureGetInputsOutputsFixture149*89c4ff92SAndroid Build Coastguard Worker GetInputsOutputsFixture() : GetInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {} 150*89c4ff92SAndroid Build Coastguard Worker }; 151*89c4ff92SAndroid Build Coastguard Worker 152*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetEmptyInputsOutputsFixture, "GetEmptyInputs") 153*89c4ff92SAndroid Build Coastguard Worker { 154*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 155*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 156*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetInputs(model, 0, 0); 157*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(0, tensors.size()); 158*89c4ff92SAndroid Build Coastguard Worker } 159*89c4ff92SAndroid Build Coastguard Worker 160*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetEmptyInputsOutputsFixture, "GetEmptyOutputs") 161*89c4ff92SAndroid Build Coastguard Worker { 162*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 163*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 164*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetOutputs(model, 0, 0); 165*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(0, tensors.size()); 166*89c4ff92SAndroid Build Coastguard Worker } 167*89c4ff92SAndroid Build Coastguard Worker 168*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputs") 169*89c4ff92SAndroid Build Coastguard Worker { 170*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 171*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 172*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetInputs(model, 0, 0); 173*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(1, tensors.size()); 174*89c4ff92SAndroid Build Coastguard Worker CheckTensors(tensors[0], 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1, 175*89c4ff92SAndroid Build Coastguard Worker "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 }); 176*89c4ff92SAndroid Build Coastguard Worker } 177*89c4ff92SAndroid Build Coastguard Worker 178*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputs") 179*89c4ff92SAndroid Build Coastguard Worker { 180*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 181*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 182*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetOutputs(model, 0, 0); 183*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(1, tensors.size()); 184*89c4ff92SAndroid Build Coastguard Worker CheckTensors(tensors[0], 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0, 185*89c4ff92SAndroid Build Coastguard Worker "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 }); 186*89c4ff92SAndroid Build Coastguard Worker } 187*89c4ff92SAndroid Build Coastguard Worker 188*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputsMultipleInputs") 189*89c4ff92SAndroid Build Coastguard Worker { 190*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 191*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 192*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetInputs(model, 1, 0); 193*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(2, tensors.size()); 194*89c4ff92SAndroid Build Coastguard Worker CheckTensors(tensors[0], 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0, 195*89c4ff92SAndroid Build Coastguard Worker "ConvInputTensor", { }, { }, { 1.0f }, { 0 }); 196*89c4ff92SAndroid Build Coastguard Worker CheckTensors(tensors[1], 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 2, 197*89c4ff92SAndroid Build Coastguard Worker "filterTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 }); 198*89c4ff92SAndroid Build Coastguard Worker } 199*89c4ff92SAndroid Build Coastguard Worker 200*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputs2") 201*89c4ff92SAndroid Build Coastguard Worker { 202*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 203*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 204*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetOutputs(model, 1, 0); 205*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(1, tensors.size()); 206*89c4ff92SAndroid Build Coastguard Worker CheckTensors(tensors[0], 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1, 207*89c4ff92SAndroid Build Coastguard Worker "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 }); 208*89c4ff92SAndroid Build Coastguard Worker } 209*89c4ff92SAndroid Build Coastguard Worker 210*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GetInputsNullModel") 211*89c4ff92SAndroid Build Coastguard Worker { 212*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::GetInputs(nullptr, 0, 0), armnn::ParseException); 213*89c4ff92SAndroid Build Coastguard Worker } 214*89c4ff92SAndroid Build Coastguard Worker 215*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GetOutputsNullModel") 216*89c4ff92SAndroid Build Coastguard Worker { 217*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::GetOutputs(nullptr, 0, 0), armnn::ParseException); 218*89c4ff92SAndroid Build Coastguard Worker } 219*89c4ff92SAndroid Build Coastguard Worker 220*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputsInvalidSubgraph") 221*89c4ff92SAndroid Build Coastguard Worker { 222*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 223*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 224*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::GetInputs(model, 2, 0), armnn::ParseException); 225*89c4ff92SAndroid Build Coastguard Worker } 226*89c4ff92SAndroid Build Coastguard Worker 227*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputsInvalidSubgraph") 228*89c4ff92SAndroid Build Coastguard Worker { 229*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 230*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 231*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::GetOutputs(model, 2, 0), armnn::ParseException); 232*89c4ff92SAndroid Build Coastguard Worker } 233*89c4ff92SAndroid Build Coastguard Worker 234*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputsInvalidOperator") 235*89c4ff92SAndroid Build Coastguard Worker { 236*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 237*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 238*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::GetInputs(model, 0, 1), armnn::ParseException); 239*89c4ff92SAndroid Build Coastguard Worker } 240*89c4ff92SAndroid Build Coastguard Worker 241*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputsInvalidOperator") 242*89c4ff92SAndroid Build Coastguard Worker { 243*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 244*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.size()); 245*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(TfLiteParserImpl::GetOutputs(model, 0, 1), armnn::ParseException); 246*89c4ff92SAndroid Build Coastguard Worker } 247*89c4ff92SAndroid Build Coastguard Worker 248*89c4ff92SAndroid Build Coastguard Worker }