1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <initializer_list>
16 #include <vector>
17
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
21 #include "tensorflow/lite/kernels/test_util.h"
22 #include "tensorflow/lite/schema/schema_generated.h"
23
24 namespace tflite {
25 namespace {
26
27 using ::testing::ElementsAreArray;
28
29 class BaseNMSOp : public SingleOpModel {
30 public:
SetScores(std::initializer_list<float> data)31 void SetScores(std::initializer_list<float> data) {
32 PopulateTensor(input_scores_, data);
33 }
34
SetMaxOutputSize(int max_output_size)35 void SetMaxOutputSize(int max_output_size) {
36 PopulateTensor(input_max_output_size_, {max_output_size});
37 }
38
SetScoreThreshold(float score_threshold)39 void SetScoreThreshold(float score_threshold) {
40 PopulateTensor(input_score_threshold_, {score_threshold});
41 }
42
GetSelectedIndices()43 std::vector<int> GetSelectedIndices() {
44 return ExtractVector<int>(output_selected_indices_);
45 }
46
GetSelectedScores()47 std::vector<float> GetSelectedScores() {
48 return ExtractVector<float>(output_selected_scores_);
49 }
50
GetNumSelectedIndices()51 std::vector<int> GetNumSelectedIndices() {
52 return ExtractVector<int>(output_num_selected_indices_);
53 }
54
55 protected:
56 int input_boxes_;
57 int input_scores_;
58 int input_max_output_size_;
59 int input_iou_threshold_;
60 int input_score_threshold_;
61 int input_sigma_;
62
63 int output_selected_indices_;
64 int output_selected_scores_;
65 int output_num_selected_indices_;
66 };
67
68 class NonMaxSuppressionV4OpModel : public BaseNMSOp {
69 public:
NonMaxSuppressionV4OpModel(const float iou_threshold,const bool static_shaped_outputs,const int max_output_size=-1)70 explicit NonMaxSuppressionV4OpModel(const float iou_threshold,
71 const bool static_shaped_outputs,
72 const int max_output_size = -1) {
73 const int num_boxes = 6;
74 input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
75 input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
76 if (static_shaped_outputs) {
77 input_max_output_size_ =
78 AddConstInput(TensorType_INT32, {max_output_size});
79 } else {
80 input_max_output_size_ = AddInput(TensorType_INT32);
81 }
82 input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
83 input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
84
85 output_selected_indices_ = AddOutput(TensorType_INT32);
86
87 output_num_selected_indices_ = AddOutput(TensorType_INT32);
88
89 SetBuiltinOp(BuiltinOperator_NON_MAX_SUPPRESSION_V4,
90 BuiltinOptions_NonMaxSuppressionV4Options,
91 CreateNonMaxSuppressionV4Options(builder_).Union());
92 BuildInterpreter({GetShape(input_boxes_), GetShape(input_scores_),
93 GetShape(input_max_output_size_),
94 GetShape(input_iou_threshold_),
95 GetShape(input_score_threshold_)});
96
97 // Default data.
98 PopulateTensor<float>(input_boxes_, {
99 1, 1, 0, 0, // Box 0
100 0, 0.1, 1, 1.1, // Box 1
101 0, .9f, 1, -0.1, // Box 2
102 0, 10, 1, 11, // Box 3
103 1, 10.1f, 0, 11.1, // Box 4
104 1, 101, 0, 100 // Box 5
105 });
106 }
107 };
108
TEST(NonMaxSuppressionV4OpModel,TestOutput)109 TEST(NonMaxSuppressionV4OpModel, TestOutput) {
110 NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
111 /**static_shaped_outputs=**/ true,
112 /**max_output_size=**/ 6);
113 nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
114 nms.SetScoreThreshold(0.4);
115 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
116 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
117 EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 0, 0, 0, 0}));
118
119 nms.SetScoreThreshold(0.99);
120 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
121 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
122 // The first two indices should be zeroed-out.
123 EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
124 }
125
TEST(NonMaxSuppressionV4OpModel,TestDynamicOutput)126 TEST(NonMaxSuppressionV4OpModel, TestDynamicOutput) {
127 NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
128 /**static_shaped_outputs=**/ false);
129 nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
130 nms.SetScoreThreshold(0.4);
131
132 nms.SetMaxOutputSize(1);
133 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
134 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1}));
135 EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3}));
136
137 nms.SetMaxOutputSize(2);
138 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
139 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
140 EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0}));
141
142 nms.SetScoreThreshold(0.99);
143 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
144 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
145 EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0}));
146 }
147
TEST(NonMaxSuppressionV4OpModel,TestOutputWithZeroMaxOutput)148 TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) {
149 NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
150 /**static_shaped_outputs=**/ true,
151 /**max_output_size=**/ 0);
152 nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
153 nms.SetScoreThreshold(0.4);
154 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
155 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
156 }
157
158 class NonMaxSuppressionV5OpModel : public BaseNMSOp {
159 public:
NonMaxSuppressionV5OpModel(const float iou_threshold,const float sigma,const bool static_shaped_outputs,const int max_output_size=-1)160 explicit NonMaxSuppressionV5OpModel(const float iou_threshold,
161 const float sigma,
162 const bool static_shaped_outputs,
163 const int max_output_size = -1) {
164 const int num_boxes = 6;
165 input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
166 input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
167 if (static_shaped_outputs) {
168 input_max_output_size_ =
169 AddConstInput(TensorType_INT32, {max_output_size});
170 } else {
171 input_max_output_size_ = AddInput(TensorType_INT32);
172 }
173 input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
174 input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
175 input_sigma_ = AddConstInput(TensorType_FLOAT32, {sigma});
176
177 output_selected_indices_ = AddOutput(TensorType_INT32);
178 output_selected_scores_ = AddOutput(TensorType_FLOAT32);
179 output_num_selected_indices_ = AddOutput(TensorType_INT32);
180
181 SetBuiltinOp(BuiltinOperator_NON_MAX_SUPPRESSION_V5,
182 BuiltinOptions_NonMaxSuppressionV5Options,
183 CreateNonMaxSuppressionV5Options(builder_).Union());
184
185 BuildInterpreter(
186 {GetShape(input_boxes_), GetShape(input_scores_),
187 GetShape(input_max_output_size_), GetShape(input_iou_threshold_),
188 GetShape(input_score_threshold_), GetShape(input_sigma_)});
189
190 // Default data.
191 PopulateTensor<float>(input_boxes_, {
192 1, 1, 0, 0, // Box 0
193 0, 0.1, 1, 1.1, // Box 1
194 0, .9f, 1, -0.1, // Box 2
195 0, 10, 1, 11, // Box 3
196 1, 10.1f, 0, 11.1, // Box 4
197 1, 101, 0, 100 // Box 5
198 });
199 }
200 };
201
TEST(NonMaxSuppressionV5OpModel,TestOutput)202 TEST(NonMaxSuppressionV5OpModel, TestOutput) {
203 NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5,
204 /**sigma=**/ 0.5,
205 /**static_shaped_outputs=**/ true,
206 /**max_output_size=**/ 6);
207 nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
208 nms.SetScoreThreshold(0.0);
209 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
210 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({3}));
211 EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 5, 0, 0, 0}));
212 EXPECT_THAT(nms.GetSelectedScores(),
213 ElementsAreArray({0.95, 0.9, 0.3, 0.0, 0.0, 0.0}));
214
215 // No candidate gets selected. But the outputs should be zeroed out.
216 nms.SetScoreThreshold(0.99);
217 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
218 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
219 EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
220 EXPECT_THAT(nms.GetSelectedScores(),
221 ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
222 }
223
TEST(NonMaxSuppressionV5OpModel,TestDynamicOutput)224 TEST(NonMaxSuppressionV5OpModel, TestDynamicOutput) {
225 NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5,
226 /**sigma=**/ 0.5,
227 /**static_shaped_outputs=**/ false,
228 /**max_output_size=**/ 6);
229 nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
230 nms.SetScoreThreshold(0.0);
231
232 nms.SetMaxOutputSize(2);
233 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
234 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
235 EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0}));
236 EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9}));
237
238 nms.SetMaxOutputSize(1);
239 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
240 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1}));
241 EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3}));
242 EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95}));
243
244 nms.SetMaxOutputSize(3);
245 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
246 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({3}));
247 EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 5}));
248 EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9, 0.3}));
249
250 // No candidate gets selected. But the outputs should be zeroed out.
251 nms.SetScoreThreshold(0.99);
252 ASSERT_EQ(nms.Invoke(), kTfLiteOk);
253 EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
254 EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0}));
255 EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.0, 0.0, 0.0}));
256 }
257 } // namespace
258 } // namespace tflite
259