1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 using armnnTfLiteParser::TfLiteParserImpl; 9 using ModelPtr = TfLiteParserImpl::ModelPtr; 10 11 TEST_SUITE("TensorflowLiteParser_GetTensorIds") 12 { 13 struct GetTensorIdsFixture : public ParserFlatbuffersFixture 14 { GetTensorIdsFixtureGetTensorIdsFixture15 explicit GetTensorIdsFixture(const std::string& inputs, const std::string& outputs) 16 { 17 m_JsonString = R"( 18 { 19 "version": 3, 20 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" } ], 21 "subgraphs": [ 22 { 23 "tensors": [ 24 { 25 "shape": [ 1, 1, 1, 1 ] , 26 "type": "UINT8", 27 "buffer": 0, 28 "name": "OutputTensor", 29 "quantization": { 30 "min": [ 0.0 ], 31 "max": [ 255.0 ], 32 "scale": [ 1.0 ], 33 "zero_point": [ 0 ] 34 } 35 }, 36 { 37 "shape": [ 1, 2, 2, 1 ] , 38 "type": "UINT8", 39 "buffer": 1, 40 "name": "InputTensor", 41 "quantization": { 42 "min": [ 0.0 ], 43 "max": [ 255.0 ], 44 "scale": [ 1.0 ], 45 "zero_point": [ 0 ] 46 } 47 } 48 ], 49 "inputs": [ 1 ], 50 "outputs": [ 0 ], 51 "operators": [ { 52 "opcode_index": 0, 53 "inputs": )" 54 + inputs 55 + R"(, 56 "outputs": )" 57 + outputs 58 + R"(, 59 "builtin_options_type": "Pool2DOptions", 60 "builtin_options": 61 { 62 "padding": "VALID", 63 "stride_w": 2, 64 "stride_h": 2, 65 "filter_width": 2, 66 "filter_height": 2, 67 "fused_activation_function": "NONE" 68 }, 69 "custom_options_format": "FLEXBUFFERS" 70 } ] 71 } 72 ], 73 "description": "Test loading a model", 74 "buffers" : [ {}, {} ] 75 })"; 76 77 ReadStringToBinary(); 78 } 79 }; 80 81 struct GetEmptyTensorIdsFixture : GetTensorIdsFixture 82 { GetEmptyTensorIdsFixtureGetEmptyTensorIdsFixture83 GetEmptyTensorIdsFixture() : GetTensorIdsFixture("[ ]", "[ ]") {} 84 }; 85 86 struct GetInputOutputTensorIdsFixture : GetTensorIdsFixture 87 { GetInputOutputTensorIdsFixtureGetInputOutputTensorIdsFixture88 GetInputOutputTensorIdsFixture() : GetTensorIdsFixture("[ 0, 1, 2 ]", "[ 3 ]") {} 89 }; 90 91 TEST_CASE_FIXTURE(GetEmptyTensorIdsFixture, "GetEmptyInputTensorIds") 92 { 93 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 94 m_GraphBinary.size()); 95 std::vector<int32_t> expectedIds = { }; 96 std::vector<int32_t> inputTensorIds = TfLiteParserImpl::GetInputTensorIds(model, 0, 0); 97 CHECK(std::equal(expectedIds.begin(), expectedIds.end(), 98 inputTensorIds.begin(), inputTensorIds.end())); 99 } 100 101 TEST_CASE_FIXTURE(GetEmptyTensorIdsFixture, "GetEmptyOutputTensorIds") 102 { 103 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 104 m_GraphBinary.size()); 105 std::vector<int32_t> expectedIds = { }; 106 std::vector<int32_t> outputTensorIds = TfLiteParserImpl::GetOutputTensorIds(model, 0, 0); 107 CHECK(std::equal(expectedIds.begin(), expectedIds.end(), 108 outputTensorIds.begin(), outputTensorIds.end())); 109 } 110 111 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIds") 112 { 113 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 114 m_GraphBinary.size()); 115 std::vector<int32_t> expectedInputIds = { 0, 1, 2 }; 116 std::vector<int32_t> inputTensorIds = TfLiteParserImpl::GetInputTensorIds(model, 0, 0); 117 CHECK(std::equal(expectedInputIds.begin(), expectedInputIds.end(), 118 inputTensorIds.begin(), inputTensorIds.end())); 119 } 120 121 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetOutputTensorIds") 122 { 123 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 124 m_GraphBinary.size()); 125 std::vector<int32_t> expectedOutputIds = { 3 }; 126 std::vector<int32_t> outputTensorIds = TfLiteParserImpl::GetOutputTensorIds(model, 0, 0); 127 CHECK(std::equal(expectedOutputIds.begin(), expectedOutputIds.end(), 128 outputTensorIds.begin(), outputTensorIds.end())); 129 } 130 131 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIdsNullModel") 132 { 133 CHECK_THROWS_AS(TfLiteParserImpl::GetInputTensorIds(nullptr, 0, 0), armnn::ParseException); 134 } 135 136 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetOutputTensorIdsNullModel") 137 { 138 CHECK_THROWS_AS(TfLiteParserImpl::GetOutputTensorIds(nullptr, 0, 0), armnn::ParseException); 139 } 140 141 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIdsInvalidSubgraph") 142 { 143 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 144 m_GraphBinary.size()); 145 CHECK_THROWS_AS(TfLiteParserImpl::GetInputTensorIds(model, 1, 0), armnn::ParseException); 146 } 147 148 TEST_CASE_FIXTURE( GetInputOutputTensorIdsFixture, "GetOutputTensorIdsInvalidSubgraph") 149 { 150 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 151 m_GraphBinary.size()); 152 CHECK_THROWS_AS(TfLiteParserImpl::GetOutputTensorIds(model, 1, 0), armnn::ParseException); 153 } 154 155 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIdsInvalidOperator") 156 { 157 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 158 m_GraphBinary.size()); 159 CHECK_THROWS_AS(TfLiteParserImpl::GetInputTensorIds(model, 0, 1), armnn::ParseException); 160 } 161 162 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetOutputTensorIdsInvalidOperator") 163 { 164 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(), 165 m_GraphBinary.size()); 166 CHECK_THROWS_AS(TfLiteParserImpl::GetOutputTensorIds(model, 0, 1), armnn::ParseException); 167 } 168 169 } 170