1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <xnnpack/argmaxpool.h>
9 #include <xnnpack/math.h>
10
11
xnn_f32_argmaxpool_ukernel_9p8x__scalar_c1(size_t output_pixels,size_t pooling_elements,size_t channels,const float ** input,size_t input_offset,float * accumulation_buffer,uint32_t * index_buffer,float * output,uint32_t * index,size_t input_increment,size_t output_increment)12 void xnn_f32_argmaxpool_ukernel_9p8x__scalar_c1(
13 size_t output_pixels,
14 size_t pooling_elements,
15 size_t channels,
16 const float** input,
17 size_t input_offset,
18 float* accumulation_buffer,
19 uint32_t* index_buffer,
20 float* output,
21 uint32_t* index,
22 size_t input_increment,
23 size_t output_increment)
24 {
25 assert(output_pixels != 0);
26 assert(pooling_elements != 0);
27 assert(pooling_elements > 9);
28 assert(channels != 0);
29
30 do {
31 {
32 float* ab = accumulation_buffer;
33 uint32_t* ib = index_buffer;
34
35 const float* i0 = *input++;
36 const float* i1 = *input++;
37 const float* i2 = *input++;
38 const float* i3 = *input++;
39 const float* i4 = *input++;
40 const float* i5 = *input++;
41 const float* i6 = *input++;
42 const float* i7 = *input++;
43 const float* i8 = *input++;
44 i0 = (const float*) ((uintptr_t) i0 + input_offset);
45 i1 = (const float*) ((uintptr_t) i1 + input_offset);
46 i2 = (const float*) ((uintptr_t) i2 + input_offset);
47 i3 = (const float*) ((uintptr_t) i3 + input_offset);
48 i4 = (const float*) ((uintptr_t) i4 + input_offset);
49 i5 = (const float*) ((uintptr_t) i5 + input_offset);
50 i6 = (const float*) ((uintptr_t) i6 + input_offset);
51 i7 = (const float*) ((uintptr_t) i7 + input_offset);
52 i8 = (const float*) ((uintptr_t) i8 + input_offset);
53
54 size_t c = channels;
55 do {
56 const float vi0 = *i0++;
57 const float vi1 = *i1++;
58 const float vi2 = *i2++;
59 const float vi3 = *i3++;
60 const float vi4 = *i4++;
61 const float vi5 = *i5++;
62 const float vi6 = *i6++;
63 const float vi7 = *i7++;
64 const float vi8 = *i8++;
65
66 float vmax = vi0;
67 uint32_t vidx = 0;
68
69 if (vi1 > vmax) {
70 vmax = vi1;
71 vidx = 1;
72 }
73
74 if (vi2 > vmax) {
75 vmax = vi2;
76 vidx = 2;
77 }
78
79 if (vi3 > vmax) {
80 vmax = vi3;
81 vidx = 3;
82 }
83
84 if (vi4 > vmax) {
85 vmax = vi4;
86 vidx = 4;
87 }
88
89 if (vi5 > vmax) {
90 vmax = vi5;
91 vidx = 5;
92 }
93
94 if (vi6 > vmax) {
95 vmax = vi6;
96 vidx = 6;
97 }
98
99 if (vi7 > vmax) {
100 vmax = vi7;
101 vidx = 7;
102 }
103
104 if (vi8 > vmax) {
105 vmax = vi8;
106 vidx = 8;
107 }
108
109 *ab++ = vmax;
110 *ib++ = vidx;
111 } while (--c != 0);
112 }
113 uint32_t vidx0 = 9;
114 size_t k = pooling_elements;
115 for (k -= 9; k > 8; k -= 8) {
116 const float* i0 = *input++;
117 const float* i1 = *input++;
118 const float* i2 = *input++;
119 const float* i3 = *input++;
120 const float* i4 = *input++;
121 const float* i5 = *input++;
122 const float* i6 = *input++;
123 const float* i7 = *input++;
124 i0 = (const float*) ((uintptr_t) i0 + input_offset);
125 i1 = (const float*) ((uintptr_t) i1 + input_offset);
126 i2 = (const float*) ((uintptr_t) i2 + input_offset);
127 i3 = (const float*) ((uintptr_t) i3 + input_offset);
128 i4 = (const float*) ((uintptr_t) i4 + input_offset);
129 i5 = (const float*) ((uintptr_t) i5 + input_offset);
130 i6 = (const float*) ((uintptr_t) i6 + input_offset);
131 i7 = (const float*) ((uintptr_t) i7 + input_offset);
132
133 float* ab = accumulation_buffer;
134 uint32_t* ib = index_buffer;
135
136 size_t c = channels;
137 do {
138 const float vi0 = *i0++;
139 const float vi1 = *i1++;
140 const float vi2 = *i2++;
141 const float vi3 = *i3++;
142 const float vi4 = *i4++;
143 const float vi5 = *i5++;
144 const float vi6 = *i6++;
145 const float vi7 = *i7++;
146
147 float vmax = *ab;
148 uint32_t vidx = *ib;
149
150 if (vi0 > vmax) {
151 vmax = vi0;
152 vidx = vidx0;
153 }
154
155 if (vi1 > vmax) {
156 vmax = vi1;
157 vidx = vidx0 + 1;
158 }
159
160 if (vi2 > vmax) {
161 vmax = vi2;
162 vidx = vidx0 + 2;
163 }
164
165 if (vi3 > vmax) {
166 vmax = vi3;
167 vidx = vidx0 + 3;
168 }
169
170 if (vi4 > vmax) {
171 vmax = vi4;
172 vidx = vidx0 + 4;
173 }
174
175 if (vi5 > vmax) {
176 vmax = vi5;
177 vidx = vidx0 + 5;
178 }
179
180 if (vi6 > vmax) {
181 vmax = vi6;
182 vidx = vidx0 + 6;
183 }
184
185 if (vi7 > vmax) {
186 vmax = vi7;
187 vidx = vidx0 + 7;
188 }
189
190 *ab++ = vmax;
191 *ib++ = vidx;
192 } while (--c != 0);
193 vidx0 += 8;
194 }
195
196 float* o = output;
197 uint32_t* i = index;
198 {
199 const float* i0 = input[0];
200 const float* i1 = input[1];
201 const float* i2 = input[2];
202 const float* i3 = input[3];
203 const float* i4 = input[4];
204 const float* i5 = input[5];
205 const float* i6 = input[6];
206 const float* i7 = input[7];
207 i0 = (const float*) ((uintptr_t) i0 + input_offset);
208 i1 = (const float*) ((uintptr_t) i1 + input_offset);
209 i2 = (const float*) ((uintptr_t) i2 + input_offset);
210 i3 = (const float*) ((uintptr_t) i3 + input_offset);
211 i4 = (const float*) ((uintptr_t) i4 + input_offset);
212 i5 = (const float*) ((uintptr_t) i5 + input_offset);
213 i6 = (const float*) ((uintptr_t) i6 + input_offset);
214 i7 = (const float*) ((uintptr_t) i7 + input_offset);
215 input = (const float**) ((uintptr_t) input + input_increment);
216 if (k < 2) {
217 i1 = i0;
218 }
219 if (k <= 2) {
220 i2 = i0;
221 }
222 if (k < 4) {
223 i3 = i0;
224 }
225 if (k <= 4) {
226 i4 = i0;
227 }
228 if (k < 6) {
229 i5 = i0;
230 }
231 if (k <= 6) {
232 i6 = i0;
233 }
234 if (k != 8) {
235 i7 = i0;
236 }
237
238 size_t c = channels;
239 float* ab = accumulation_buffer;
240 uint32_t* ib = index_buffer;
241 do {
242 const float vi0 = *i0++;
243 const float vi1 = *i1++;
244 const float vi2 = *i2++;
245 const float vi3 = *i3++;
246 const float vi4 = *i4++;
247 const float vi5 = *i5++;
248 const float vi6 = *i6++;
249 const float vi7 = *i7++;
250
251 float vmax = *ab++;
252 uint32_t vidx = *ib++;
253
254 if (vi0 > vmax) {
255 vmax = vi0;
256 vidx = vidx0;
257 }
258
259 if (vi1 > vmax) {
260 vmax = vi1;
261 vidx = vidx0 + 1;
262 }
263
264 if (vi2 > vmax) {
265 vmax = vi2;
266 vidx = vidx0 + 2;
267 }
268
269 if (vi3 > vmax) {
270 vmax = vi3;
271 vidx = vidx0 + 3;
272 }
273
274 if (vi4 > vmax) {
275 vmax = vi4;
276 vidx = vidx0 + 4;
277 }
278
279 if (vi5 > vmax) {
280 vmax = vi5;
281 vidx = vidx0 + 5;
282 }
283
284 if (vi6 > vmax) {
285 vmax = vi6;
286 vidx = vidx0 + 6;
287 }
288
289 if (vi7 > vmax) {
290 vmax = vi7;
291 vidx = vidx0 + 7;
292 }
293
294 *o++ = vmax;
295 *i++ = vidx;
296 } while (--c != 0);
297 }
298
299 output = (float*) ((uintptr_t) o + output_increment);
300 index = (uint32_t*) i;
301 } while (--output_pixels != 0);
302 }
303