1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 7$assert NR % 8 == 0 8$assert 8 <= NR <= 16 9$assert REQUANTIZATION == "RNDNU" 10#include <assert.h> 11 12#include <arm_neon.h> 13 14#include <xnnpack/gemm.h> 15#include <xnnpack/math.h> 16 17 18void xnn_qs8_igemm_minmax_rndnu_ukernel_${MR}x${NR}c16__neon_mlal( 19 size_t mr, 20 size_t nc, 21 size_t kc, 22 size_t ks, 23 const int8_t** restrict a, 24 const void* restrict w, 25 int8_t* restrict c, 26 size_t cm_stride, 27 size_t cn_stride, 28 size_t a_offset, 29 const int8_t* zero, 30 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 31{ 32 assert(mr != 0); 33 assert(mr <= ${MR}); 34 assert(nc != 0); 35 assert(kc != 0); 36 assert(ks != 0); 37 assert(ks % (${MR} * sizeof(void*)) == 0); 38 assert(a_offset % sizeof(int8_t) == 0); 39 assert(a != NULL); 40 assert(w != NULL); 41 assert(c != NULL); 42 43 kc = round_up_po2(kc, 16 * sizeof(int8_t)); 44 int8_t* c0 = c; 45 $for M in range(1, MR): 46 int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); 47 $if M % 2 == 0: 48 if XNN_UNPREDICTABLE(mr <= ${M}) { 49 c${M} = c${M-1}; 50 } 51 $elif M + 1 == MR: 52 if XNN_UNPREDICTABLE(mr != ${M+1}) { 53 c${M} = c${M-1}; 54 } 55 $else: 56 if XNN_UNPREDICTABLE(mr < ${M+1}) { 57 c${M} = c${M-1}; 58 } 59 60 do { 61 $for N in range(NR): 62 int32x4_t vacc0x${N} = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); 63 $for M in range(1, MR): 64 $for N in range(NR): 65 int32x4_t vacc${M}x${N} = vacc0x${N}; 66 67 size_t p = ks; 68 do { 69 $for M in range(MR): 70 const int8_t* restrict a${M} = a[${M}]; 71 if XNN_UNPREDICTABLE(a${M} != zero) { 72 a${M} = (const int8_t*) ((uintptr_t) a${M} + a_offset); 73 } 74 a += ${MR}; 75 76 // KC loop of 16 with up to 15 remainder 77 size_t k = kc; 78 while (k != 0) { 79 $for M in range(MR): 80 const int8x16_t va${M} = vld1q_s8(a${M}); a${M} += 16; 81 82 $for N in range(NR): 83 const int8x16_t vb${N} = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t)); 84 85 $for N in range(NR): 86 $for M in range(MR): 87 int16x8_t vprod${M}x${N} = vmull_s8(vget_low_s8(vb${N}), vget_low_s8(va${M})); 88 $for M in range(MR): 89 vprod${M}x${N} = vmlal_s8(vprod${M}x${N}, vget_high_s8(vb${N}), vget_high_s8(va${M})); 90 $for M in range(MR): 91 vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N}); 92 93 k -= 16 * sizeof(int8_t); 94 } 95 96 p -= ${MR} * sizeof(void*); 97 } while (p != 0); 98 99#if XNN_ARCH_ARM64 100 $for M in range(MR): 101 $for N in range(0, NR, 4): 102 const int32x4_t vsum${M}x${ABC[N:N+2]} = vpaddq_s32(vacc${M}x${N}, vacc${M}x${N+1}); 103 const int32x4_t vsum${M}x${ABC[N+2:N+4]} = vpaddq_s32(vacc${M}x${N+2}, vacc${M}x${N+3}); 104 $for M in range(MR): 105 $for N in range(0, NR, 4): 106 int32x4_t vacc${M}x${ABC[N:N+4]} = vpaddq_s32(vsum${M}x${ABC[N:N+2]}, vsum${M}x${ABC[N+2:N+4]}); 107#else 108 $for M in range(MR): 109 $for N in range(0, NR, 4): 110 const int32x2_t vpsum${M}x${ABC[N]} = vadd_s32(vget_low_s32(vacc${M}x${N}), vget_high_s32(vacc${M}x${N})); 111 const int32x2_t vpsum${M}x${ABC[N+1]} = vadd_s32(vget_low_s32(vacc${M}x${N+1}), vget_high_s32(vacc${M}x${N+1})); 112 const int32x2_t vpsum${M}x${ABC[N+2]} = vadd_s32(vget_low_s32(vacc${M}x${N+2}), vget_high_s32(vacc${M}x${N+2})); 113 const int32x2_t vpsum${M}x${ABC[N+3]} = vadd_s32(vget_low_s32(vacc${M}x${N+3}), vget_high_s32(vacc${M}x${N+3})); 114 const int32x2_t vsum${M}x${ABC[N:N+2]} = vpadd_s32(vpsum${M}x${ABC[N]}, vpsum${M}x${ABC[N+1]}); 115 const int32x2_t vsum${M}x${ABC[N+2:N+4]} = vpadd_s32(vpsum${M}x${ABC[N+2]}, vpsum${M}x${ABC[N+3]}); 116 int32x4_t vacc${M}x${ABC[N:N+4]} = vcombine_s32(vsum${M}x${ABC[N:N+2]}, vsum${M}x${ABC[N+2:N+4]} ); 117#endif 118 119 const int32x4_t vright_pre_shift = vld1q_dup_s32(¶ms->rndnu_neon.right_pre_shift); 120 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->rndnu_neon.multiplier); 121 const int32x4_t vright_post_shift = vld1q_dup_s32(¶ms->rndnu_neon.right_post_shift); 122 123 $for M in range(MR): 124 $for N in range(0, NR, 4): 125 vacc${M}x${ABC[N:N+4]} = vqshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_pre_shift); 126 127 $for M in range(MR): 128 $for N in range(0, NR, 4): 129 vacc${M}x${ABC[N:N+4]} = vqdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier); 130 131 $for M in range(MR): 132 $for N in range(0, NR, 4): 133 vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_post_shift); 134 135 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->rndnu_neon.output_zero_point); 136#if XNN_ARCH_ARM64 137 $for M in range(MR): 138 $for N in range(0, NR, 8): 139 const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]}), voutput_zero_point); 140 $for M in range(MR): 141 $for N in range(0, NR, 16): 142 $if N + 8 < NR: 143 int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]}); 144 $elif M % 2 == 1: 145 int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]}); 146 $elif M + 1 == MR: 147 int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]}); 148#else 149 $for M in range(MR): 150 $for N in range(0, NR, 8): 151 const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]})), voutput_zero_point); 152 153 $for M in range(MR): 154 $for N in range(0, NR, 16): 155 $if N + 8 < NR: 156 int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]})); 157 $elif M % 2 == 1: 158 int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]})); 159 $elif M + 1 == MR: 160 int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]}); 161#endif 162 $if NR == 8 and MR == 1: 163 const int8x8_t voutput_min = vld1_dup_s8(¶ms->rndnu_neon.output_min); 164 const int8x8_t voutput_max = vld1_dup_s8(¶ms->rndnu_neon.output_max); 165 $else: 166 const int8x16_t voutput_min = vld1q_dup_s8(¶ms->rndnu_neon.output_min); 167 const int8x16_t voutput_max = vld1q_dup_s8(¶ms->rndnu_neon.output_max); 168 169 $for M in reversed(range(MR)): 170 $for N in range(0, NR, 16): 171 $if N + 8 < NR: 172 vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min); 173 $elif M % 2 == 1: 174 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min); 175 $elif M + 1 == MR: 176 $if NR == 8 and MR == 1: 177 vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min); 178 $else: 179 vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min)); 180 181 $for M in reversed(range(MR)): 182 $for N in range(0, NR, 16): 183 $if N + 8 < NR: 184 vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max); 185 $elif M % 2 == 1: 186 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max); 187 $elif M + 1 == MR: 188 $if NR == 8 and MR == 1: 189 vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max); 190 $else: 191 vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max)); 192 193 if (nc >= ${NR}) { 194 $for M in reversed(range(MR)): 195 $for N in range(0, NR, 16): 196 $if N + 8 < NR: 197 vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]}); 198 $elif M % 2 == 1: 199 vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); 200 vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); 201 $elif M + 1 == MR: 202 vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]}); 203 204 $for M in reversed(range(MR)): 205 c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); 206 207 a = (const int8_t**restrict) ((uintptr_t) a - ks); 208 209 nc -= ${NR}; 210 } else { 211 $if NR == 16: 212 $for M in reversed(range(MR)): 213 $if M % 2 == 1: 214 int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF)); 215 $elif M + 1 == MR: 216 int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF); 217 if (nc & 8) { 218 $for M in reversed(range(MR)): 219 $if M % 2 == 1: 220 vst1_s8(c${M}, vget_high_s8(vout${M-1}x01234567_${M}x01234567)); c${M} += 8; 221 vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x01234567_${M}x01234567)); c${M-1} += 8; 222 $elif M + 1 == MR: 223 vst1_s8(c${M}, vout${M}x01234567); c${M} += 8; 224 $for M in reversed(range(MR)): 225 $if M % 2 == 1: 226 vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF)); 227 $elif M + 1 == MR: 228 vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF); 229 } 230 if (nc & 4) { 231 $for M in reversed(range(MR)): 232 $if M % 2 == 1: 233 vst1q_lane_u32((void*) c${M}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4; 234 vst1q_lane_u32((void*) c${M-1}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4; 235 $elif M + 1 == MR: 236 vst1_lane_u32((void*) c${M}, vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4; 237 $for M in reversed(range(MR)): 238 $if M % 2 == 1: 239 vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4); 240 $elif M + 1 == MR: 241 vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4); 242 } 243 if (nc & 2) { 244 $for M in reversed(range(MR)): 245 $if M % 2 == 1: 246 vst1q_lane_u16((void*) c${M}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2; 247 vst1q_lane_u16((void*) c${M-1}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2; 248 $elif M + 1 == MR: 249 vst1_lane_u16((void*) c${M}, vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2; 250 $for M in reversed(range(MR)): 251 $if M % 2 == 1: 252 vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2); 253 $elif M + 1 == MR: 254 vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2); 255 } 256 if (nc & 1) { 257 $for M in reversed(range(MR)): 258 $if M % 2 == 1: 259 vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8); 260 vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0); 261 $elif M + 1 == MR: 262 vst1_lane_s8(c${M}, vout${M}x01234567, 0); 263 } 264 265 nc = 0; 266 } 267 } while (nc != 0); 268} 269