xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/DetectionPostProcess.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 #include "ParserPrototxtFixture.hpp"
8 #include "ParserHelper.hpp"
9 #include <GraphUtils.hpp>
10 
11 #include <armnn/utility/PolymorphicDowncast.hpp>
12 #include <armnnUtils/QuantizeHelper.hpp>
13 
14 TEST_SUITE("TensorflowLiteParser_DetectionPostProcess")
15 {
16 struct DetectionPostProcessFixture : ParserFlatbuffersFixture
17 {
DetectionPostProcessFixtureDetectionPostProcessFixture18     explicit DetectionPostProcessFixture(const std::string& custom_options)
19     {
20         /*
21             The following values were used for the custom_options:
22             use_regular_nms = true
23             max_classes_per_detection = 1
24             detections_per_class = 1
25             nms_score_threshold = 0.0
26             nms_iou_threshold = 0.5
27             max_detections = 3
28             max_detections = 3
29             num_classes = 2
30             h_scale = 5
31             w_scale = 5
32             x_scale = 10
33             y_scale = 10
34         */
35         m_JsonString = R"(
36             {
37                 "version": 3,
38                 "operator_codes": [{
39                     "builtin_code": "CUSTOM",
40                     "custom_code": "TFLite_Detection_PostProcess"
41                 }],
42                 "subgraphs": [{
43                     "tensors": [{
44                             "shape": [1, 6, 4],
45                             "type": "UINT8",
46                             "buffer": 0,
47                             "name": "box_encodings",
48                             "quantization": {
49                                 "min": [0.0],
50                                 "max": [255.0],
51                                 "scale": [1.0],
52                                 "zero_point": [ 1 ]
53                             }
54                         },
55                         {
56                             "shape": [1, 6, 3],
57                             "type": "UINT8",
58                             "buffer": 1,
59                             "name": "scores",
60                             "quantization": {
61                                 "min": [0.0],
62                                 "max": [255.0],
63                                 "scale": [0.01],
64                                 "zero_point": [0]
65                             }
66                         },
67                         {
68                             "shape": [6, 4],
69                             "type": "UINT8",
70                             "buffer": 2,
71                             "name": "anchors",
72                             "quantization": {
73                                 "min": [0.0],
74                                 "max": [255.0],
75                                 "scale": [0.5],
76                                 "zero_point": [0]
77                             }
78                         },
79                         {
80                             "type": "FLOAT32",
81                             "buffer": 3,
82                             "name": "detection_boxes",
83                             "quantization": {}
84                         },
85                         {
86                             "type": "FLOAT32",
87                             "buffer": 4,
88                             "name": "detection_classes",
89                             "quantization": {}
90                         },
91                         {
92                             "type": "FLOAT32",
93                             "buffer": 5,
94                             "name": "detection_scores",
95                             "quantization": {}
96                         },
97                         {
98                             "type": "FLOAT32",
99                             "buffer": 6,
100                             "name": "num_detections",
101                             "quantization": {}
102                         }
103                     ],
104                     "inputs": [0, 1, 2],
105                     "outputs": [3, 4, 5, 6],
106                     "operators": [{
107                         "opcode_index": 0,
108                         "inputs": [0, 1, 2],
109                         "outputs": [3, 4, 5, 6],
110                         "builtin_options_type": 0,
111                         "custom_options": [)" + custom_options + R"(],
112                         "custom_options_format": "FLEXBUFFERS"
113                     }]
114                 }],
115                 "buffers": [{},
116                     {},
117                     { "data": [ 1, 1,   2, 2,
118                                 1, 1,   2, 2,
119                                 1, 1,   2, 2,
120                                 1, 21,  2, 2,
121                                 1, 21,  2, 2,
122                                 1, 201, 2, 2]},
123                     {},
124                     {},
125                     {},
126                     {},
127                 ]
128             }
129         )";
130     }
131 };
132 
133 struct ParseDetectionPostProcessCustomOptions : DetectionPostProcessFixture
134 {
135 private:
GenerateDescriptorParseDetectionPostProcessCustomOptions136     static armnn::DetectionPostProcessDescriptor GenerateDescriptor()
137     {
138         static armnn::DetectionPostProcessDescriptor descriptor;
139         descriptor.m_UseRegularNms          = true;
140         descriptor.m_MaxDetections          = 3u;
141         descriptor.m_MaxClassesPerDetection = 1u;
142         descriptor.m_DetectionsPerClass     = 1u;
143         descriptor.m_NumClasses             = 2u;
144         descriptor.m_NmsScoreThreshold      = 0.0f;
145         descriptor.m_NmsIouThreshold        = 0.5f;
146         descriptor.m_ScaleH                 = 5.0f;
147         descriptor.m_ScaleW                 = 5.0f;
148         descriptor.m_ScaleX                 = 10.0f;
149         descriptor.m_ScaleY                 = 10.0f;
150 
151         return descriptor;
152     }
153 
154 public:
ParseDetectionPostProcessCustomOptionsParseDetectionPostProcessCustomOptions155     ParseDetectionPostProcessCustomOptions()
156         : DetectionPostProcessFixture(
157             GenerateDetectionPostProcessJsonString(GenerateDescriptor()))
158     {}
159 };
160 
161 TEST_CASE_FIXTURE(ParseDetectionPostProcessCustomOptions, "ParseDetectionPostProcess")
162 {
163     Setup();
164 
165     // Inputs
166     using UnquantizedContainer = std::vector<float>;
167     UnquantizedContainer boxEncodings =
168     {
169         0.0f,  0.0f, 0.0f, 0.0f,
170         0.0f,  1.0f, 0.0f, 0.0f,
171         0.0f, -1.0f, 0.0f, 0.0f,
172         0.0f,  0.0f, 0.0f, 0.0f,
173         0.0f,  1.0f, 0.0f, 0.0f,
174         0.0f,  0.0f, 0.0f, 0.0f
175     };
176 
177     UnquantizedContainer scores =
178     {
179         0.0f, 0.9f,  0.8f,
180         0.0f, 0.75f, 0.72f,
181         0.0f, 0.6f,  0.5f,
182         0.0f, 0.93f, 0.95f,
183         0.0f, 0.5f,  0.4f,
184         0.0f, 0.3f,  0.2f
185     };
186 
187     // Outputs
188     UnquantizedContainer detectionBoxes =
189     {
190         0.0f, 10.0f, 1.0f, 11.0f,
191         0.0f, 10.0f, 1.0f, 11.0f,
192         0.0f, 0.0f,  0.0f, 0.0f
193     };
194 
195     UnquantizedContainer detectionClasses = { 1.0f,  0.0f,  0.0f };
196     UnquantizedContainer detectionScores  = { 0.95f, 0.93f, 0.0f };
197 
198     UnquantizedContainer numDetections    = { 2.0f };
199 
200     // Quantize inputs and outputs
201     using QuantizedContainer = std::vector<uint8_t>;
202 
203     QuantizedContainer quantBoxEncodings = armnnUtils::QuantizedVector<uint8_t>(boxEncodings, 1.00f, 1);
204     QuantizedContainer quantScores       = armnnUtils::QuantizedVector<uint8_t>(scores,       0.01f, 0);
205 
206     std::map<std::string, QuantizedContainer> input =
207     {
208         { "box_encodings", quantBoxEncodings },
209         { "scores", quantScores }
210     };
211 
212     std::map<std::string, UnquantizedContainer> output =
213     {
214         { "detection_boxes", detectionBoxes},
215         { "detection_classes", detectionClasses},
216         { "detection_scores", detectionScores},
217         { "num_detections", numDetections}
218     };
219 
220     RunTest<armnn::DataType::QAsymmU8, armnn::DataType::Float32>(0, input, output);
221 }
222 
223 TEST_CASE_FIXTURE(ParseDetectionPostProcessCustomOptions, "DetectionPostProcessGraphStructureTest")
224 {
225     /*
226        Inputs:            box_encodings  scores
227                                \          /
228                             DetectionPostProcess
229                           /        /     \       \
230                          /        /       \       \
231        Outputs:     detection detection detection num_detections
232                     boxes     classes   scores
233     */
234 
235     ReadStringToBinary();
236 
237     armnn::INetworkPtr network = m_Parser->CreateNetworkFromBinary(m_GraphBinary);
238 
239     auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec());
240 
241     armnn::Graph& graph = GetGraphForTesting(optimized.get());
242 
243     // Check the number of layers in the graph
244     CHECK((graph.GetNumInputs() == 2));
245     CHECK((graph.GetNumOutputs() == 4));
246     CHECK((graph.GetNumLayers() == 7));
247 
248     // Input layers
249     armnn::Layer* boxEncodingLayer = GetFirstLayerWithName(graph, "box_encodings");
250     CHECK((boxEncodingLayer->GetType() == armnn::LayerType::Input));
251     CHECK(CheckNumberOfInputSlot(boxEncodingLayer, 0));
252     CHECK(CheckNumberOfOutputSlot(boxEncodingLayer, 1));
253 
254     armnn::Layer* scoresLayer = GetFirstLayerWithName(graph, "scores");
255     CHECK((scoresLayer->GetType() == armnn::LayerType::Input));
256     CHECK(CheckNumberOfInputSlot(scoresLayer, 0));
257     CHECK(CheckNumberOfOutputSlot(scoresLayer, 1));
258 
259     // DetectionPostProcess layer
260     armnn::Layer* detectionPostProcessLayer = GetFirstLayerWithName(graph, "DetectionPostProcess:0:0");
261     CHECK((detectionPostProcessLayer->GetType() == armnn::LayerType::DetectionPostProcess));
262     CHECK(CheckNumberOfInputSlot(detectionPostProcessLayer, 2));
263     CHECK(CheckNumberOfOutputSlot(detectionPostProcessLayer, 4));
264 
265     // Output layers
266     armnn::Layer* detectionBoxesLayer = GetFirstLayerWithName(graph, "detection_boxes");
267     CHECK((detectionBoxesLayer->GetType() == armnn::LayerType::Output));
268     CHECK(CheckNumberOfInputSlot(detectionBoxesLayer, 1));
269     CHECK(CheckNumberOfOutputSlot(detectionBoxesLayer, 0));
270 
271     armnn::Layer* detectionClassesLayer = GetFirstLayerWithName(graph, "detection_classes");
272     CHECK((detectionClassesLayer->GetType() == armnn::LayerType::Output));
273     CHECK(CheckNumberOfInputSlot(detectionClassesLayer, 1));
274     CHECK(CheckNumberOfOutputSlot(detectionClassesLayer, 0));
275 
276     armnn::Layer* detectionScoresLayer = GetFirstLayerWithName(graph, "detection_scores");
277     CHECK((detectionScoresLayer->GetType() == armnn::LayerType::Output));
278     CHECK(CheckNumberOfInputSlot(detectionScoresLayer, 1));
279     CHECK(CheckNumberOfOutputSlot(detectionScoresLayer, 0));
280 
281     armnn::Layer* numDetectionsLayer = GetFirstLayerWithName(graph, "num_detections");
282     CHECK((numDetectionsLayer->GetType() == armnn::LayerType::Output));
283     CHECK(CheckNumberOfInputSlot(numDetectionsLayer, 1));
284     CHECK(CheckNumberOfOutputSlot(numDetectionsLayer, 0));
285 
286     // Check the connections
287     armnn::TensorInfo boxEncodingTensor(armnn::TensorShape({ 1, 6, 4 }), armnn::DataType::QAsymmU8, 1, 1);
288     armnn::TensorInfo scoresTensor(armnn::TensorShape({ 1, 6, 3 }), armnn::DataType::QAsymmU8,
289                                                       0.00999999978f, 0);
290 
291     armnn::TensorInfo detectionBoxesTensor(armnn::TensorShape({ 1, 3, 4 }), armnn::DataType::Float32);
292     armnn::TensorInfo detectionClassesTensor(armnn::TensorShape({ 1, 3 }), armnn::DataType::Float32);
293     armnn::TensorInfo detectionScoresTensor(armnn::TensorShape({ 1, 3 }), armnn::DataType::Float32);
294     armnn::TensorInfo numDetectionsTensor(armnn::TensorShape({ 1 } ), armnn::DataType::Float32);
295 
296     CHECK(IsConnected(boxEncodingLayer, detectionPostProcessLayer, 0, 0, boxEncodingTensor));
297     CHECK(IsConnected(scoresLayer, detectionPostProcessLayer, 0, 1, scoresTensor));
298     CHECK(IsConnected(detectionPostProcessLayer, detectionBoxesLayer, 0, 0, detectionBoxesTensor));
299     CHECK(IsConnected(detectionPostProcessLayer, detectionClassesLayer, 1, 0, detectionClassesTensor));
300     CHECK(IsConnected(detectionPostProcessLayer, detectionScoresLayer, 2, 0, detectionScoresTensor));
301     CHECK(IsConnected(detectionPostProcessLayer, numDetectionsLayer, 3, 0, numDetectionsTensor));
302 }
303 
304 }
305