1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert REQUANTIZATION == "FP32" 7$assert VARIANT in ["FMAGIC", "IMAGIC", "LRINTF"] 8$assert DATATYPE in ["QC8", "QS8", "QU8"] 9#include <assert.h> 10$if VARIANT == "LRINTF": 11 #include <math.h> 12 13#include <xnnpack/math.h> 14#include <xnnpack/gemm.h> 15$if NR % 4 != 0: 16 #include <xnnpack/unaligned.h> 17 18 19$PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar" + ("_" + VARIANT.lower() if VARIANT else "") 20$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 21$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 22$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32" 23$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32" 24void xnn_${DATATYPE.lower()}_gemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x${NR}__${"wasm" if WASM else "scalar"}_${VARIANT.lower()}( 25 size_t mr, 26 size_t nc, 27 size_t kc, 28 const ${XINT8_T}* restrict a, 29 size_t a_stride, 30 const void* restrict w, 31 ${XINT8_T}* restrict c, 32 size_t cm_stride, 33 size_t cn_stride, 34 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) 35{ 36 assert(mr != 0); 37 assert(mr <= ${MR}); 38 assert(nc != 0); 39 assert(kc != 0); 40 41 const ${XINT8_T}* a0 = a; 42 ${XINT8_T}* c0 = c; 43 $for M in range(1, MR): 44 const ${XINT8_T}* a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M-1} + a_stride); 45 ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride); 46 $if M % 2 == 0: 47 if XNN_UNPREDICTABLE(mr <= ${M}) { 48 a${M} = a${M-1}; 49 c${M} = c${M-1}; 50 } 51 $elif M + 1 == MR: 52 if XNN_UNPREDICTABLE(mr != ${M+1}) { 53 a${M} = a${M-1}; 54 c${M} = c${M-1}; 55 } 56 $else: 57 if XNN_UNPREDICTABLE(mr < ${M+1}) { 58 a${M} = a${M-1}; 59 c${M} = c${M-1}; 60 } 61 62 $if DATATYPE == "QU8": 63 const int32_t vb_zero_point = params->${PARAMS_STRUCT}.kernel_zero_point; 64 do { 65 $if NR % 4 != 0: 66 $for N in range(NR): 67 int32_t vacc0x${N} = unaligned_indexed_load_s32(w, ${N}); 68 $else: 69 $for N in range(NR): 70 int32_t vacc0x${N} = ((const int32_t*) w)[${N}]; 71 $for M in range(1, MR): 72 $for N in range(NR): 73 int32_t vacc${M}x${N} = vacc0x${N}; 74 w = (const void*) ((const int32_t*) w + ${NR}); 75 76 size_t k = kc; 77 do { 78 $for M in range(MR): 79 $if DATATYPE == "QU8": 80 const int32_t va${M} = (int32_t) (uint32_t) *a${M}++; 81 $else: 82 const int32_t va${M} = (int32_t) *a${M}++; 83 84 $for N in range(NR): 85 $if DATATYPE == "QU8": 86 const int32_t vb${N} = (int32_t) (uint32_t) ((const uint8_t*) w)[${N}] - vb_zero_point; 87 $else: 88 const int32_t vb${N} = (int32_t) ((const int8_t*) w)[${N}]; 89 w = (const void*) ((const ${XINT8_T}*) w + ${NR}); 90 91 $for M in range(MR): 92 $for N in range(NR): 93 vacc${M}x${N} += va${M} * vb${N}; 94 95 k -= sizeof(${XINT8_T}); 96 } while (k != 0); 97 98 $for M in range(MR): 99 $for N in range(NR): 100 float vfpacc${M}x${N} = (float) vacc${M}x${N}; 101 102 $if DATATYPE == "QC8": 103 $if NR % 4 != 0: 104 $for N in range(NR): 105 const float vscale${N} = unaligned_indexed_load_f32(w, ${N}); 106 $for M in range(MR): 107 vfpacc${M}x${N} *= vscale${N}; 108 $else: 109 $for N in range(NR): 110 const float vscale${N} = ((const float*) w)[${N}]; 111 $for M in range(MR): 112 vfpacc${M}x${N} *= vscale${N}; 113 w = (const void*) ((const float*) w + ${NR}); 114 $else: 115 const float vscale = params->${PARAMS_STRUCT}.scale; 116 $for M in range(MR): 117 $for N in range(NR): 118 vfpacc${M}x${N} *= vscale; 119 120 $if VARIANT == "FMAGIC": 121 const float voutput_min_less_zero_point = params->${PARAMS_STRUCT}.output_min_less_zero_point; 122 $for M in range(MR): 123 $for N in range(NR): 124 vfpacc${M}x${N} = ${MAX_F32}(vfpacc${M}x${N}, voutput_min_less_zero_point); 125 126 const float voutput_max_less_zero_point = params->${PARAMS_STRUCT}.output_max_less_zero_point; 127 $for M in range(MR): 128 $for N in range(NR): 129 vfpacc${M}x${N} = ${MIN_F32}(vfpacc${M}x${N}, voutput_max_less_zero_point); 130 131 const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias; 132 $for M in range(MR): 133 $for N in range(NR): 134 vfpacc${M}x${N} += vmagic_bias; 135 136 const int32_t vmagic_bias_less_output_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point; 137 $for M in range(MR): 138 $for N in range(NR): 139 int32_t vout${M}x${N} = (int32_t) float_as_uint32(vfpacc${M}x${N}) - vmagic_bias_less_output_zero_point; 140 $elif VARIANT == "IMAGIC": 141 const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias; 142 $for M in range(MR): 143 $for N in range(NR): 144 vfpacc${M}x${N} += vmagic_bias; 145 146 $for M in range(MR): 147 $for N in range(NR): 148 int32_t vout${M}x${N} = (int32_t) float_as_uint32(vfpacc${M}x${N}); 149 150 const int32_t vmagic_min = params->${PARAMS_STRUCT}.magic_min; 151 $for M in range(MR): 152 $for N in range(NR): 153 vout${M}x${N} = math_max_s32(vout${M}x${N}, vmagic_min); 154 155 const int32_t vmagic_max = params->${PARAMS_STRUCT}.magic_max; 156 $for M in range(MR): 157 $for N in range(NR): 158 vout${M}x${N} = math_min_s32(vout${M}x${N}, vmagic_max); 159 160 const int32_t vmagic_bias_less_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_zero_point; 161 $for M in range(MR): 162 $for N in range(NR): 163 vout${M}x${N} -= vmagic_bias_less_zero_point; 164 $elif VARIANT == "LRINTF": 165 const float voutput_min_less_zero_point = params->${PARAMS_STRUCT}.output_min_less_zero_point; 166 $for M in range(MR): 167 $for N in range(NR): 168 vfpacc${M}x${N} = ${MAX_F32}(vfpacc${M}x${N}, voutput_min_less_zero_point); 169 170 const float voutput_max_less_zero_point = params->${PARAMS_STRUCT}.output_max_less_zero_point; 171 $for M in range(MR): 172 $for N in range(NR): 173 vfpacc${M}x${N} = ${MIN_F32}(vfpacc${M}x${N}, voutput_max_less_zero_point); 174 175 $for M in range(MR): 176 $for N in range(NR): 177 const int32_t vrndacc${M}x${N} = (int32_t) lrintf(vfpacc${M}x${N}); 178 179 const int32_t voutput_zero_point = params->${PARAMS_STRUCT}.output_zero_point; 180 $for M in range(MR): 181 $for N in range(NR): 182 int32_t vout${M}x${N} = vrndacc${M}x${N} + voutput_zero_point; 183 184 if XNN_LIKELY(nc >= ${NR}) { 185 $for M in range(MR): 186 $for N in range(NR): 187 c${M}[${N}] = (${XINT8_T}) vout${M}x${N}; 188 189 $for M in range(MR): 190 a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); 191 192 $for M in range(MR): 193 c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); 194 195 nc -= ${NR}; 196 } else { 197 $for LOG2N in reversed(range(NR.bit_length() - 1)): 198 if (nc & ${1 << LOG2N}) { 199 $for M in range(MR): 200 $for N in range(1 << LOG2N): 201 c${M}[${N}] = (${XINT8_T}) vout${M}x${N}; 202 $if LOG2N != 0: 203 $for N in range(1 << (LOG2N - 1)): 204 vout${M}x${N} = vout${M}x${N + (1 << LOG2N)}; 205 c${M} += ${1 << LOG2N}; 206 } 207 208 nc = 0; 209 } 210 } while (nc != 0); 211} 212