1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 7$assert NR % 8 == 0 8$assert 8 <= NR <= 16 9$assert REQUANTIZATION in ["FP32", "RNDNU"] 10$assert not CHANNELWISE or REQUANTIZATION == "FP32" 11#include <assert.h> 12 13#include <arm_neon.h> 14 15#include <xnnpack/gemm.h> 16$if REQUANTIZATION == "FP32": 17 #include <xnnpack/intrinsics-polyfill.h> 18#include <xnnpack/math.h> 19 20 21$DATATYPE = "qc8" if CHANNELWISE else "qs8" 22$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if REQUANTIZATION == "FP32" else "neon") 23$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 24void xnn_${DATATYPE}_gemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x${NR}c4__neondot( 25 size_t mr, 26 size_t nc, 27 size_t kc, 28 const int8_t* restrict a, 29 size_t a_stride, 30 const void* restrict w, 31 int8_t* restrict c, 32 size_t cm_stride, 33 size_t cn_stride, 34 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 35{ 36 assert(mr != 0); 37 assert(mr <= ${MR}); 38 assert(nc != 0); 39 assert(kc != 0); 40 assert(kc % sizeof(int8_t) == 0); 41 assert(a != NULL); 42 assert(w != NULL); 43 assert(c != NULL); 44 45 kc = round_up_po2(kc, 4 * sizeof(int8_t)); 46 const int8_t* a0 = a; 47 int8_t* c0 = c; 48 $for M in range(1, MR): 49 const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride); 50 int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); 51 $if M % 2 == 0: 52 if XNN_UNPREDICTABLE(mr <= ${M}) { 53 a${M} = a${M-1}; 54 c${M} = c${M-1}; 55 } 56 $elif M + 1 == MR: 57 if XNN_UNPREDICTABLE(mr != ${M+1}) { 58 a${M} = a${M-1}; 59 c${M} = c${M-1}; 60 } 61 $else: 62 if XNN_UNPREDICTABLE(mr < ${M+1}) { 63 a${M} = a${M-1}; 64 c${M} = c${M-1}; 65 } 66 67 // Loop over groups of ${NR} columns. 68 do { 69 // Initialize accumulators with bias. ${NR} bias values are loaded from the 70 // weight matrix, at the start of the group of ${NR} columns. 71 $for N in range(0, NR, 4): 72 int32x4_t vacc0x${ABC[N:N+4]} = vld1q_s32(w); w = (const void*) ((const int32_t*) w + 4); 73 $for M in range(1, MR): 74 $for N in range(0, NR, 4): 75 int32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]}; 76 77 // Inner accumulation loop along the ${NR} columns. 78 size_t k = kc; 79 // 2x partial unrolled loop to load 8 bytes at a time. 80 while (k >= 8 * sizeof(int8_t)) { 81 // Load a ${MR}x8 block of activations. 82 $for M in range(MR): 83 const int8x8_t va${M}x01234567 = vld1_s8(a${M}); a${M} += 8; 84 85 // Load a 8x${NR} block of weights. 86 $for K in range(0, 8, 4): 87 $for N in range(0, NR, 4): 88 const int8x16_t vb${ABC[K:K+4]}x${ABC[N:N+4]} = vld1q_s8(w); w = (const void*) ((const int8_t*) w + 16); 89 90 // Multiply-accumulate: ${MR}x8 * 8x${NR} --> ${MR}x${NR}. 91 $for K in range(0, 8, 4): 92 $for M in range(MR): 93 $for N in range(0, NR, 4): 94 vacc${M}x${ABC[N:N+4]} = vdotq_lane_s32(vacc${M}x${ABC[N:N+4]}, vb${ABC[K:K+4]}x${ABC[N:N+4]}, va${M}x01234567, ${K//4}); 95 96 k -= 8 * sizeof(int8_t); 97 } 98 // Handle up to 4 final positions of `k` 99 if XNN_UNLIKELY(k != 0) { 100 // Load a ${MR}x4 block of activations. 101 $for M in range(MR): 102 const int8x8_t va${M}x01234567 = vld1_s8(a${M}); a${M} += 4; 103 104 // Load a 4x${NR} block of weights. 105 $for N in range(0, NR, 4): 106 const int8x16_t vb0123x${ABC[N:N+4]} = vld1q_s8(w); w = (const void*) ((const int8_t*) w + 16); 107 108 // Multiply-accumulate: ${MR}x4 * 4x${NR} --> ${MR}x${NR}. 109 $for M in range(MR): 110 $for N in range(0, NR, 4): 111 vacc${M}x${ABC[N:N+4]} = vdotq_lane_s32(vacc${M}x${ABC[N:N+4]}, vb0123x${ABC[N:N+4]}, va${M}x01234567, 0); 112 } 113 114 $if REQUANTIZATION == "RNDNU": 115 const int32x4_t vright_pre_shift = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.right_pre_shift); 116 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.multiplier); 117 const int32x4_t vright_post_shift = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.right_post_shift); 118 119 $for M in range(MR): 120 $for N in range(0, NR, 4): 121 vacc${M}x${ABC[N:N+4]} = vqshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_pre_shift); 122 123 $for M in range(MR): 124 $for N in range(0, NR, 4): 125 vacc${M}x${ABC[N:N+4]} = vqdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier); 126 127 $for M in range(MR): 128 $for N in range(0, NR, 4): 129 vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_post_shift); 130 $elif REQUANTIZATION == "FP32": 131 $for M in range(MR): 132 $for N in range(0, NR, 4): 133 float32x4_t vfpacc${M}x${ABC[N:N+4]} = vcvtq_f32_s32(vacc${M}x${ABC[N:N+4]}); 134 135 $if CHANNELWISE: 136 $for N in range(0, NR, 4): 137 const float32x4_t vscale${ABC[N:N+4]} = vld1q_f32((const float*) w); w = (const void*) ((const float*) w + 4); 138 $for M in range(MR): 139 vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale${ABC[N:N+4]}); 140 $else: 141 const float32x4_t vscale = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.scale); 142 $for M in range(MR): 143 $for N in range(0, NR, 4): 144 vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale); 145 146 $for M in range(MR): 147 $for N in range(0, NR, 4): 148 vacc${M}x${ABC[N:N+4]} = vcvtnq_s32_f32(vfpacc${M}x${ABC[N:N+4]}); 149 150 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->${PARAMS_STRUCT}.output_zero_point); 151#if XNN_ARCH_ARM64 152 $for M in range(MR): 153 $for N in range(0, NR, 8): 154 const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]}), voutput_zero_point); 155 156 $for M in range(MR): 157 $for N in range(0, NR, 16): 158 $if N + 8 < NR: 159 int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]}); 160 $elif M % 2 == 1: 161 int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]}); 162 $elif M + 1 == MR: 163 int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]}); 164#else 165 $for M in range(MR): 166 $for N in range(0, NR, 8): 167 const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]})), voutput_zero_point); 168 169 $for M in range(MR): 170 $for N in range(0, NR, 16): 171 $if N + 8 < NR: 172 int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]})); 173 $elif M % 2 == 1: 174 int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]})); 175 $elif M + 1 == MR: 176 int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]}); 177#endif 178 $if NR == 8 and MR == 1: 179 const int8x8_t voutput_min = vld1_dup_s8(¶ms->${PARAMS_STRUCT}.output_min); 180 const int8x8_t voutput_max = vld1_dup_s8(¶ms->${PARAMS_STRUCT}.output_max); 181 $else: 182 const int8x16_t voutput_min = vld1q_dup_s8(¶ms->${PARAMS_STRUCT}.output_min); 183 const int8x16_t voutput_max = vld1q_dup_s8(¶ms->${PARAMS_STRUCT}.output_max); 184 185 $for M in range(MR): 186 $for N in range(0, NR, 16): 187 $if N + 8 < NR: 188 vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min); 189 $elif M % 2 == 1: 190 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min); 191 $elif M + 1 == MR: 192 $if NR == 8 and MR == 1: 193 vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min); 194 $else: 195 vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min)); 196 197 $for M in range(MR): 198 $for N in range(0, NR, 16): 199 $if N + 8 < NR: 200 vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max); 201 $elif M % 2 == 1: 202 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max); 203 $elif M + 1 == MR: 204 $if NR == 8 and MR == 1: 205 vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max); 206 $else: 207 vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max)); 208 209 if (nc >= ${NR}) { 210 // Main case where there the ${NR} columns fit in the destination. 211 $for M in range(MR): 212 $for N in range(0, NR, 16): 213 $if N + 8 < NR: 214 vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]}); 215 $elif M % 2 == 1: 216 vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); 217 vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); 218 $elif M + 1 == MR: 219 vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]}); 220 221 // Advance to the next ${NR} columns. 222 $for M in range(MR): 223 c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); 224 225 $for M in range(MR): 226 a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); 227 228 nc -= ${NR}; 229 } else { 230 // Final case where not all of the ${NR} columns fit in the destination. 231 $if NR == 16: 232 $for M in range(MR): 233 $if M % 2 == 1: 234 int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF)); 235 $elif M + 1 == MR: 236 int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF); 237 if (nc & 8) { 238 $for M in range(MR): 239 $if M % 2 == 1: 240 vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x01234567_${M}x01234567)); c${M-1} += 8; 241 vst1_s8(c${M}, vget_high_s8(vout${M-1}x01234567_${M}x01234567)); c${M} += 8; 242 $elif M + 1 == MR: 243 vst1_s8(c${M}, vout${M}x01234567); c${M} += 8; 244 $for M in range(MR): 245 $if M % 2 == 1: 246 vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF)); 247 $elif M + 1 == MR: 248 vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF); 249 } 250 if (nc & 4) { 251 $for M in range(MR): 252 $if M % 2 == 1: 253 vst1q_lane_u32((void*) c${M-1}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4; 254 vst1q_lane_u32((void*) c${M}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4; 255 $elif M + 1 == MR: 256 vst1_lane_u32((void*) c${M}, vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4; 257 $for M in range(MR): 258 $if M % 2 == 1: 259 vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4); 260 $elif M + 1 == MR: 261 vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4); 262 } 263 if (nc & 2) { 264 $for M in range(MR): 265 $if M % 2 == 1: 266 vst1q_lane_u16((void*) c${M-1}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2; 267 vst1q_lane_u16((void*) c${M}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2; 268 $elif M + 1 == MR: 269 vst1_lane_u16((void*) c${M}, vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2; 270 $for M in range(MR): 271 $if M % 2 == 1: 272 vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2); 273 $elif M + 1 == MR: 274 vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2); 275 } 276 if (nc & 1) { 277 $for M in range(MR): 278 $if M % 2 == 1: 279 vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0); 280 vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8); 281 $elif M + 1 == MR: 282 vst1_lane_s8(c${M}, vout${M}x01234567, 0); 283 } 284 285 nc = 0; 286 } 287 } while (nc != 0); 288} 289