xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/non_max_suppression_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <initializer_list>
16 #include <vector>
17 
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
21 #include "tensorflow/lite/kernels/test_util.h"
22 #include "tensorflow/lite/schema/schema_generated.h"
23 
24 namespace tflite {
25 namespace {
26 
27 using ::testing::ElementsAreArray;
28 
29 class BaseNMSOp : public SingleOpModel {
30  public:
SetScores(std::initializer_list<float> data)31   void SetScores(std::initializer_list<float> data) {
32     PopulateTensor(input_scores_, data);
33   }
34 
SetMaxOutputSize(int max_output_size)35   void SetMaxOutputSize(int max_output_size) {
36     PopulateTensor(input_max_output_size_, {max_output_size});
37   }
38 
SetScoreThreshold(float score_threshold)39   void SetScoreThreshold(float score_threshold) {
40     PopulateTensor(input_score_threshold_, {score_threshold});
41   }
42 
GetSelectedIndices()43   std::vector<int> GetSelectedIndices() {
44     return ExtractVector<int>(output_selected_indices_);
45   }
46 
GetSelectedScores()47   std::vector<float> GetSelectedScores() {
48     return ExtractVector<float>(output_selected_scores_);
49   }
50 
GetNumSelectedIndices()51   std::vector<int> GetNumSelectedIndices() {
52     return ExtractVector<int>(output_num_selected_indices_);
53   }
54 
55  protected:
56   int input_boxes_;
57   int input_scores_;
58   int input_max_output_size_;
59   int input_iou_threshold_;
60   int input_score_threshold_;
61   int input_sigma_;
62 
63   int output_selected_indices_;
64   int output_selected_scores_;
65   int output_num_selected_indices_;
66 };
67 
68 class NonMaxSuppressionV4OpModel : public BaseNMSOp {
69  public:
NonMaxSuppressionV4OpModel(const float iou_threshold,const bool static_shaped_outputs,const int max_output_size=-1)70   explicit NonMaxSuppressionV4OpModel(const float iou_threshold,
71                                       const bool static_shaped_outputs,
72                                       const int max_output_size = -1) {
73     const int num_boxes = 6;
74     input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
75     input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
76     if (static_shaped_outputs) {
77       input_max_output_size_ =
78           AddConstInput(TensorType_INT32, {max_output_size});
79     } else {
80       input_max_output_size_ = AddInput(TensorType_INT32);
81     }
82     input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
83     input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
84 
85     output_selected_indices_ = AddOutput(TensorType_INT32);
86 
87     output_num_selected_indices_ = AddOutput(TensorType_INT32);
88 
89     SetBuiltinOp(BuiltinOperator_NON_MAX_SUPPRESSION_V4,
90                  BuiltinOptions_NonMaxSuppressionV4Options,
91                  CreateNonMaxSuppressionV4Options(builder_).Union());
92     BuildInterpreter({GetShape(input_boxes_), GetShape(input_scores_),
93                       GetShape(input_max_output_size_),
94                       GetShape(input_iou_threshold_),
95                       GetShape(input_score_threshold_)});
96 
97     // Default data.
98     PopulateTensor<float>(input_boxes_, {
99                                             1, 1,     0, 0,     // Box 0
100                                             0, 0.1,   1, 1.1,   // Box 1
101                                             0, .9f,   1, -0.1,  // Box 2
102                                             0, 10,    1, 11,    // Box 3
103                                             1, 10.1f, 0, 11.1,  // Box 4
104                                             1, 101,   0, 100    // Box 5
105                                         });
106   }
107 };
108 
TEST(NonMaxSuppressionV4OpModel,TestOutput)109 TEST(NonMaxSuppressionV4OpModel, TestOutput) {
110   NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
111                                  /**static_shaped_outputs=**/ true,
112                                  /**max_output_size=**/ 6);
113   nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
114   nms.SetScoreThreshold(0.4);
115   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
116   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
117   EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 0, 0, 0, 0}));
118 
119   nms.SetScoreThreshold(0.99);
120   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
121   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
122   // The first two indices should be zeroed-out.
123   EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
124 }
125 
TEST(NonMaxSuppressionV4OpModel,TestDynamicOutput)126 TEST(NonMaxSuppressionV4OpModel, TestDynamicOutput) {
127   NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
128                                  /**static_shaped_outputs=**/ false);
129   nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
130   nms.SetScoreThreshold(0.4);
131 
132   nms.SetMaxOutputSize(1);
133   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
134   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1}));
135   EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3}));
136 
137   nms.SetMaxOutputSize(2);
138   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
139   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
140   EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0}));
141 
142   nms.SetScoreThreshold(0.99);
143   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
144   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
145   EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0}));
146 }
147 
TEST(NonMaxSuppressionV4OpModel,TestOutputWithZeroMaxOutput)148 TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) {
149   NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
150                                  /**static_shaped_outputs=**/ true,
151                                  /**max_output_size=**/ 0);
152   nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
153   nms.SetScoreThreshold(0.4);
154   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
155   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
156 }
157 
158 class NonMaxSuppressionV5OpModel : public BaseNMSOp {
159  public:
NonMaxSuppressionV5OpModel(const float iou_threshold,const float sigma,const bool static_shaped_outputs,const int max_output_size=-1)160   explicit NonMaxSuppressionV5OpModel(const float iou_threshold,
161                                       const float sigma,
162                                       const bool static_shaped_outputs,
163                                       const int max_output_size = -1) {
164     const int num_boxes = 6;
165     input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
166     input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
167     if (static_shaped_outputs) {
168       input_max_output_size_ =
169           AddConstInput(TensorType_INT32, {max_output_size});
170     } else {
171       input_max_output_size_ = AddInput(TensorType_INT32);
172     }
173     input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
174     input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
175     input_sigma_ = AddConstInput(TensorType_FLOAT32, {sigma});
176 
177     output_selected_indices_ = AddOutput(TensorType_INT32);
178     output_selected_scores_ = AddOutput(TensorType_FLOAT32);
179     output_num_selected_indices_ = AddOutput(TensorType_INT32);
180 
181     SetBuiltinOp(BuiltinOperator_NON_MAX_SUPPRESSION_V5,
182                  BuiltinOptions_NonMaxSuppressionV5Options,
183                  CreateNonMaxSuppressionV5Options(builder_).Union());
184 
185     BuildInterpreter(
186         {GetShape(input_boxes_), GetShape(input_scores_),
187          GetShape(input_max_output_size_), GetShape(input_iou_threshold_),
188          GetShape(input_score_threshold_), GetShape(input_sigma_)});
189 
190     // Default data.
191     PopulateTensor<float>(input_boxes_, {
192                                             1, 1,     0, 0,     // Box 0
193                                             0, 0.1,   1, 1.1,   // Box 1
194                                             0, .9f,   1, -0.1,  // Box 2
195                                             0, 10,    1, 11,    // Box 3
196                                             1, 10.1f, 0, 11.1,  // Box 4
197                                             1, 101,   0, 100    // Box 5
198                                         });
199   }
200 };
201 
TEST(NonMaxSuppressionV5OpModel,TestOutput)202 TEST(NonMaxSuppressionV5OpModel, TestOutput) {
203   NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5,
204                                  /**sigma=**/ 0.5,
205                                  /**static_shaped_outputs=**/ true,
206                                  /**max_output_size=**/ 6);
207   nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
208   nms.SetScoreThreshold(0.0);
209   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
210   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({3}));
211   EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 5, 0, 0, 0}));
212   EXPECT_THAT(nms.GetSelectedScores(),
213               ElementsAreArray({0.95, 0.9, 0.3, 0.0, 0.0, 0.0}));
214 
215   // No candidate gets selected. But the outputs should be zeroed out.
216   nms.SetScoreThreshold(0.99);
217   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
218   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
219   EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
220   EXPECT_THAT(nms.GetSelectedScores(),
221               ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
222 }
223 
TEST(NonMaxSuppressionV5OpModel,TestDynamicOutput)224 TEST(NonMaxSuppressionV5OpModel, TestDynamicOutput) {
225   NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5,
226                                  /**sigma=**/ 0.5,
227                                  /**static_shaped_outputs=**/ false,
228                                  /**max_output_size=**/ 6);
229   nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
230   nms.SetScoreThreshold(0.0);
231 
232   nms.SetMaxOutputSize(2);
233   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
234   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
235   EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0}));
236   EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9}));
237 
238   nms.SetMaxOutputSize(1);
239   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
240   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1}));
241   EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3}));
242   EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95}));
243 
244   nms.SetMaxOutputSize(3);
245   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
246   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({3}));
247   EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 5}));
248   EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9, 0.3}));
249 
250   // No candidate gets selected. But the outputs should be zeroed out.
251   nms.SetScoreThreshold(0.99);
252   ASSERT_EQ(nms.Invoke(), kTfLiteOk);
253   EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
254   EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0}));
255   EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.0, 0.0, 0.0}));
256 }
257 }  // namespace
258 }  // namespace tflite
259