1 // Auto-generated file. Do not edit!
2 // Template: src/f32-spmm/neon-blocked.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f32_spmm_minmax_ukernel_8x2__neonfma(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_8x2__neonfma(
18 size_t mc,
19 size_t nc,
20 const float*restrict input,
21 const float*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 float*restrict output,
25 size_t output_stride,
26 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28 assert(mc != 0);
29 assert(mc % sizeof(float) == 0);
30 assert(nc != 0);
31
32 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
33 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
34 size_t output_decrement = output_stride * nc - 8 * sizeof(float);
35 while XNN_LIKELY(mc >= 8 * sizeof(float)) {
36 const float*restrict w = weights;
37 const int32_t* dmap = widx_dmap;
38 const uint32_t* nnzmap = nidx_nnzmap;
39 size_t n = nc;
40 while (n >= 2) {
41 uint32_t nnz = *nnzmap++;
42 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
43 float32x4_t vacc4567n0 = vacc0123n0;
44 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
45 float32x4_t vacc4567n1 = vacc0123n1;
46 if XNN_LIKELY(nnz != 0) {
47 do {
48 const intptr_t diff = *dmap++;
49 const float32x4_t vi0123 = vld1q_f32(input);
50 const float32x4_t vi4567 = vld1q_f32(input + 4);
51 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
52 __builtin_prefetch(input + 16);
53 const float32x2_t vw = vld1_f32(w); w += 2;
54 __builtin_prefetch(w + 32);
55 vacc0123n0 = vfmaq_lane_f32(vacc0123n0, vi0123, vw, 0);
56 vacc4567n0 = vfmaq_lane_f32(vacc4567n0, vi4567, vw, 0);
57 vacc0123n1 = vfmaq_lane_f32(vacc0123n1, vi0123, vw, 1);
58 vacc4567n1 = vfmaq_lane_f32(vacc4567n1, vi4567, vw, 1);
59 } while (--nnz != 0);
60 }
61 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
62 float32x4_t vout4567n0 = vminq_f32(vacc4567n0, vmax);
63 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
64 float32x4_t vout4567n1 = vminq_f32(vacc4567n1, vmax);
65
66 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
67 vout4567n0 = vmaxq_f32(vout4567n0, vmin);
68 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
69 vout4567n1 = vmaxq_f32(vout4567n1, vmin);
70
71 vst1q_f32(output + 0, vout0123n0);
72 vst1q_f32(output + 4, vout4567n0);
73 output = (float*restrict) ((uintptr_t) output + output_stride);
74 vst1q_f32(output + 0, vout0123n1);
75 vst1q_f32(output + 4, vout4567n1);
76 output = (float*restrict) ((uintptr_t) output + output_stride);
77 n -= 2;
78 }
79
80 // clean up loop, fall back to nr=1
81 if XNN_UNLIKELY(n != 0) {
82 do {
83 uint32_t nnz = *nnzmap++;
84 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
85 float32x4_t vacc4567 = vacc0123;
86 if XNN_LIKELY(nnz != 0) {
87 do {
88 const intptr_t diff = *dmap++;
89 const float32x4_t vi0123 = vld1q_f32(input);
90 const float32x4_t vi4567 = vld1q_f32(input + 4);
91 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
92 __builtin_prefetch(input + 16);
93 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
94 __builtin_prefetch(w + 32);
95 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
96 vacc4567 = vfmaq_f32(vacc4567, vi4567, vw);
97 } while (--nnz != 0);
98 }
99 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
100 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
101
102 vout0123 = vmaxq_f32(vout0123, vmin);
103 vout4567 = vmaxq_f32(vout4567, vmin);
104
105 vst1q_f32(output + 0, vout0123);
106 vst1q_f32(output + 4, vout4567);
107 output = (float*restrict) ((uintptr_t) output + output_stride);
108 n -= 1;
109 } while (n != 0);
110 }
111 output = (float*restrict) ((uintptr_t) output - output_decrement);
112 input += 8;
113 mc -= 8 * sizeof(float);
114 }
115 if XNN_UNLIKELY(mc != 0) {
116 output_decrement += 4 * sizeof(float);
117 if (mc & (4 * sizeof(float))) {
118 const float*restrict w = weights;
119 const int32_t* dmap = widx_dmap;
120 const uint32_t* nnzmap = nidx_nnzmap;
121 size_t n = nc;
122 while (n >= 2) {
123 uint32_t nnz = *nnzmap++;
124 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
125 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
126 if XNN_LIKELY(nnz != 0) {
127 do {
128 const intptr_t diff = *dmap++;
129 const float32x4_t vi0123 = vld1q_f32(input);
130 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
131 const float32x2_t vw = vld1_f32(w); w += 2;
132
133 vacc0123n0 = vfmaq_lane_f32(vacc0123n0, vi0123, vw, 0);
134 vacc0123n1 = vfmaq_lane_f32(vacc0123n1, vi0123, vw, 1);
135 } while (--nnz != 0);
136 }
137 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
138 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
139
140 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
141 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
142
143 vst1q_f32(output + 0, vout0123n0);
144 output = (float*restrict) ((uintptr_t) output + output_stride);
145 vst1q_f32(output + 0, vout0123n1);
146 output = (float*restrict) ((uintptr_t) output + output_stride);
147 n -= 2;
148 }
149
150 // clean up loop, fall back to nr=1
151 if XNN_UNLIKELY(n != 0) {
152 do {
153 uint32_t nnz = *nnzmap++;
154 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
155 if XNN_LIKELY(nnz != 0) {
156 do {
157 const intptr_t diff = *dmap++;
158 const float32x4_t vi0123 = vld1q_f32(input);
159 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
160 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
161 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
162 } while (--nnz != 0);
163 }
164 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
165
166 vout0123 = vmaxq_f32(vout0123, vmin);
167
168 vst1q_f32(output + 0, vout0123);
169 output = (float*restrict) ((uintptr_t) output + output_stride);
170 n -= 1;
171 } while (n != 0);
172 }
173 output = (float*restrict) ((uintptr_t) output - output_decrement);
174 input += 4;
175 }
176 output_decrement += 2 * sizeof(float);
177 if (mc & (2 * sizeof(float))) {
178 const float*restrict w = weights;
179 const int32_t* dmap = widx_dmap;
180 const uint32_t* nnzmap = nidx_nnzmap;
181 size_t n = nc;
182 while (n >= 2) {
183 uint32_t nnz = *nnzmap++;
184 float32x2_t vacc01n0 = vld1_dup_f32(w); w += 1;
185 float32x2_t vacc01n1 = vld1_dup_f32(w); w += 1;
186 if XNN_LIKELY(nnz != 0) {
187 do {
188 const intptr_t diff = *dmap++;
189 const float32x2_t vi01 = vld1_f32(input);
190 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
191 const float32x2_t vw = vld1_f32(w); w += 2;
192
193 vacc01n0 = vfma_lane_f32(vacc01n0, vi01, vw, 0);
194 vacc01n1 = vfma_lane_f32(vacc01n1, vi01, vw, 1);
195 } while (--nnz != 0);
196 }
197 float32x2_t vout01n0 = vmin_f32(vacc01n0, vget_low_f32(vmax));
198 float32x2_t vout01n1 = vmin_f32(vacc01n1, vget_low_f32(vmax));
199
200 vout01n0 = vmax_f32(vout01n0, vget_low_f32(vmin));
201 vout01n1 = vmax_f32(vout01n1, vget_low_f32(vmin));
202
203 vst1_f32(output + 0, vout01n0);
204 output = (float*restrict) ((uintptr_t) output + output_stride);
205 vst1_f32(output + 0, vout01n1);
206 output = (float*restrict) ((uintptr_t) output + output_stride);
207 n -= 2;
208 }
209
210 // clean up loop, fall back to nr=1
211 if XNN_UNLIKELY(n != 0) {
212 do {
213 uint32_t nnz = *nnzmap++;
214 float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
215 if XNN_LIKELY(nnz != 0) {
216 do {
217 const intptr_t diff = *dmap++;
218 const float32x2_t vi01 = vld1_f32(input);
219 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
220 const float32x2_t vw = vld1_dup_f32(w); w += 1;
221 vacc01 = vfma_f32(vacc01, vi01, vw);
222 } while (--nnz != 0);
223 }
224 float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
225 vout01 = vmax_f32(vout01, vget_low_f32(vmin));
226
227 vst1_f32(output, vout01);
228 output = (float*restrict) ((uintptr_t) output + output_stride);
229 n -= 1;
230 } while (n != 0);
231 }
232 output = (float*restrict) ((uintptr_t) output - output_decrement);
233 input += 2;
234 }
235 output_decrement += 1 * sizeof(float);
236 if (mc & (1 * sizeof(float))) {
237 const float*restrict w = weights;
238 const int32_t* dmap = widx_dmap;
239 const uint32_t* nnzmap = nidx_nnzmap;
240 size_t n = nc;
241 while (n >= 2) {
242 uint32_t nnz = *nnzmap++;
243 float32x2_t vacc0n0 = vld1_dup_f32(w); w += 1;
244 float32x2_t vacc0n1 = vld1_dup_f32(w); w += 1;
245 if XNN_LIKELY(nnz != 0) {
246 do {
247 const intptr_t diff = *dmap++;
248 const float32x2_t vi0 = vld1_dup_f32(input);
249 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
250 const float32x2_t vw = vld1_f32(w); w += 2;
251
252 vacc0n0 = vfma_lane_f32(vacc0n0, vi0, vw, 0);
253 vacc0n1 = vfma_lane_f32(vacc0n1, vi0, vw, 1);
254 } while (--nnz != 0);
255 }
256 float32x2_t vout0n0 = vmin_f32(vacc0n0, vget_low_f32(vmax));
257 float32x2_t vout0n1 = vmin_f32(vacc0n1, vget_low_f32(vmax));
258
259 vout0n0 = vmax_f32(vout0n0, vget_low_f32(vmin));
260 vout0n1 = vmax_f32(vout0n1, vget_low_f32(vmin));
261
262 vst1_lane_f32(output + 0, vout0n0, 0);
263 output = (float*restrict) ((uintptr_t) output + output_stride);
264 vst1_lane_f32(output + 0, vout0n1, 0);
265 output = (float*restrict) ((uintptr_t) output + output_stride);
266 n -= 2;
267 }
268
269 // clean up loop, fall back to nr=1
270 if XNN_UNLIKELY(n != 0) {
271 do {
272 uint32_t nnz = *nnzmap++;
273 float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
274 if XNN_LIKELY(nnz != 0) {
275 do {
276 const intptr_t diff = *dmap++;
277 const float32x2_t vi0 = vld1_dup_f32(input);
278 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
279 const float32x2_t vw = vld1_dup_f32(w); w += 1;
280 vacc0 = vfma_f32(vacc0, vi0, vw);
281 } while (--nnz != 0);
282 }
283 float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
284 vout0 = vmax_f32(vout0, vget_low_f32(vmin));
285
286 vst1_lane_f32(output, vout0, 1);
287 output = (float*restrict) ((uintptr_t) output + output_stride);
288 n -= 1;
289 } while (n != 0);
290 }
291 output = (float*restrict) ((uintptr_t) output - output_decrement);
292 input += 1;
293 }
294 }
295 }
296