xref: /aosp_15_r20/external/XNNPACK/src/qs8-dwconv/unipass-scalar.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert KERNEL_TILE >= 2
7$assert REQUANTIZATION == "FP32"
8$assert VARIANT in ["FMAGIC", "IMAGIC", "LRINTF"]
9$assert DATATYPE in ["QC8", "QS8", "QU8"]
10#include <assert.h>
11$if VARIANT == "LRINTF":
12  #include <math.h>
13
14#include <xnnpack/dwconv.h>
15#include <xnnpack/math.h>
16$if CHANNEL_TILE % 4 != 0:
17  #include <xnnpack/unaligned.h>
18
19
20$PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar" + ("_" + VARIANT.lower() if VARIANT else "")
21$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
22$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
23$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
24$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
25void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__${"wasm" if WASM else "scalar"}_${VARIANT.lower()}(
26    size_t channels,
27    size_t output_width,
28    const ${XINT8_T}** input,
29    const void* weights,
30    ${XINT8_T}* output,
31    size_t input_stride,
32    size_t output_increment,
33    size_t input_offset,
34    const ${XINT8_T}* zero,
35    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)])
36{
37  assert(channels != 0);
38  assert(output_width != 0);
39
40  $if DATATYPE != "QC8":
41    const float vscale = params->${PARAMS_STRUCT}.scale;
42  $if VARIANT == "FMAGIC":
43    const float voutput_min_less_zero_point = params->${PARAMS_STRUCT}.output_min_less_zero_point;
44    const float voutput_max_less_zero_point = params->${PARAMS_STRUCT}.output_max_less_zero_point;
45    const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias;
46    const int32_t vmagic_bias_less_output_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point;
47  $elif VARIANT == "IMAGIC":
48    const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias;
49    const int32_t vmagic_min = params->${PARAMS_STRUCT}.magic_min;
50    const int32_t vmagic_max = params->${PARAMS_STRUCT}.magic_max;
51    const int32_t vmagic_bias_less_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_zero_point;
52  $elif VARIANT == "LRINTF":
53    const float voutput_min_less_zero_point = params->${PARAMS_STRUCT}.output_min_less_zero_point;
54    const float voutput_max_less_zero_point = params->${PARAMS_STRUCT}.output_max_less_zero_point;
55    const int32_t voutput_zero_point = params->${PARAMS_STRUCT}.output_zero_point;
56  $if DATATYPE == "QU8":
57    const int32_t vkernel_zero_point = params->${PARAMS_STRUCT}.kernel_zero_point;
58  do {
59    $for K in range(KERNEL_TILE):
60      const ${XINT8_T}* i${K} = input[${K}];
61      assert(i${K} != NULL);
62      if XNN_UNPREDICTABLE(i${K} != zero) {
63        i${K} = (const ${XINT8_T}*) ((uintptr_t) i${K} + input_offset);
64      }
65    input = (const ${XINT8_T}**) ((uintptr_t) input + input_stride);
66
67    size_t c = channels;
68    const void* w = weights;
69    $if CHANNEL_TILE == 1:
70      do {
71        int32_t vacc = unaligned_load_s32(w);
72
73        $for K in range(KERNEL_TILE):
74          $if DATATYPE == "QU8":
75            const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++;
76          $else:
77            const int32_t vi${K} = (int32_t) *i${K}++;
78          $if DATATYPE == "QU8":
79            const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + sizeof(int32_t)))[${K}] - vkernel_zero_point;
80          $else:
81            const int32_t vk${K} = ((const ${XINT8_T}*) ((uintptr_t) w + sizeof(int32_t)))[${K}];
82          vacc += vi${K} * vk${K};
83
84        w = (const void*) ((uintptr_t) w + sizeof(int32_t) + ${KERNEL_TILE} * sizeof(${XINT8_T}));
85
86        $if DATATYPE == "QC8":
87          $if CHANNEL_TILE % 4 != 0:
88            const float vscale = unaligned_load_f32(w);
89            w = (const void*) ((const float*) w + 1);
90          $else:
91            const float vscale = *((const float*) w);
92            w = (const void*) ((const float*) w + 1);
93        float vfpacc = (float) vacc * vscale;
94
95        $if VARIANT == "FMAGIC":
96          vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
97          vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
98          vfpacc += vmagic_bias;
99          int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
100        $elif VARIANT == "IMAGIC":
101          vfpacc += vmagic_bias;
102          int32_t vout = (int32_t) float_as_uint32(vfpacc);
103          vout = math_max_s32(vout, vmagic_min);
104          vout = math_min_s32(vout, vmagic_max);
105          vout -= vmagic_bias_less_zero_point;
106        $elif VARIANT == "LRINTF":
107          vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
108          vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
109          const int32_t vrndacc = (int32_t) lrintf(vfpacc);
110          int32_t vout = vrndacc + voutput_zero_point;
111
112        *output++ = (${XINT8_T}) vout;
113      } while (--c != 0);
114    $else:
115      for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
116        $if CHANNEL_TILE % 4 != 0:
117          $for C in range(CHANNEL_TILE):
118            int32_t vacc${C} = unaligned_indexed_load_s32(w, ${C});
119        $else:
120          $for C in range(CHANNEL_TILE):
121            int32_t vacc${C} = ((const int32_t*) w)[${C}];
122
123        $for K in range(KERNEL_TILE):
124
125          $for C in range(CHANNEL_TILE):
126            $if DATATYPE == "QU8":
127              const int32_t vi${K}x${C} = (int32_t) (uint32_t) i${K}[${C}];
128            $else:
129              const int32_t vi${K}x${C} = (int32_t) i${K}[${C}];
130          i${K} += ${CHANNEL_TILE};
131
132          $for C in range(CHANNEL_TILE):
133            $if DATATYPE == "QU8":
134              const int32_t vk${K}x${C} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE + C}] - vkernel_zero_point;
135            $else:
136              const int32_t vk${K}x${C} = (int32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE + C}];
137
138          $for C in range(CHANNEL_TILE):
139            vacc${C} += vi${K}x${C} * vk${K}x${C};
140
141        w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}));
142
143        $for C in range(CHANNEL_TILE):
144          float vfpacc${C} = (float) vacc${C};
145
146        $if DATATYPE == "QC8":
147          $if CHANNEL_TILE % 4 != 0:
148            $for C in range(CHANNEL_TILE):
149              const float vscale${C} = unaligned_indexed_load_f32(w, ${C});
150          $else:
151            $for C in range(CHANNEL_TILE):
152              const float vscale${C} = ((const float*) w)[${C}];
153          w = (const void*) ((const float*) w + ${CHANNEL_TILE});
154
155          $for C in range(CHANNEL_TILE):
156            vfpacc${C} *= vscale${C};
157        $else:
158          $for C in range(CHANNEL_TILE):
159            vfpacc${C} *= vscale;
160
161        $if VARIANT == "FMAGIC":
162          $for C in range(CHANNEL_TILE):
163            vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point);
164
165          $for C in range(CHANNEL_TILE):
166            vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point);
167
168          $for C in range(CHANNEL_TILE):
169            vfpacc${C} += vmagic_bias;
170
171          $for C in range(CHANNEL_TILE):
172            int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C}) - vmagic_bias_less_output_zero_point;
173        $elif VARIANT == "IMAGIC":
174          $for C in range(CHANNEL_TILE):
175            vfpacc${C} += vmagic_bias;
176
177          $for C in range(CHANNEL_TILE):
178            int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C});
179
180          $for C in range(CHANNEL_TILE):
181            vout${C} = math_max_s32(vout${C}, vmagic_min);
182
183          $for C in range(CHANNEL_TILE):
184            vout${C} = math_min_s32(vout${C}, vmagic_max);
185
186          $for C in range(CHANNEL_TILE):
187            vout${C} -= vmagic_bias_less_zero_point;
188        $elif VARIANT == "LRINTF":
189          $for C in range(CHANNEL_TILE):
190            vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point);
191
192          $for C in range(CHANNEL_TILE):
193            vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point);
194
195          $for C in range(CHANNEL_TILE):
196            const int32_t vrndacc${C} = (int32_t) lrintf(vfpacc${C});
197
198          $for C in range(CHANNEL_TILE):
199            int32_t vout${C} = (int32_t) vrndacc${C} + voutput_zero_point;
200
201        $for C in range(CHANNEL_TILE):
202          output[${C}] = (${XINT8_T}) vout${C};
203        output += ${CHANNEL_TILE};
204      }
205      if XNN_UNLIKELY(c != 0) {
206        $if CHANNEL_TILE == 2:
207          int32_t vacc = unaligned_load_s32(w);
208
209          $for K in range(KERNEL_TILE):
210            $if DATATYPE == "QU8":
211              const int32_t vi${K} = (int32_t) (uint32_t) *i${K};
212            $else:
213              const int32_t vi${K} = (int32_t) *i${K};
214            $if DATATYPE == "QU8":
215              const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE}] - vkernel_zero_point;
216            $else:
217              const int32_t vk${K} = (int32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE}];
218            vacc += vi${K} * vk${K};
219
220          $if DATATYPE == "QC8":
221            $if CHANNEL_TILE % 4 != 0:
222              typedef XNN_UNALIGNED float unaligned_float;
223              const float vscale = *((const unaligned_float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T})));
224            $else:
225              const float vscale = *((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T})));
226          float vfpacc = (float) vacc * vscale;
227
228          $if VARIANT == "FMAGIC":
229            vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
230            vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
231            vfpacc += vmagic_bias;
232            int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
233          $elif VARIANT == "IMAGIC":
234            vfpacc += vmagic_bias;
235            int32_t vout = (int32_t) float_as_uint32(vfpacc);
236            vout = math_max_s32(vout, vmagic_min);
237            vout = math_min_s32(vout, vmagic_max);
238            vout -= vmagic_bias_less_zero_point;
239          $elif VARIANT == "LRINTF":
240            vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
241            vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
242            const int32_t vrndacc = (int32_t) lrintf(vfpacc);
243            int32_t vout = vrndacc + voutput_zero_point;
244
245          *output++ = (${XINT8_T}) vout;
246        $else:
247          const ${XINT8_T}* k = (const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t));
248          do {
249            int32_t vacc = *((const int32_t*) w);
250            w = (const void*) ((uintptr_t) w + sizeof(int32_t));
251
252            $for K in range(KERNEL_TILE):
253              $if DATATYPE == "QU8":
254                const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++;
255              $else:
256                const int32_t vi${K} = (int32_t) *i${K}++;
257              $if DATATYPE == "QU8":
258                const int32_t vk${K} = (int32_t) (uint32_t) k[${K * CHANNEL_TILE}] - vkernel_zero_point;
259              $else:
260                const int32_t vk${K} = (int32_t) k[${K * CHANNEL_TILE}];
261              vacc += vi${K} * vk${K};
262            k += 1;
263
264            $if DATATYPE == "QC8":
265              $if CHANNEL_TILE % 4 != 0:
266                const float vscale = unaligned_load_f32((const void*) ((uintptr_t) w + ${CHANNEL_TILE - 1} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T})));
267              $else:
268                const float vscale = *((const float*) ((uintptr_t) w + ${CHANNEL_TILE - 1} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T})));
269            float vfpacc = (float) vacc * vscale;
270
271            $if VARIANT == "FMAGIC":
272              vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
273              vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
274              vfpacc += vmagic_bias;
275              int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
276            $elif VARIANT == "IMAGIC":
277              vfpacc += vmagic_bias;
278              int32_t vout = (int32_t) float_as_uint32(vfpacc);
279              vout = math_max_s32(vout, vmagic_min);
280              vout = math_min_s32(vout, vmagic_max);
281              vout -= vmagic_bias_less_zero_point;
282            $elif VARIANT == "LRINTF":
283              vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
284              vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
285              const int32_t vrndacc = (int32_t) lrintf(vfpacc);
286              int32_t vout = vrndacc + voutput_zero_point;
287
288            *output++ = (${XINT8_T}) vout;
289          } while (--c != 0);
290      }
291
292    output = (${XINT8_T}*) ((uintptr_t) output + output_increment);
293  } while (--output_width != 0);
294}
295