xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/GetTensorIds.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 using armnnTfLiteParser::TfLiteParserImpl;
9 using ModelPtr = TfLiteParserImpl::ModelPtr;
10 
11 TEST_SUITE("TensorflowLiteParser_GetTensorIds")
12 {
13 struct GetTensorIdsFixture : public ParserFlatbuffersFixture
14 {
GetTensorIdsFixtureGetTensorIdsFixture15     explicit GetTensorIdsFixture(const std::string& inputs, const std::string& outputs)
16     {
17         m_JsonString = R"(
18         {
19             "version": 3,
20             "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" } ],
21             "subgraphs": [
22             {
23                 "tensors": [
24                 {
25                     "shape": [ 1, 1, 1, 1 ] ,
26                     "type": "UINT8",
27                             "buffer": 0,
28                             "name": "OutputTensor",
29                             "quantization": {
30                                 "min": [ 0.0 ],
31                                 "max": [ 255.0 ],
32                                 "scale": [ 1.0 ],
33                                 "zero_point": [ 0 ]
34                             }
35                 },
36                 {
37                     "shape": [ 1, 2, 2, 1 ] ,
38                     "type": "UINT8",
39                             "buffer": 1,
40                             "name": "InputTensor",
41                             "quantization": {
42                                 "min": [ 0.0 ],
43                                 "max": [ 255.0 ],
44                                 "scale": [ 1.0 ],
45                                 "zero_point": [ 0 ]
46                             }
47                 }
48                 ],
49                 "inputs": [ 1 ],
50                 "outputs": [ 0 ],
51                 "operators": [ {
52                         "opcode_index": 0,
53                         "inputs": )"
54                             + inputs
55                             + R"(,
56                         "outputs": )"
57                             + outputs
58                             + R"(,
59                         "builtin_options_type": "Pool2DOptions",
60                         "builtin_options":
61                         {
62                             "padding": "VALID",
63                             "stride_w": 2,
64                             "stride_h": 2,
65                             "filter_width": 2,
66                             "filter_height": 2,
67                             "fused_activation_function": "NONE"
68                         },
69                         "custom_options_format": "FLEXBUFFERS"
70                     } ]
71                 }
72             ],
73             "description": "Test loading a model",
74             "buffers" : [ {}, {} ]
75         })";
76 
77         ReadStringToBinary();
78     }
79 };
80 
81 struct GetEmptyTensorIdsFixture : GetTensorIdsFixture
82 {
GetEmptyTensorIdsFixtureGetEmptyTensorIdsFixture83     GetEmptyTensorIdsFixture() : GetTensorIdsFixture("[ ]", "[ ]") {}
84 };
85 
86 struct GetInputOutputTensorIdsFixture : GetTensorIdsFixture
87 {
GetInputOutputTensorIdsFixtureGetInputOutputTensorIdsFixture88     GetInputOutputTensorIdsFixture() : GetTensorIdsFixture("[ 0, 1, 2 ]", "[ 3 ]") {}
89 };
90 
91 TEST_CASE_FIXTURE(GetEmptyTensorIdsFixture, "GetEmptyInputTensorIds")
92 {
93     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
94                                                                              m_GraphBinary.size());
95     std::vector<int32_t> expectedIds = { };
96     std::vector<int32_t> inputTensorIds = TfLiteParserImpl::GetInputTensorIds(model, 0, 0);
97     CHECK(std::equal(expectedIds.begin(), expectedIds.end(),
98                                   inputTensorIds.begin(), inputTensorIds.end()));
99 }
100 
101 TEST_CASE_FIXTURE(GetEmptyTensorIdsFixture, "GetEmptyOutputTensorIds")
102 {
103     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
104                                                                              m_GraphBinary.size());
105     std::vector<int32_t> expectedIds = { };
106     std::vector<int32_t> outputTensorIds = TfLiteParserImpl::GetOutputTensorIds(model, 0, 0);
107     CHECK(std::equal(expectedIds.begin(), expectedIds.end(),
108                                   outputTensorIds.begin(), outputTensorIds.end()));
109 }
110 
111 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIds")
112 {
113     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
114                                                                              m_GraphBinary.size());
115     std::vector<int32_t> expectedInputIds = { 0, 1, 2 };
116     std::vector<int32_t> inputTensorIds = TfLiteParserImpl::GetInputTensorIds(model, 0, 0);
117     CHECK(std::equal(expectedInputIds.begin(), expectedInputIds.end(),
118                                   inputTensorIds.begin(), inputTensorIds.end()));
119 }
120 
121 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetOutputTensorIds")
122 {
123     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
124                                                                              m_GraphBinary.size());
125     std::vector<int32_t> expectedOutputIds = { 3 };
126     std::vector<int32_t> outputTensorIds = TfLiteParserImpl::GetOutputTensorIds(model, 0, 0);
127     CHECK(std::equal(expectedOutputIds.begin(), expectedOutputIds.end(),
128                                   outputTensorIds.begin(), outputTensorIds.end()));
129 }
130 
131 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIdsNullModel")
132 {
133     CHECK_THROWS_AS(TfLiteParserImpl::GetInputTensorIds(nullptr, 0, 0), armnn::ParseException);
134 }
135 
136 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetOutputTensorIdsNullModel")
137 {
138     CHECK_THROWS_AS(TfLiteParserImpl::GetOutputTensorIds(nullptr, 0, 0), armnn::ParseException);
139 }
140 
141 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIdsInvalidSubgraph")
142 {
143     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
144                                                                              m_GraphBinary.size());
145     CHECK_THROWS_AS(TfLiteParserImpl::GetInputTensorIds(model, 1, 0), armnn::ParseException);
146 }
147 
148 TEST_CASE_FIXTURE( GetInputOutputTensorIdsFixture, "GetOutputTensorIdsInvalidSubgraph")
149 {
150     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
151                                                                              m_GraphBinary.size());
152     CHECK_THROWS_AS(TfLiteParserImpl::GetOutputTensorIds(model, 1, 0), armnn::ParseException);
153 }
154 
155 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIdsInvalidOperator")
156 {
157     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
158                                                                              m_GraphBinary.size());
159     CHECK_THROWS_AS(TfLiteParserImpl::GetInputTensorIds(model, 0, 1), armnn::ParseException);
160 }
161 
162 TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetOutputTensorIdsInvalidOperator")
163 {
164     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
165                                                                              m_GraphBinary.size());
166     CHECK_THROWS_AS(TfLiteParserImpl::GetOutputTensorIds(model, 0, 1), armnn::ParseException);
167 }
168 
169 }
170