xref: /aosp_15_r20/external/XNNPACK/src/f32-qs8-vcvt/neon.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 8 == 0
7$assert BATCH_TILE >= 8
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9#include <assert.h>
10
11#include <arm_neon.h>
12
13#include <xnnpack/common.h>
14#include <xnnpack/intrinsics-polyfill.h>
15#include <xnnpack/vcvt.h>
16
17
18$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
19$XINT8X8_T = {"QS8": "int8x8_t", "QU8": "uint8x8_t"}[DATATYPE]
20$XINT8X16_T = {"QS8": "int8x16_t", "QU8": "uint8x16_t"}[DATATYPE]
21$VLD1Q_DUP_X8 = {"QS8": "vld1q_dup_s8", "QU8": "vld1q_dup_u8"}[DATATYPE]
22$VLD1_DUP_X8 = {"QS8": "vld1_dup_s8", "QU8": "vld1_dup_u8"}[DATATYPE]
23$VST1Q_X8 = {"QS8": "vst1q_s8", "QU8": "vst1q_u8"}[DATATYPE]
24$VST1_X8 = {"QS8": "vst1_s8", "QU8": "vst1_u8"}[DATATYPE]
25$VST1_LANE_X8 = {"QS8": "vst1_lane_s8", "QU8": "vst1_lane_u8"}[DATATYPE]
26$VQMOVXN_S16 = {"QS8": "vqmovn_s16", "QU8": "vqmovun_s16"}[DATATYPE]
27$VEXT_X8 = {"QS8": "vext_s8", "QU8": "vext_u8"}[DATATYPE]
28$VCOMBINE_X8 = {"QS8": "vcombine_s8", "QU8": "vcombine_u8"}[DATATYPE]
29$VGET_LOW_X8 = {"QS8": "vget_low_s8", "QU8": "vget_low_u8"}[DATATYPE]
30$VREINTERPRET_U16_X8 = {"QS8": "vreinterpret_u16_s8", "QU8": "vreinterpret_u16_u8"}[DATATYPE]
31$VREINTERPRET_U32_X8 = {"QS8": "vreinterpret_u32_s8", "QU8": "vreinterpret_u32_u8"}[DATATYPE]
32$VMINQ_X8 = {"QS8": "vminq_s8", "QU8": "vminq_u8"}[DATATYPE]
33$VMIN_X8 = {"QS8": "vmin_s8", "QU8": "vmin_u8"}[DATATYPE]
34$VMAXQ_X8 = {"QS8": "vmaxq_s8", "QU8": "vmaxq_u8"}[DATATYPE]
35$VMAX_X8 = {"QS8": "vmax_s8", "QU8": "vmax_u8"}[DATATYPE]
36void xnn_f32_${DATATYPE.lower()}_vcvt_ukernel__neon_x${BATCH_TILE}(
37    size_t n,
38    const float* x,
39    ${XINT8_T}* y,
40    const union xnn_f32_${DATATYPE.lower()}_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
41{
42  assert(n != 0);
43  assert(n % sizeof(float) == 0);
44  assert(x != NULL);
45  assert(y != NULL);
46
47  const float32x4_t vscale = vld1q_dup_f32(&params->neon.scale);
48  const float32x4_t vmagic_bias = vld1q_dup_f32(&params->neon.magic_bias);
49  const int32x4_t vmagic_bias_less_zero_point = vld1q_dup_s32(&params->neon.magic_bias_less_zero_point);
50  $if BATCH_TILE > 8:
51    const ${XINT8X16_T} voutput_min = ${VLD1Q_DUP_X8}(&params->neon.output_min);
52    const ${XINT8X16_T} voutput_max = ${VLD1Q_DUP_X8}(&params->neon.output_max);
53  $else:
54    const ${XINT8X8_T} voutput_min = ${VLD1_DUP_X8}(&params->neon.output_min);
55    const ${XINT8X8_T} voutput_max = ${VLD1_DUP_X8}(&params->neon.output_max);
56  $if BATCH_TILE > 8:
57    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
58      $for N in range(0, BATCH_TILE, 4):
59        float32x4_t vx${ABC[N:N+4]} = vld1q_f32(x); x += 4;
60
61      $for N in range(0, BATCH_TILE, 4):
62        vx${ABC[N:N+4]} = vmulq_f32(vx${ABC[N:N+4]}, vscale);
63
64      $for N in range(0, BATCH_TILE, 4):
65        vx${ABC[N:N+4]} = vaddq_f32(vx${ABC[N:N+4]}, vmagic_bias);
66
67      $for N in range(0, BATCH_TILE, 4):
68        const int32x4_t vacc${ABC[N:N+4]} = vqsubq_s32(vreinterpretq_s32_f32(vx${ABC[N:N+4]}), vmagic_bias_less_zero_point);
69
70      $for N in range(0, BATCH_TILE, 8):
71        const int16x8_t vacc${ABC[N:N+8]} = vcombine_s16(vqmovn_s32(vacc${ABC[N:N+4]}), vqmovn_s32(vacc${ABC[N+4:N+8]}));
72
73      $for N in range(0, BATCH_TILE, 16):
74        $if N + 8 < BATCH_TILE:
75          ${XINT8X16_T} vy${ABC[N:N+16]} = ${VCOMBINE_X8}(${VQMOVXN_S16}(vacc${ABC[N:N+8]}), ${VQMOVXN_S16}(vacc${ABC[N+8:N+16]}));
76        $else:
77          ${XINT8X8_T} vy${ABC[N:N+8]} = ${VQMOVXN_S16}(vacc${ABC[N:N+8]});
78
79      $for N in range(0, BATCH_TILE, 16):
80        $if N + 8 < BATCH_TILE:
81          vy${ABC[N:N+16]} = ${VMAXQ_X8}(vy${ABC[N:N+16]}, voutput_min);
82        $else:
83          vy${ABC[N:N+8]} = ${VMAX_X8}(vy${ABC[N:N+8]}, ${VGET_LOW_X8}(voutput_min));
84
85      $for N in range(0, BATCH_TILE, 16):
86        $if N + 8 < BATCH_TILE:
87          vy${ABC[N:N+16]} = ${VMINQ_X8}(vy${ABC[N:N+16]}, voutput_max);
88        $else:
89          vy${ABC[N:N+8]} = ${VMIN_X8}(vy${ABC[N:N+8]}, ${VGET_LOW_X8}(voutput_max));
90
91      $for N in range(0, BATCH_TILE, 16):
92        $if N + 8 < BATCH_TILE:
93          ${VST1Q_X8}(y, vy${ABC[N:N+16]}); y += 16;
94        $else:
95          ${VST1_X8}(y, vy${ABC[N:N+8]}); y += 8;
96    }
97  for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
98    float32x4_t vx_lo = vld1q_f32(x); x += 4;
99    float32x4_t vx_hi = vld1q_f32(x); x += 4;
100
101    vx_lo = vmulq_f32(vx_lo, vscale);
102    vx_hi = vmulq_f32(vx_hi, vscale);
103
104    vx_lo = vaddq_f32(vx_lo, vmagic_bias);
105    vx_hi = vaddq_f32(vx_hi, vmagic_bias);
106
107    const int32x4_t vacc_lo = vqsubq_s32(vreinterpretq_s32_f32(vx_lo), vmagic_bias_less_zero_point);
108    const int32x4_t vacc_hi = vqsubq_s32(vreinterpretq_s32_f32(vx_hi), vmagic_bias_less_zero_point);
109
110    const int16x8_t vacc = vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
111
112    ${XINT8X8_T} vy = ${VQMOVXN_S16}(vacc);
113    $if BATCH_TILE > 8:
114      vy = ${VMAX_X8}(vy, ${VGET_LOW_X8}(voutput_min));
115      vy = ${VMIN_X8}(vy, ${VGET_LOW_X8}(voutput_max));
116    $else:
117      vy = ${VMAX_X8}(vy, voutput_min);
118      vy = ${VMIN_X8}(vy, voutput_max);
119    ${VST1_X8}(y, vy); y += 8;
120  }
121  if XNN_UNLIKELY(n != 0) {
122    assert(n >= 1 * sizeof(float));
123    assert(n <= 7 * sizeof(float));
124    float32x4_t vx_lo = vld1q_f32(x);
125    const float* x_hi = (const float*) ((uintptr_t) x + (n & (4 * sizeof(float))));
126    float32x4_t vx_hi = vld1q_f32(x_hi);
127
128    vx_lo = vmulq_f32(vx_lo, vscale);
129    vx_hi = vmulq_f32(vx_hi, vscale);
130
131    vx_lo = vaddq_f32(vx_lo, vmagic_bias);
132    vx_hi = vaddq_f32(vx_hi, vmagic_bias);
133
134    const int32x4_t vacc_lo = vqsubq_s32(vreinterpretq_s32_f32(vx_lo), vmagic_bias_less_zero_point);
135    const int32x4_t vacc_hi = vqsubq_s32(vreinterpretq_s32_f32(vx_hi), vmagic_bias_less_zero_point);
136
137    const int16x8_t vacc = vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
138
139    ${XINT8X8_T} vy = ${VQMOVXN_S16}(vacc);
140    $if BATCH_TILE > 8:
141      vy = ${VMAX_X8}(vy, ${VGET_LOW_X8}(voutput_min));
142      vy = ${VMIN_X8}(vy, ${VGET_LOW_X8}(voutput_max));
143    $else:
144      vy = ${VMAX_X8}(vy, voutput_min);
145      vy = ${VMIN_X8}(vy, voutput_max);
146
147    if (n & (4 * sizeof(float))) {
148      vst1_lane_u32((void*) y, ${VREINTERPRET_U32_X8}(vy), 0); y += 4;
149      vy = ${VEXT_X8}(vy, vy, 4);
150    }
151    if (n & (2 * sizeof(float))) {
152      vst1_lane_u16((void*) y, ${VREINTERPRET_U16_X8}(vy), 0); y += 2;
153      vy = ${VEXT_X8}(vy, vy, 2);
154    }
155    if (n & (1 * sizeof(float))) {
156      ${VST1_LANE_X8}(y, vy, 0);
157    }
158  }
159}
160