xref: /aosp_15_r20/external/XNNPACK/test/softmax-nc.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <gtest/gtest.h>
10 
11 #include "softmax-operator-tester.h"
12 
13 
TEST(SOFTMAX_NC_F16,single_class)14 TEST(SOFTMAX_NC_F16, single_class) {
15   SoftMaxOperatorTester()
16     .batch_size(1)
17     .channels(1)
18     .iterations(100)
19     .TestF16();
20 }
21 
TEST(SOFTMAX_NC_F16,two_classes)22 TEST(SOFTMAX_NC_F16, two_classes) {
23   SoftMaxOperatorTester()
24     .batch_size(1)
25     .channels(2)
26     .iterations(100)
27     .TestF16();
28 }
29 
TEST(SOFTMAX_NC_F16,many_classes)30 TEST(SOFTMAX_NC_F16, many_classes) {
31   for (size_t channels = 3; channels < 100; channels++) {
32     SoftMaxOperatorTester()
33       .batch_size(1)
34       .channels(channels)
35       .iterations(1)
36       .TestF16();
37   }
38 }
39 
TEST(SOFTMAX_NC_F16,cifar_classes)40 TEST(SOFTMAX_NC_F16, cifar_classes) {
41   // CIFAR-10
42   SoftMaxOperatorTester()
43     .batch_size(1)
44     .channels(10)
45     .iterations(15)
46     .TestF16();
47   // CIFAR-100
48   SoftMaxOperatorTester()
49     .batch_size(1)
50     .channels(100)
51     .iterations(15)
52     .TestF16();
53 }
54 
TEST(SOFTMAX_NC_F16,imagenet_classes)55 TEST(SOFTMAX_NC_F16, imagenet_classes) {
56   // ImageNet-1K
57   SoftMaxOperatorTester()
58     .batch_size(1)
59     .channels(1000)
60     .iterations(10)
61     .TestF16();
62   // ImageNet-1K+1
63   SoftMaxOperatorTester()
64     .batch_size(1)
65     .channels(1001)
66     .iterations(10)
67     .TestF16();
68   // ImageNet-22K
69   SoftMaxOperatorTester()
70     .batch_size(1)
71     .channels(21841)
72     .iterations(10)
73     .TestF16();
74 }
75 
TEST(SOFTMAX_NC_F16,small_batch)76 TEST(SOFTMAX_NC_F16, small_batch) {
77   for (size_t channels = 1; channels < 100; channels += 5) {
78     SoftMaxOperatorTester()
79       .batch_size(3)
80       .channels(channels)
81       .iterations(3)
82       .TestF16();
83   }
84 }
85 
TEST(SOFTMAX_NC_F16,small_batch_with_input_stride)86 TEST(SOFTMAX_NC_F16, small_batch_with_input_stride) {
87   for (size_t channels = 1; channels < 100; channels += 5) {
88     SoftMaxOperatorTester()
89       .batch_size(3)
90       .channels(channels)
91       .input_stride(129)
92       .iterations(3)
93       .TestF16();
94   }
95 }
96 
TEST(SOFTMAX_NC_F16,small_batch_with_output_stride)97 TEST(SOFTMAX_NC_F16, small_batch_with_output_stride) {
98   for (size_t channels = 1; channels < 100; channels += 5) {
99     SoftMaxOperatorTester()
100       .batch_size(3)
101       .channels(channels)
102       .output_stride(117)
103       .iterations(3)
104       .TestF16();
105   }
106 }
107 
TEST(SOFTMAX_NC_F16,strided_batch_with_input_and_output_stride)108 TEST(SOFTMAX_NC_F16, strided_batch_with_input_and_output_stride) {
109   for (size_t channels = 1; channels < 100; channels += 5) {
110     SoftMaxOperatorTester()
111       .batch_size(3)
112       .channels(channels)
113       .input_stride(129)
114       .output_stride(117)
115       .iterations(3)
116       .TestF16();
117   }
118 }
119 
120 
TEST(SOFTMAX_NC_F32,single_class)121 TEST(SOFTMAX_NC_F32, single_class) {
122   SoftMaxOperatorTester()
123     .batch_size(1)
124     .channels(1)
125     .iterations(100)
126     .TestF32();
127 }
128 
TEST(SOFTMAX_NC_F32,two_classes)129 TEST(SOFTMAX_NC_F32, two_classes) {
130   SoftMaxOperatorTester()
131     .batch_size(1)
132     .channels(2)
133     .iterations(100)
134     .TestF32();
135 }
136 
TEST(SOFTMAX_NC_F32,many_classes)137 TEST(SOFTMAX_NC_F32, many_classes) {
138   for (size_t channels = 3; channels < 100; channels++) {
139     SoftMaxOperatorTester()
140       .batch_size(1)
141       .channels(channels)
142       .iterations(1)
143       .TestF32();
144   }
145 }
146 
TEST(SOFTMAX_NC_F32,cifar_classes)147 TEST(SOFTMAX_NC_F32, cifar_classes) {
148   // CIFAR-10
149   SoftMaxOperatorTester()
150     .batch_size(1)
151     .channels(10)
152     .iterations(15)
153     .TestF32();
154   // CIFAR-100
155   SoftMaxOperatorTester()
156     .batch_size(1)
157     .channels(100)
158     .iterations(15)
159     .TestF32();
160 }
161 
TEST(SOFTMAX_NC_F32,imagenet_classes)162 TEST(SOFTMAX_NC_F32, imagenet_classes) {
163   // ImageNet-1K
164   SoftMaxOperatorTester()
165     .batch_size(1)
166     .channels(1000)
167     .iterations(10)
168     .TestF32();
169   // ImageNet-1K+1
170   SoftMaxOperatorTester()
171     .batch_size(1)
172     .channels(1001)
173     .iterations(10)
174     .TestF32();
175   // ImageNet-22K
176   SoftMaxOperatorTester()
177     .batch_size(1)
178     .channels(21841)
179     .iterations(10)
180     .TestF32();
181 }
182 
TEST(SOFTMAX_NC_F32,small_batch)183 TEST(SOFTMAX_NC_F32, small_batch) {
184   for (size_t channels = 1; channels < 100; channels += 5) {
185     SoftMaxOperatorTester()
186       .batch_size(3)
187       .channels(channels)
188       .iterations(3)
189       .TestF32();
190   }
191 }
192 
TEST(SOFTMAX_NC_F32,small_batch_with_input_stride)193 TEST(SOFTMAX_NC_F32, small_batch_with_input_stride) {
194   for (size_t channels = 1; channels < 100; channels += 5) {
195     SoftMaxOperatorTester()
196       .batch_size(3)
197       .channels(channels)
198       .input_stride(129)
199       .iterations(3)
200       .TestF32();
201   }
202 }
203 
TEST(SOFTMAX_NC_F32,small_batch_with_output_stride)204 TEST(SOFTMAX_NC_F32, small_batch_with_output_stride) {
205   for (size_t channels = 1; channels < 100; channels += 5) {
206     SoftMaxOperatorTester()
207       .batch_size(3)
208       .channels(channels)
209       .output_stride(117)
210       .iterations(3)
211       .TestF32();
212   }
213 }
214 
TEST(SOFTMAX_NC_F32,strided_batch_with_input_and_output_stride)215 TEST(SOFTMAX_NC_F32, strided_batch_with_input_and_output_stride) {
216   for (size_t channels = 1; channels < 100; channels += 5) {
217     SoftMaxOperatorTester()
218       .batch_size(3)
219       .channels(channels)
220       .input_stride(129)
221       .output_stride(117)
222       .iterations(3)
223       .TestF32();
224   }
225 }
226 
227 
TEST(SOFTMAX_NC_QU8,single_class)228 TEST(SOFTMAX_NC_QU8, single_class) {
229   SoftMaxOperatorTester()
230     .batch_size(1)
231     .channels(1)
232     .iterations(100)
233     .TestQU8();
234 }
235 
TEST(SOFTMAX_NC_QU8,two_classes)236 TEST(SOFTMAX_NC_QU8, two_classes) {
237   SoftMaxOperatorTester()
238     .batch_size(1)
239     .channels(2)
240     .iterations(100)
241     .TestQU8();
242 }
243 
TEST(SOFTMAX_NC_QU8,many_classes)244 TEST(SOFTMAX_NC_QU8, many_classes) {
245   for (size_t channels = 3; channels < 100; channels++) {
246     SoftMaxOperatorTester()
247       .batch_size(1)
248       .channels(channels)
249       .iterations(1)
250       .TestQU8();
251   }
252 }
253 
TEST(SOFTMAX_NC_QU8,cifar_classes)254 TEST(SOFTMAX_NC_QU8, cifar_classes) {
255   // CIFAR-10
256   SoftMaxOperatorTester()
257     .batch_size(1)
258     .channels(10)
259     .iterations(15)
260     .TestQU8();
261   // CIFAR-100
262   SoftMaxOperatorTester()
263     .batch_size(1)
264     .channels(100)
265     .iterations(15)
266     .TestQU8();
267 }
268 
TEST(SOFTMAX_NC_QU8,imagenet_classes)269 TEST(SOFTMAX_NC_QU8, imagenet_classes) {
270   // ImageNet-1K
271   SoftMaxOperatorTester()
272     .batch_size(1)
273     .channels(1000)
274     .iterations(10)
275     .TestQU8();
276   // ImageNet-1K+1
277   SoftMaxOperatorTester()
278     .batch_size(1)
279     .channels(1001)
280     .iterations(10)
281     .TestQU8();
282   // ImageNet-22K
283   SoftMaxOperatorTester()
284     .batch_size(1)
285     .channels(21841)
286     .iterations(10)
287     .TestQU8();
288 }
289 
TEST(SOFTMAX_NC_QU8,many_channels_with_input_scale)290 TEST(SOFTMAX_NC_QU8, many_channels_with_input_scale) {
291   for (size_t channels = 1; channels < 100; channels += 5) {
292     for (float input_scale = 1.0e-2f; input_scale < 1.0e+2f; input_scale *= 3.14159265f) {
293       SoftMaxOperatorTester()
294         .batch_size(1)
295         .channels(channels)
296         .input_scale(input_scale)
297         .iterations(1)
298         .TestQU8();
299     }
300   }
301 }
302 
TEST(SOFTMAX_NC_QU8,many_channels_with_input_zero_point)303 TEST(SOFTMAX_NC_QU8, many_channels_with_input_zero_point) {
304   for (size_t channels = 1; channels < 100; channels += 5) {
305     for (int32_t input_zero_point = 0; input_zero_point <= 255; input_zero_point += 51) {
306       SoftMaxOperatorTester()
307         .batch_size(1)
308         .channels(channels)
309         .input_zero_point(uint8_t(input_zero_point))
310         .iterations(1)
311         .TestQU8();
312     }
313   }
314 }
315 
TEST(SOFTMAX_NC_QU8,small_batch)316 TEST(SOFTMAX_NC_QU8, small_batch) {
317   for (size_t channels = 1; channels < 100; channels += 5) {
318     SoftMaxOperatorTester()
319       .batch_size(3)
320       .channels(channels)
321       .iterations(3)
322       .TestQU8();
323   }
324 }
325 
TEST(SOFTMAX_NC_QU8,small_batch_with_input_stride)326 TEST(SOFTMAX_NC_QU8, small_batch_with_input_stride) {
327   for (size_t channels = 1; channels < 100; channels += 5) {
328     SoftMaxOperatorTester()
329       .batch_size(3)
330       .channels(channels)
331       .input_stride(129)
332       .iterations(3)
333       .TestQU8();
334   }
335 }
336 
TEST(SOFTMAX_NC_QU8,small_batch_with_output_stride)337 TEST(SOFTMAX_NC_QU8, small_batch_with_output_stride) {
338   for (size_t channels = 1; channels < 100; channels += 5) {
339     SoftMaxOperatorTester()
340       .batch_size(3)
341       .channels(channels)
342       .output_stride(117)
343       .iterations(3)
344       .TestQU8();
345   }
346 }
347 
TEST(SOFTMAX_NC_QU8,strided_batch_with_input_and_output_stride)348 TEST(SOFTMAX_NC_QU8, strided_batch_with_input_and_output_stride) {
349   for (size_t channels = 1; channels < 100; channels += 5) {
350     SoftMaxOperatorTester()
351       .batch_size(3)
352       .channels(channels)
353       .input_stride(129)
354       .output_stride(117)
355       .iterations(3)
356       .TestQU8();
357   }
358 }
359