1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <gtest/gtest.h>
10
11 #include "softmax-operator-tester.h"
12
13
TEST(SOFTMAX_NC_F16,single_class)14 TEST(SOFTMAX_NC_F16, single_class) {
15 SoftMaxOperatorTester()
16 .batch_size(1)
17 .channels(1)
18 .iterations(100)
19 .TestF16();
20 }
21
TEST(SOFTMAX_NC_F16,two_classes)22 TEST(SOFTMAX_NC_F16, two_classes) {
23 SoftMaxOperatorTester()
24 .batch_size(1)
25 .channels(2)
26 .iterations(100)
27 .TestF16();
28 }
29
TEST(SOFTMAX_NC_F16,many_classes)30 TEST(SOFTMAX_NC_F16, many_classes) {
31 for (size_t channels = 3; channels < 100; channels++) {
32 SoftMaxOperatorTester()
33 .batch_size(1)
34 .channels(channels)
35 .iterations(1)
36 .TestF16();
37 }
38 }
39
TEST(SOFTMAX_NC_F16,cifar_classes)40 TEST(SOFTMAX_NC_F16, cifar_classes) {
41 // CIFAR-10
42 SoftMaxOperatorTester()
43 .batch_size(1)
44 .channels(10)
45 .iterations(15)
46 .TestF16();
47 // CIFAR-100
48 SoftMaxOperatorTester()
49 .batch_size(1)
50 .channels(100)
51 .iterations(15)
52 .TestF16();
53 }
54
TEST(SOFTMAX_NC_F16,imagenet_classes)55 TEST(SOFTMAX_NC_F16, imagenet_classes) {
56 // ImageNet-1K
57 SoftMaxOperatorTester()
58 .batch_size(1)
59 .channels(1000)
60 .iterations(10)
61 .TestF16();
62 // ImageNet-1K+1
63 SoftMaxOperatorTester()
64 .batch_size(1)
65 .channels(1001)
66 .iterations(10)
67 .TestF16();
68 // ImageNet-22K
69 SoftMaxOperatorTester()
70 .batch_size(1)
71 .channels(21841)
72 .iterations(10)
73 .TestF16();
74 }
75
TEST(SOFTMAX_NC_F16,small_batch)76 TEST(SOFTMAX_NC_F16, small_batch) {
77 for (size_t channels = 1; channels < 100; channels += 5) {
78 SoftMaxOperatorTester()
79 .batch_size(3)
80 .channels(channels)
81 .iterations(3)
82 .TestF16();
83 }
84 }
85
TEST(SOFTMAX_NC_F16,small_batch_with_input_stride)86 TEST(SOFTMAX_NC_F16, small_batch_with_input_stride) {
87 for (size_t channels = 1; channels < 100; channels += 5) {
88 SoftMaxOperatorTester()
89 .batch_size(3)
90 .channels(channels)
91 .input_stride(129)
92 .iterations(3)
93 .TestF16();
94 }
95 }
96
TEST(SOFTMAX_NC_F16,small_batch_with_output_stride)97 TEST(SOFTMAX_NC_F16, small_batch_with_output_stride) {
98 for (size_t channels = 1; channels < 100; channels += 5) {
99 SoftMaxOperatorTester()
100 .batch_size(3)
101 .channels(channels)
102 .output_stride(117)
103 .iterations(3)
104 .TestF16();
105 }
106 }
107
TEST(SOFTMAX_NC_F16,strided_batch_with_input_and_output_stride)108 TEST(SOFTMAX_NC_F16, strided_batch_with_input_and_output_stride) {
109 for (size_t channels = 1; channels < 100; channels += 5) {
110 SoftMaxOperatorTester()
111 .batch_size(3)
112 .channels(channels)
113 .input_stride(129)
114 .output_stride(117)
115 .iterations(3)
116 .TestF16();
117 }
118 }
119
120
TEST(SOFTMAX_NC_F32,single_class)121 TEST(SOFTMAX_NC_F32, single_class) {
122 SoftMaxOperatorTester()
123 .batch_size(1)
124 .channels(1)
125 .iterations(100)
126 .TestF32();
127 }
128
TEST(SOFTMAX_NC_F32,two_classes)129 TEST(SOFTMAX_NC_F32, two_classes) {
130 SoftMaxOperatorTester()
131 .batch_size(1)
132 .channels(2)
133 .iterations(100)
134 .TestF32();
135 }
136
TEST(SOFTMAX_NC_F32,many_classes)137 TEST(SOFTMAX_NC_F32, many_classes) {
138 for (size_t channels = 3; channels < 100; channels++) {
139 SoftMaxOperatorTester()
140 .batch_size(1)
141 .channels(channels)
142 .iterations(1)
143 .TestF32();
144 }
145 }
146
TEST(SOFTMAX_NC_F32,cifar_classes)147 TEST(SOFTMAX_NC_F32, cifar_classes) {
148 // CIFAR-10
149 SoftMaxOperatorTester()
150 .batch_size(1)
151 .channels(10)
152 .iterations(15)
153 .TestF32();
154 // CIFAR-100
155 SoftMaxOperatorTester()
156 .batch_size(1)
157 .channels(100)
158 .iterations(15)
159 .TestF32();
160 }
161
TEST(SOFTMAX_NC_F32,imagenet_classes)162 TEST(SOFTMAX_NC_F32, imagenet_classes) {
163 // ImageNet-1K
164 SoftMaxOperatorTester()
165 .batch_size(1)
166 .channels(1000)
167 .iterations(10)
168 .TestF32();
169 // ImageNet-1K+1
170 SoftMaxOperatorTester()
171 .batch_size(1)
172 .channels(1001)
173 .iterations(10)
174 .TestF32();
175 // ImageNet-22K
176 SoftMaxOperatorTester()
177 .batch_size(1)
178 .channels(21841)
179 .iterations(10)
180 .TestF32();
181 }
182
TEST(SOFTMAX_NC_F32,small_batch)183 TEST(SOFTMAX_NC_F32, small_batch) {
184 for (size_t channels = 1; channels < 100; channels += 5) {
185 SoftMaxOperatorTester()
186 .batch_size(3)
187 .channels(channels)
188 .iterations(3)
189 .TestF32();
190 }
191 }
192
TEST(SOFTMAX_NC_F32,small_batch_with_input_stride)193 TEST(SOFTMAX_NC_F32, small_batch_with_input_stride) {
194 for (size_t channels = 1; channels < 100; channels += 5) {
195 SoftMaxOperatorTester()
196 .batch_size(3)
197 .channels(channels)
198 .input_stride(129)
199 .iterations(3)
200 .TestF32();
201 }
202 }
203
TEST(SOFTMAX_NC_F32,small_batch_with_output_stride)204 TEST(SOFTMAX_NC_F32, small_batch_with_output_stride) {
205 for (size_t channels = 1; channels < 100; channels += 5) {
206 SoftMaxOperatorTester()
207 .batch_size(3)
208 .channels(channels)
209 .output_stride(117)
210 .iterations(3)
211 .TestF32();
212 }
213 }
214
TEST(SOFTMAX_NC_F32,strided_batch_with_input_and_output_stride)215 TEST(SOFTMAX_NC_F32, strided_batch_with_input_and_output_stride) {
216 for (size_t channels = 1; channels < 100; channels += 5) {
217 SoftMaxOperatorTester()
218 .batch_size(3)
219 .channels(channels)
220 .input_stride(129)
221 .output_stride(117)
222 .iterations(3)
223 .TestF32();
224 }
225 }
226
227
TEST(SOFTMAX_NC_QU8,single_class)228 TEST(SOFTMAX_NC_QU8, single_class) {
229 SoftMaxOperatorTester()
230 .batch_size(1)
231 .channels(1)
232 .iterations(100)
233 .TestQU8();
234 }
235
TEST(SOFTMAX_NC_QU8,two_classes)236 TEST(SOFTMAX_NC_QU8, two_classes) {
237 SoftMaxOperatorTester()
238 .batch_size(1)
239 .channels(2)
240 .iterations(100)
241 .TestQU8();
242 }
243
TEST(SOFTMAX_NC_QU8,many_classes)244 TEST(SOFTMAX_NC_QU8, many_classes) {
245 for (size_t channels = 3; channels < 100; channels++) {
246 SoftMaxOperatorTester()
247 .batch_size(1)
248 .channels(channels)
249 .iterations(1)
250 .TestQU8();
251 }
252 }
253
TEST(SOFTMAX_NC_QU8,cifar_classes)254 TEST(SOFTMAX_NC_QU8, cifar_classes) {
255 // CIFAR-10
256 SoftMaxOperatorTester()
257 .batch_size(1)
258 .channels(10)
259 .iterations(15)
260 .TestQU8();
261 // CIFAR-100
262 SoftMaxOperatorTester()
263 .batch_size(1)
264 .channels(100)
265 .iterations(15)
266 .TestQU8();
267 }
268
TEST(SOFTMAX_NC_QU8,imagenet_classes)269 TEST(SOFTMAX_NC_QU8, imagenet_classes) {
270 // ImageNet-1K
271 SoftMaxOperatorTester()
272 .batch_size(1)
273 .channels(1000)
274 .iterations(10)
275 .TestQU8();
276 // ImageNet-1K+1
277 SoftMaxOperatorTester()
278 .batch_size(1)
279 .channels(1001)
280 .iterations(10)
281 .TestQU8();
282 // ImageNet-22K
283 SoftMaxOperatorTester()
284 .batch_size(1)
285 .channels(21841)
286 .iterations(10)
287 .TestQU8();
288 }
289
TEST(SOFTMAX_NC_QU8,many_channels_with_input_scale)290 TEST(SOFTMAX_NC_QU8, many_channels_with_input_scale) {
291 for (size_t channels = 1; channels < 100; channels += 5) {
292 for (float input_scale = 1.0e-2f; input_scale < 1.0e+2f; input_scale *= 3.14159265f) {
293 SoftMaxOperatorTester()
294 .batch_size(1)
295 .channels(channels)
296 .input_scale(input_scale)
297 .iterations(1)
298 .TestQU8();
299 }
300 }
301 }
302
TEST(SOFTMAX_NC_QU8,many_channels_with_input_zero_point)303 TEST(SOFTMAX_NC_QU8, many_channels_with_input_zero_point) {
304 for (size_t channels = 1; channels < 100; channels += 5) {
305 for (int32_t input_zero_point = 0; input_zero_point <= 255; input_zero_point += 51) {
306 SoftMaxOperatorTester()
307 .batch_size(1)
308 .channels(channels)
309 .input_zero_point(uint8_t(input_zero_point))
310 .iterations(1)
311 .TestQU8();
312 }
313 }
314 }
315
TEST(SOFTMAX_NC_QU8,small_batch)316 TEST(SOFTMAX_NC_QU8, small_batch) {
317 for (size_t channels = 1; channels < 100; channels += 5) {
318 SoftMaxOperatorTester()
319 .batch_size(3)
320 .channels(channels)
321 .iterations(3)
322 .TestQU8();
323 }
324 }
325
TEST(SOFTMAX_NC_QU8,small_batch_with_input_stride)326 TEST(SOFTMAX_NC_QU8, small_batch_with_input_stride) {
327 for (size_t channels = 1; channels < 100; channels += 5) {
328 SoftMaxOperatorTester()
329 .batch_size(3)
330 .channels(channels)
331 .input_stride(129)
332 .iterations(3)
333 .TestQU8();
334 }
335 }
336
TEST(SOFTMAX_NC_QU8,small_batch_with_output_stride)337 TEST(SOFTMAX_NC_QU8, small_batch_with_output_stride) {
338 for (size_t channels = 1; channels < 100; channels += 5) {
339 SoftMaxOperatorTester()
340 .batch_size(3)
341 .channels(channels)
342 .output_stride(117)
343 .iterations(3)
344 .TestQU8();
345 }
346 }
347
TEST(SOFTMAX_NC_QU8,strided_batch_with_input_and_output_stride)348 TEST(SOFTMAX_NC_QU8, strided_batch_with_input_and_output_stride) {
349 for (size_t channels = 1; channels < 100; channels += 5) {
350 SoftMaxOperatorTester()
351 .batch_size(3)
352 .channels(channels)
353 .input_stride(129)
354 .output_stride(117)
355 .iterations(3)
356 .TestQU8();
357 }
358 }
359