1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 8 == 0 7$assert BATCH_TILE >= 8 8$SIMD_TILE = BATCH_TILE // 8 9#include <assert.h> 10#include <stddef.h> 11#include <stdint.h> 12 13#include <arm_neon.h> 14 15#include <xnnpack/math.h> 16#include <xnnpack/window.h> 17 18$SHIFT_VARIANT = "_shift%s" % SHIFT if SHIFT else "" 19 20void xnn_s16_window${SHIFT_VARIANT}_ukernel__neon_x${BATCH_TILE}( 21 size_t rows, 22 size_t batch_size, 23 const int16_t* input, 24 const int16_t* weights, 25 int16_t* output, 26 uint32_t shift) 27{ 28 assert(rows != 0); 29 assert(batch_size != 0); 30 assert(input != NULL); 31 assert(weights != NULL); 32 assert(output != NULL); 33 $if SHIFT != 0: 34 assert(shift == ${SHIFT}); 35 $else: 36 assert(shift < 32); 37 38 $if SHIFT == 0: 39 const int32x4_t vshift = vdupq_n_s32(-(int32_t)shift); // negative to shift right. 40 41 do { 42 const int16_t* w = weights; 43 size_t n = batch_size * sizeof(int16_t); 44 $if BATCH_TILE > 8: 45 for (; n >= ${BATCH_TILE} * sizeof(int16_t); n -= ${BATCH_TILE} * sizeof(int16_t)) { 46 $for N in range(SIMD_TILE): 47 const int16x8_t vi${N} = vld1q_s16(input); input += 8; 48 49 $for N in range(SIMD_TILE): 50 const int16x8_t vw${N} = vld1q_s16(w); w += 8; 51 52 $if SHIFT == 15: 53 $for N in range(SIMD_TILE): 54 const int16x8_t vout${N} = vqdmulhq_s16(vi${N}, vw${N}); 55 $else: 56 $for N in range(SIMD_TILE): 57 int32x4_t vacc${N}_lo = vmull_s16(vget_low_s16(vi${N}), vget_low_s16(vw${N})); 58 int32x4_t vacc${N}_hi = vmull_s16(vget_high_s16(vi${N}), vget_high_s16(vw${N})); 59 60 $if SHIFT != 0: 61 $for N in range(SIMD_TILE): 62 const int16x4_t vshift${N}_lo = vqshrn_n_s32(vacc${N}_lo, ${SHIFT}); 63 const int16x4_t vshift${N}_hi = vqshrn_n_s32(vacc${N}_hi, ${SHIFT}); 64 65 $for N in range(SIMD_TILE): 66 const int16x8_t vout${N} = vcombine_s16(vshift${N}_lo, vshift${N}_hi); 67 $else: 68 $for N in range(SIMD_TILE): 69 vacc${N}_lo = vshlq_s32(vacc${N}_lo, vshift); 70 vacc${N}_hi = vshlq_s32(vacc${N}_hi, vshift); 71 72 $for N in range(SIMD_TILE): 73 const int16x8_t vout${N} = vcombine_s16(vqmovn_s32(vacc${N}_lo), vqmovn_s32(vacc${N}_hi)); 74 75 $for N in range(SIMD_TILE): 76 vst1q_s16(output, vout${N}); output += 8; 77 } 78 79 // Remainder of full vectors 80 for (; n >= 8 * sizeof(int16_t); n -= 8 * sizeof(int16_t)) { 81 const int16x8_t vi = vld1q_s16(input); input += 8; 82 const int16x8_t vw = vld1q_s16(w); w += 8; 83 $if SHIFT == 15: 84 const int16x8_t vout = vqdmulhq_s16(vi, vw); 85 $else: 86 int32x4_t vacc_lo = vmull_s16(vget_low_s16(vi), vget_low_s16(vw)); 87 int32x4_t vacc_hi = vmull_s16(vget_high_s16(vi), vget_high_s16(vw)); 88 $if SHIFT != 0: 89 const int16x4_t vshift_lo = vqshrn_n_s32(vacc_lo, ${SHIFT}); 90 const int16x4_t vshift_hi = vqshrn_n_s32(vacc_hi, ${SHIFT}); 91 const int16x8_t vout = vcombine_s16(vshift_lo, vshift_hi); 92 $else: 93 vacc_lo = vshlq_s32(vacc_lo, vshift); 94 vacc_hi = vshlq_s32(vacc_hi, vshift); 95 const int16x8_t vout = vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)); 96 vst1q_s16(output, vout); output += 8; 97 } 98 99 assert(n % 2 == 0); 100 // Remainder of 1 to 7 batch_size 101 if XNN_UNLIKELY(n != 0) { 102 const int16x8_t vi = vld1q_s16(input); input = (const int16_t*) ((uintptr_t) input + n); 103 const int16x8_t vw = vld1q_s16(w); 104 $if SHIFT == 15: 105 int16x4_t vout = vqdmulh_s16(vget_low_s16(vi), vget_low_s16(vw)); 106 $else: 107 int32x4_t vacc = vmull_s16(vget_low_s16(vi), vget_low_s16(vw)); 108 $if SHIFT != 0: 109 int16x4_t vout = vqshrn_n_s32(vacc, ${SHIFT}); 110 $else: 111 vacc = vshlq_s32(vacc, vshift); 112 int16x4_t vout = vqmovn_s32(vacc); 113 if (n & (4 * sizeof(int16_t))) { 114 vst1_s16(output, vout); output += 4; 115 $if SHIFT == 15: 116 vout = vqdmulh_s16(vget_high_s16(vi), vget_high_s16(vw)); 117 $else: 118 vacc = vmull_s16(vget_high_s16(vi), vget_high_s16(vw)); 119 $if SHIFT != 0: 120 vout = vqshrn_n_s32(vacc, ${SHIFT}); 121 $else: 122 vacc = vshlq_s32(vacc, vshift); 123 vout = vqmovn_s32(vacc); 124 } 125 if (n & (2 * sizeof(int16_t))) { 126 vst1_lane_u32((void*) output, vreinterpret_u32_s16(vout), 0); output += 2; 127 vout = vext_s16(vout, vout, 2); 128 } 129 if (n & (1 * sizeof(int16_t))) { 130 vst1_lane_s16(output, vout, 0); output += 1; 131 } 132 } 133 134 } while (--rows != 0); 135} 136