1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert REQUANTIZATION == "FP32" 7$assert DATATYPE in ["QC8", "QS8", "QU8"] 8$assert 1 <= MR <= 2 9$assert 1 <= NR <= 2 10#include <assert.h> 11 12#include <arm_acle.h> 13 14#include <xnnpack/intrinsics-polyfill.h> 15#include <xnnpack/math.h> 16#include <xnnpack/gemm.h> 17#include <xnnpack/unaligned.h> 18 19 20$PARAMS_STRUCT = REQUANTIZATION.lower() + "_armsimd32" 21$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 22$__XXTB16 = "__uxtb16" if DATATYPE == "QU8" else "__sxtb16" 23$__XSAT = "__usat" if DATATYPE == "QU8" else "__ssat" 24$__XSUB8 = "__usub8" if DATATYPE == "QU8" else "__ssub8" 25$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 26void xnn_${DATATYPE.lower()}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x${NR}c4__armsimd32( 27 size_t mr, 28 size_t nc, 29 size_t kc, 30 size_t ks, 31 const ${XINT8_T}**restrict a, 32 const void*restrict w, 33 ${XINT8_T}*restrict c, 34 size_t cm_stride, 35 size_t cn_stride, 36 size_t a_offset, 37 const ${XINT8_T}* zero, 38 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) 39{ 40 assert(mr != 0); 41 assert(mr <= ${MR}); 42 assert(nc != 0); 43 assert(kc != 0); 44 assert(ks != 0); 45 assert(ks % (${MR} * sizeof(void*)) == 0); 46 assert(a != NULL); 47 assert(w != NULL); 48 assert(c != NULL); 49 50 kc = round_up_po2(kc, 4 * sizeof(int8_t)); 51 ${XINT8_T}* c0 = c; 52 $for M in range(1, MR): 53 ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride); 54 $if M % 2 == 0: 55 if XNN_UNPREDICTABLE(mr <= ${M}) { 56 c${M} = c${M-1}; 57 } 58 $elif M + 1 == MR: 59 if XNN_UNPREDICTABLE(mr != ${M+1}) { 60 c${M} = c${M-1}; 61 } 62 $else: 63 if XNN_UNPREDICTABLE(mr < ${M+1}) { 64 c${M} = c${M-1}; 65 } 66 67 $if DATATYPE == "QU8": 68 const int16x2_t vb_minus_zero_point = (int16x2_t) params->${PARAMS_STRUCT}.minus_kernel_zero_point; 69 $if REQUANTIZATION == "FP32": 70 $if DATATYPE != "QC8": 71 const float vscale = params->${PARAMS_STRUCT}.scale; 72 const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias; 73 do { 74 $for N in range(NR): 75 int32_t vacc0x${N} = ((const int32_t*) w)[${N}]; 76 $for M in range(1, MR): 77 $for N in range(NR): 78 int32_t vacc${M}x${N} = vacc0x${N}; 79 w = (const void*) ((const int32_t*) w + ${NR}); 80 81 size_t p = ks; 82 do { 83 $for M in range(MR): 84 const ${XINT8_T}* restrict a${M} = a[${M}]; 85 assert(a${M} != NULL); 86 if XNN_UNPREDICTABLE(a${M} != zero) { 87 a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} + a_offset); 88 } 89 a += ${MR}; 90 91 size_t k = kc; 92 do { 93 $for M in range(MR): 94 const int8x4_t va${M} = (int8x4_t) unaligned_load_s32(a${M}); a${M} += 4; 95 96 $for M in range(MR): 97 const int16x2_t va${M}c02 = ${__XXTB16}(va${M}); 98 const int16x2_t va${M}c13 = ${__XXTB16}(__ror(va${M}, 8)); 99 100 $for N in range(NR): 101 const int8x4_t vb${N} = *((const int8x4_t*) w); w = (const int8_t*) w + 4; 102 $if DATATYPE == "QU8": 103 const int16x2_t vb${N}c02 = __uxtab16(vb_minus_zero_point, vb${N}); 104 $else: 105 const int16x2_t vb${N}c02 = __sxtb16(vb${N}); 106 107 $for M in range(MR): 108 vacc${M}x${N} = __smlad(va${M}c02, vb${N}c02, vacc${M}x${N}); 109 110 $if DATATYPE == "QU8": 111 const int16x2_t vb${N}c13 = __uxtab16(vb_minus_zero_point, __ror(vb${N}, 8)); 112 $else: 113 const int16x2_t vb${N}c13 = __sxtb16(__ror(vb${N}, 8)); 114 $for M in range(MR): 115 vacc${M}x${N} = __smlad(va${M}c13, vb${N}c13, vacc${M}x${N}); 116 117 k -= 4 * sizeof(${XINT8_T}); 118 } while (k != 0); 119 p -= ${MR} * sizeof(void*); 120 } while (p != 0); 121 122 $for M in range(MR): 123 $for N in range(NR): 124 float vfpacc${M}x${N} = (float) vacc${M}x${N}; 125 126 $if DATATYPE == "QC8": 127 $for N in range(NR): 128 const float vscale${N} = ((const float*) w)[${N}]; 129 $for M in range(MR): 130 vfpacc${M}x${N} *= vscale${N}; 131 w = (const void*) ((const float*) w + ${NR}); 132 $else: 133 $for M in range(MR): 134 $for N in range(NR): 135 vfpacc${M}x${N} *= vscale; 136 137 $for M in range(MR): 138 $for N in range(NR): 139 vfpacc${M}x${N} += vmagic_bias; 140 141 $for M in range(MR): 142 $for N in range(NR): 143 int32_t vout${M}x${N} = (int32_t) float_as_uint32(vfpacc${M}x${N}); 144 145 const int32_t vmagic_bias_less_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_zero_point; 146 $for M in range(MR): 147 $for N in range(NR): 148 vout${M}x${N} = __qsub(vout${M}x${N}, vmagic_bias_less_zero_point); 149 150 $for M in range(MR): 151 $for N in range(NR): 152 vout${M}x${N} = ${__XSAT}(vout${M}x${N}, 8); 153 154 $for M in range(MR): 155 $if NR == 1: 156 const uint32_t vout${M} = (uint32_t) vout${M}x0; 157 $else: 158 const uint32_t vout${M} = (uint32_t) (uint8_t) vout${M}x0 | ((uint32_t) vout${M}x1 << 8); 159 160 $if MR == 1: 161 uint32_t vout = vout0; 162 $else: 163 uint32_t vout = (uint32_t) (uint16_t) vout1 | (vout0 << 16); 164 165 const int8x4_t voutput_min = (int8x4_t) params->${PARAMS_STRUCT}.output_min; 166 ${__XSUB8}((int8x4_t) vout, voutput_min); 167 vout = (uint32_t) __sel((uint8x4_t) vout, (uint8x4_t) voutput_min); 168 169 const int8x4_t voutput_max = (int8x4_t) params->${PARAMS_STRUCT}.output_max; 170 ${__XSUB8}((int8x4_t) vout, voutput_max); 171 vout = (uint32_t) __sel((uint8x4_t) voutput_max, (uint8x4_t) vout); 172 173 $if NR == 2: 174 if XNN_LIKELY(nc >= ${NR}) { 175 $for M in reversed(range(MR)): 176 unaligned_store_u16(c${M}, (uint16_t) vout); 177 $if M != 0: 178 vout >>= 16; 179 180 $for M in reversed(range(MR)): 181 c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); 182 183 a = (const ${XINT8_T}**restrict) ((uintptr_t) a - ks); 184 nc -= ${NR}; 185 } else { 186 $for M in reversed(range(MR)): 187 *c${M} = (${XINT8_T}) vout; 188 $if M != 0: 189 vout >>= 16; 190 191 nc = 0; 192 } 193 $else: 194 $for M in reversed(range(MR)): 195 *c${M} = (${XINT8_T}) vout; 196 $if M != 0: 197 vout >>= 16; 198 199 $for M in reversed(range(MR)): 200 c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); 201 202 a = (const ${XINT8_T}**restrict) ((uintptr_t) a - ks); 203 nc -= 1; 204 } while (nc != 0); 205} 206