1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert KERNEL_TILE >= 2 7$assert REQUANTIZATION == "FP32" 8$assert VARIANT in ["FMAGIC", "IMAGIC", "LRINTF"] 9$assert DATATYPE in ["QC8", "QS8", "QU8"] 10#include <assert.h> 11$if VARIANT == "LRINTF": 12 #include <math.h> 13 14#include <xnnpack/dwconv.h> 15#include <xnnpack/math.h> 16$if CHANNEL_TILE % 4 != 0: 17 #include <xnnpack/unaligned.h> 18 19 20$PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar" + ("_" + VARIANT.lower() if VARIANT else "") 21$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 22$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 23$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32" 24$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32" 25void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__${"wasm" if WASM else "scalar"}_${VARIANT.lower()}( 26 size_t channels, 27 size_t output_width, 28 const ${XINT8_T}** input, 29 const void* weights, 30 ${XINT8_T}* output, 31 size_t input_stride, 32 size_t output_increment, 33 size_t input_offset, 34 const ${XINT8_T}* zero, 35 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) 36{ 37 assert(channels != 0); 38 assert(output_width != 0); 39 40 $if DATATYPE != "QC8": 41 const float vscale = params->${PARAMS_STRUCT}.scale; 42 $if VARIANT == "FMAGIC": 43 const float voutput_min_less_zero_point = params->${PARAMS_STRUCT}.output_min_less_zero_point; 44 const float voutput_max_less_zero_point = params->${PARAMS_STRUCT}.output_max_less_zero_point; 45 const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias; 46 const int32_t vmagic_bias_less_output_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point; 47 $elif VARIANT == "IMAGIC": 48 const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias; 49 const int32_t vmagic_min = params->${PARAMS_STRUCT}.magic_min; 50 const int32_t vmagic_max = params->${PARAMS_STRUCT}.magic_max; 51 const int32_t vmagic_bias_less_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_zero_point; 52 $elif VARIANT == "LRINTF": 53 const float voutput_min_less_zero_point = params->${PARAMS_STRUCT}.output_min_less_zero_point; 54 const float voutput_max_less_zero_point = params->${PARAMS_STRUCT}.output_max_less_zero_point; 55 const int32_t voutput_zero_point = params->${PARAMS_STRUCT}.output_zero_point; 56 $if DATATYPE == "QU8": 57 const int32_t vkernel_zero_point = params->${PARAMS_STRUCT}.kernel_zero_point; 58 do { 59 $for K in range(KERNEL_TILE): 60 const ${XINT8_T}* i${K} = input[${K}]; 61 assert(i${K} != NULL); 62 if XNN_UNPREDICTABLE(i${K} != zero) { 63 i${K} = (const ${XINT8_T}*) ((uintptr_t) i${K} + input_offset); 64 } 65 input = (const ${XINT8_T}**) ((uintptr_t) input + input_stride); 66 67 size_t c = channels; 68 const void* w = weights; 69 $if CHANNEL_TILE == 1: 70 do { 71 int32_t vacc = unaligned_load_s32(w); 72 73 $for K in range(KERNEL_TILE): 74 $if DATATYPE == "QU8": 75 const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++; 76 $else: 77 const int32_t vi${K} = (int32_t) *i${K}++; 78 $if DATATYPE == "QU8": 79 const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + sizeof(int32_t)))[${K}] - vkernel_zero_point; 80 $else: 81 const int32_t vk${K} = ((const ${XINT8_T}*) ((uintptr_t) w + sizeof(int32_t)))[${K}]; 82 vacc += vi${K} * vk${K}; 83 84 w = (const void*) ((uintptr_t) w + sizeof(int32_t) + ${KERNEL_TILE} * sizeof(${XINT8_T})); 85 86 $if DATATYPE == "QC8": 87 $if CHANNEL_TILE % 4 != 0: 88 const float vscale = unaligned_load_f32(w); 89 w = (const void*) ((const float*) w + 1); 90 $else: 91 const float vscale = *((const float*) w); 92 w = (const void*) ((const float*) w + 1); 93 float vfpacc = (float) vacc * vscale; 94 95 $if VARIANT == "FMAGIC": 96 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 97 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 98 vfpacc += vmagic_bias; 99 int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point; 100 $elif VARIANT == "IMAGIC": 101 vfpacc += vmagic_bias; 102 int32_t vout = (int32_t) float_as_uint32(vfpacc); 103 vout = math_max_s32(vout, vmagic_min); 104 vout = math_min_s32(vout, vmagic_max); 105 vout -= vmagic_bias_less_zero_point; 106 $elif VARIANT == "LRINTF": 107 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 108 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 109 const int32_t vrndacc = (int32_t) lrintf(vfpacc); 110 int32_t vout = vrndacc + voutput_zero_point; 111 112 *output++ = (${XINT8_T}) vout; 113 } while (--c != 0); 114 $else: 115 for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) { 116 $if CHANNEL_TILE % 4 != 0: 117 $for C in range(CHANNEL_TILE): 118 int32_t vacc${C} = unaligned_indexed_load_s32(w, ${C}); 119 $else: 120 $for C in range(CHANNEL_TILE): 121 int32_t vacc${C} = ((const int32_t*) w)[${C}]; 122 123 $for K in range(KERNEL_TILE): 124 125 $for C in range(CHANNEL_TILE): 126 $if DATATYPE == "QU8": 127 const int32_t vi${K}x${C} = (int32_t) (uint32_t) i${K}[${C}]; 128 $else: 129 const int32_t vi${K}x${C} = (int32_t) i${K}[${C}]; 130 i${K} += ${CHANNEL_TILE}; 131 132 $for C in range(CHANNEL_TILE): 133 $if DATATYPE == "QU8": 134 const int32_t vk${K}x${C} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE + C}] - vkernel_zero_point; 135 $else: 136 const int32_t vk${K}x${C} = (int32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE + C}]; 137 138 $for C in range(CHANNEL_TILE): 139 vacc${C} += vi${K}x${C} * vk${K}x${C}; 140 141 w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T})); 142 143 $for C in range(CHANNEL_TILE): 144 float vfpacc${C} = (float) vacc${C}; 145 146 $if DATATYPE == "QC8": 147 $if CHANNEL_TILE % 4 != 0: 148 $for C in range(CHANNEL_TILE): 149 const float vscale${C} = unaligned_indexed_load_f32(w, ${C}); 150 $else: 151 $for C in range(CHANNEL_TILE): 152 const float vscale${C} = ((const float*) w)[${C}]; 153 w = (const void*) ((const float*) w + ${CHANNEL_TILE}); 154 155 $for C in range(CHANNEL_TILE): 156 vfpacc${C} *= vscale${C}; 157 $else: 158 $for C in range(CHANNEL_TILE): 159 vfpacc${C} *= vscale; 160 161 $if VARIANT == "FMAGIC": 162 $for C in range(CHANNEL_TILE): 163 vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point); 164 165 $for C in range(CHANNEL_TILE): 166 vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point); 167 168 $for C in range(CHANNEL_TILE): 169 vfpacc${C} += vmagic_bias; 170 171 $for C in range(CHANNEL_TILE): 172 int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C}) - vmagic_bias_less_output_zero_point; 173 $elif VARIANT == "IMAGIC": 174 $for C in range(CHANNEL_TILE): 175 vfpacc${C} += vmagic_bias; 176 177 $for C in range(CHANNEL_TILE): 178 int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C}); 179 180 $for C in range(CHANNEL_TILE): 181 vout${C} = math_max_s32(vout${C}, vmagic_min); 182 183 $for C in range(CHANNEL_TILE): 184 vout${C} = math_min_s32(vout${C}, vmagic_max); 185 186 $for C in range(CHANNEL_TILE): 187 vout${C} -= vmagic_bias_less_zero_point; 188 $elif VARIANT == "LRINTF": 189 $for C in range(CHANNEL_TILE): 190 vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point); 191 192 $for C in range(CHANNEL_TILE): 193 vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point); 194 195 $for C in range(CHANNEL_TILE): 196 const int32_t vrndacc${C} = (int32_t) lrintf(vfpacc${C}); 197 198 $for C in range(CHANNEL_TILE): 199 int32_t vout${C} = (int32_t) vrndacc${C} + voutput_zero_point; 200 201 $for C in range(CHANNEL_TILE): 202 output[${C}] = (${XINT8_T}) vout${C}; 203 output += ${CHANNEL_TILE}; 204 } 205 if XNN_UNLIKELY(c != 0) { 206 $if CHANNEL_TILE == 2: 207 int32_t vacc = unaligned_load_s32(w); 208 209 $for K in range(KERNEL_TILE): 210 $if DATATYPE == "QU8": 211 const int32_t vi${K} = (int32_t) (uint32_t) *i${K}; 212 $else: 213 const int32_t vi${K} = (int32_t) *i${K}; 214 $if DATATYPE == "QU8": 215 const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE}] - vkernel_zero_point; 216 $else: 217 const int32_t vk${K} = (int32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE}]; 218 vacc += vi${K} * vk${K}; 219 220 $if DATATYPE == "QC8": 221 $if CHANNEL_TILE % 4 != 0: 222 typedef XNN_UNALIGNED float unaligned_float; 223 const float vscale = *((const unaligned_float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}))); 224 $else: 225 const float vscale = *((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}))); 226 float vfpacc = (float) vacc * vscale; 227 228 $if VARIANT == "FMAGIC": 229 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 230 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 231 vfpacc += vmagic_bias; 232 int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point; 233 $elif VARIANT == "IMAGIC": 234 vfpacc += vmagic_bias; 235 int32_t vout = (int32_t) float_as_uint32(vfpacc); 236 vout = math_max_s32(vout, vmagic_min); 237 vout = math_min_s32(vout, vmagic_max); 238 vout -= vmagic_bias_less_zero_point; 239 $elif VARIANT == "LRINTF": 240 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 241 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 242 const int32_t vrndacc = (int32_t) lrintf(vfpacc); 243 int32_t vout = vrndacc + voutput_zero_point; 244 245 *output++ = (${XINT8_T}) vout; 246 $else: 247 const ${XINT8_T}* k = (const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)); 248 do { 249 int32_t vacc = *((const int32_t*) w); 250 w = (const void*) ((uintptr_t) w + sizeof(int32_t)); 251 252 $for K in range(KERNEL_TILE): 253 $if DATATYPE == "QU8": 254 const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++; 255 $else: 256 const int32_t vi${K} = (int32_t) *i${K}++; 257 $if DATATYPE == "QU8": 258 const int32_t vk${K} = (int32_t) (uint32_t) k[${K * CHANNEL_TILE}] - vkernel_zero_point; 259 $else: 260 const int32_t vk${K} = (int32_t) k[${K * CHANNEL_TILE}]; 261 vacc += vi${K} * vk${K}; 262 k += 1; 263 264 $if DATATYPE == "QC8": 265 $if CHANNEL_TILE % 4 != 0: 266 const float vscale = unaligned_load_f32((const void*) ((uintptr_t) w + ${CHANNEL_TILE - 1} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}))); 267 $else: 268 const float vscale = *((const float*) ((uintptr_t) w + ${CHANNEL_TILE - 1} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}))); 269 float vfpacc = (float) vacc * vscale; 270 271 $if VARIANT == "FMAGIC": 272 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 273 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 274 vfpacc += vmagic_bias; 275 int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point; 276 $elif VARIANT == "IMAGIC": 277 vfpacc += vmagic_bias; 278 int32_t vout = (int32_t) float_as_uint32(vfpacc); 279 vout = math_max_s32(vout, vmagic_min); 280 vout = math_min_s32(vout, vmagic_max); 281 vout -= vmagic_bias_less_zero_point; 282 $elif VARIANT == "LRINTF": 283 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 284 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 285 const int32_t vrndacc = (int32_t) lrintf(vfpacc); 286 int32_t vout = vrndacc + voutput_zero_point; 287 288 *output++ = (${XINT8_T}) vout; 289 } while (--c != 0); 290 } 291 292 output = (${XINT8_T}*) ((uintptr_t) output + output_increment); 293 } while (--output_width != 0); 294} 295