xref: /aosp_15_r20/external/XNNPACK/src/s16-window/neon.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 8 == 0
7$assert BATCH_TILE >= 8
8$SIMD_TILE = BATCH_TILE // 8
9#include <assert.h>
10#include <stddef.h>
11#include <stdint.h>
12
13#include <arm_neon.h>
14
15#include <xnnpack/math.h>
16#include <xnnpack/window.h>
17
18$SHIFT_VARIANT = "_shift%s" % SHIFT if SHIFT else ""
19
20void xnn_s16_window${SHIFT_VARIANT}_ukernel__neon_x${BATCH_TILE}(
21    size_t rows,
22    size_t batch_size,
23    const int16_t* input,
24    const int16_t* weights,
25    int16_t* output,
26    uint32_t shift)
27{
28  assert(rows != 0);
29  assert(batch_size != 0);
30  assert(input != NULL);
31  assert(weights != NULL);
32  assert(output != NULL);
33  $if SHIFT != 0:
34    assert(shift == ${SHIFT});
35  $else:
36    assert(shift < 32);
37
38  $if SHIFT == 0:
39    const int32x4_t vshift = vdupq_n_s32(-(int32_t)shift);  // negative to shift right.
40
41  do {
42    const int16_t* w = weights;
43    size_t n = batch_size * sizeof(int16_t);
44    $if BATCH_TILE > 8:
45      for (; n >= ${BATCH_TILE} * sizeof(int16_t); n -= ${BATCH_TILE} * sizeof(int16_t)) {
46        $for N in range(SIMD_TILE):
47          const int16x8_t vi${N} = vld1q_s16(input); input += 8;
48
49        $for N in range(SIMD_TILE):
50          const int16x8_t vw${N} = vld1q_s16(w); w += 8;
51
52        $if SHIFT == 15:
53          $for N in range(SIMD_TILE):
54            const int16x8_t vout${N} = vqdmulhq_s16(vi${N}, vw${N});
55        $else:
56          $for N in range(SIMD_TILE):
57            int32x4_t vacc${N}_lo = vmull_s16(vget_low_s16(vi${N}), vget_low_s16(vw${N}));
58            int32x4_t vacc${N}_hi = vmull_s16(vget_high_s16(vi${N}), vget_high_s16(vw${N}));
59
60          $if SHIFT != 0:
61            $for N in range(SIMD_TILE):
62              const int16x4_t vshift${N}_lo = vqshrn_n_s32(vacc${N}_lo, ${SHIFT});
63              const int16x4_t vshift${N}_hi = vqshrn_n_s32(vacc${N}_hi, ${SHIFT});
64
65            $for N in range(SIMD_TILE):
66              const int16x8_t vout${N} = vcombine_s16(vshift${N}_lo, vshift${N}_hi);
67          $else:
68            $for N in range(SIMD_TILE):
69              vacc${N}_lo = vshlq_s32(vacc${N}_lo, vshift);
70              vacc${N}_hi = vshlq_s32(vacc${N}_hi, vshift);
71
72            $for N in range(SIMD_TILE):
73              const int16x8_t vout${N} = vcombine_s16(vqmovn_s32(vacc${N}_lo), vqmovn_s32(vacc${N}_hi));
74
75        $for N in range(SIMD_TILE):
76          vst1q_s16(output, vout${N}); output += 8;
77      }
78
79    // Remainder of full vectors
80    for (; n >= 8 * sizeof(int16_t); n -= 8 * sizeof(int16_t)) {
81      const int16x8_t vi = vld1q_s16(input); input += 8;
82      const int16x8_t vw = vld1q_s16(w); w += 8;
83      $if SHIFT == 15:
84        const int16x8_t vout = vqdmulhq_s16(vi, vw);
85      $else:
86        int32x4_t vacc_lo = vmull_s16(vget_low_s16(vi), vget_low_s16(vw));
87        int32x4_t vacc_hi = vmull_s16(vget_high_s16(vi), vget_high_s16(vw));
88        $if SHIFT != 0:
89          const int16x4_t vshift_lo = vqshrn_n_s32(vacc_lo, ${SHIFT});
90          const int16x4_t vshift_hi = vqshrn_n_s32(vacc_hi, ${SHIFT});
91          const int16x8_t vout = vcombine_s16(vshift_lo, vshift_hi);
92        $else:
93          vacc_lo = vshlq_s32(vacc_lo, vshift);
94          vacc_hi = vshlq_s32(vacc_hi, vshift);
95          const int16x8_t vout = vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
96      vst1q_s16(output, vout); output += 8;
97    }
98
99    assert(n % 2 == 0);
100    // Remainder of 1 to 7 batch_size
101    if XNN_UNLIKELY(n != 0) {
102      const int16x8_t vi = vld1q_s16(input); input = (const int16_t*) ((uintptr_t) input + n);
103      const int16x8_t vw = vld1q_s16(w);
104      $if SHIFT == 15:
105        int16x4_t vout = vqdmulh_s16(vget_low_s16(vi), vget_low_s16(vw));
106      $else:
107        int32x4_t vacc = vmull_s16(vget_low_s16(vi), vget_low_s16(vw));
108        $if SHIFT != 0:
109          int16x4_t vout = vqshrn_n_s32(vacc, ${SHIFT});
110        $else:
111          vacc = vshlq_s32(vacc, vshift);
112          int16x4_t vout = vqmovn_s32(vacc);
113      if (n & (4 * sizeof(int16_t))) {
114        vst1_s16(output, vout); output += 4;
115        $if SHIFT == 15:
116          vout = vqdmulh_s16(vget_high_s16(vi), vget_high_s16(vw));
117        $else:
118          vacc = vmull_s16(vget_high_s16(vi), vget_high_s16(vw));
119          $if SHIFT != 0:
120            vout = vqshrn_n_s32(vacc, ${SHIFT});
121          $else:
122            vacc = vshlq_s32(vacc, vshift);
123            vout = vqmovn_s32(vacc);
124      }
125      if (n & (2 * sizeof(int16_t))) {
126        vst1_lane_u32((void*) output, vreinterpret_u32_s16(vout), 0); output += 2;
127        vout = vext_s16(vout, vout, 2);
128      }
129      if (n & (1 * sizeof(int16_t))) {
130        vst1_lane_s16(output, vout, 0); output += 1;
131      }
132    }
133
134  } while (--rows != 0);
135}
136