xref: /aosp_15_r20/external/XNNPACK/src/f16-vsigmoid/neonfp16arith.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 8 == 0
7$assert BATCH_TILE >= 8
8$SIMD_TILE = BATCH_TILE // 8
9$assert DIV_ALGO in ["DIV", "NR1FMA", "NR1RECPS"]
10$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
11#include <assert.h>
12
13#include <arm_neon.h>
14
15#include <xnnpack/common.h>
16#include <xnnpack/vunary.h>
17
18
19void xnn_f16_vsigmoid_ukernel__neonfp16arith_rr2_p2_${DIV_ALGO.lower()}_x${BATCH_TILE}(
20    size_t batch,
21    const void* input,
22    void* output,
23    const union xnn_f16_sigmoid_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
24{
25  assert(batch % sizeof(__fp16) == 0);
26
27  const float16x8_t vmagic_bias = vreinterpretq_f16_u16(vld1q_dup_u16(&params->neonfp16arith_rr2_p2.magic_bias));
28  const float16x8_t vminus_log2e = vreinterpretq_f16_u16(vld1q_dup_u16(&params->neonfp16arith_rr2_p2.minus_log2e));
29  const float16x8_t vln2_hi = vreinterpretq_f16_u16(vld1q_dup_u16(&params->neonfp16arith_rr2_p2.ln2_hi));
30  const float16x8_t vln2_lo = vreinterpretq_f16_u16(vld1q_dup_u16(&params->neonfp16arith_rr2_p2.ln2_lo));
31  const float16x8_t vc2 = vreinterpretq_f16_u16(vld1q_dup_u16(&params->neonfp16arith_rr2_p2.c2));
32  const float16x8_t vc1 = vreinterpretq_f16_u16(vld1q_dup_u16(&params->neonfp16arith_rr2_p2.c1));
33  const float16x8_t vone = vmovq_n_f16(1.0f);
34  const float16x8_t vdenorm_cutoff = vreinterpretq_f16_u16(vld1q_dup_u16(&params->neonfp16arith_rr2_p2.denorm_cutoff));
35
36  const __fp16* i = (const __fp16*) input;
37  __fp16* o = (__fp16*) output;
38  $if BATCH_TILE > 8:
39    for (; batch >= ${BATCH_TILE} * sizeof(__fp16); batch -= ${BATCH_TILE} * sizeof(__fp16)) {
40      $for N in range(SIMD_TILE):
41        const float16x8_t vx${ABC[N]} = vld1q_f16(i); i += 8;
42
43      $for N in range(SIMD_TILE):
44        const float16x8_t vz${ABC[N]} = vabsq_f16(vx${ABC[N]});
45
46      $for N in range(SIMD_TILE):
47        float16x8_t vn${ABC[N]} = vfmaq_f16(vmagic_bias, vz${ABC[N]}, vminus_log2e);
48
49      $for N in range(SIMD_TILE):
50        const float16x8_t vs${ABC[N]} = vreinterpretq_f16_s16(vshlq_n_s16(vreinterpretq_s16_f16(vn${ABC[N]}), 10));
51
52      $for N in range(SIMD_TILE):
53        vn${ABC[N]} = vsubq_f16(vn${ABC[N]}, vmagic_bias);
54
55      $for N in range(SIMD_TILE):
56        float16x8_t vt${ABC[N]} = vfmaq_f16(vz${ABC[N]}, vn${ABC[N]}, vln2_hi);
57
58      $for N in range(SIMD_TILE):
59        vt${ABC[N]} = vfmaq_f16(vt${ABC[N]}, vn${ABC[N]}, vln2_lo);
60
61      $for N in range(SIMD_TILE):
62        const float16x8_t vp${ABC[N]} = vfmaq_f16(vc1, vc2, vt${ABC[N]});
63
64      $for N in range(SIMD_TILE):
65        vt${ABC[N]} = vmulq_f16(vt${ABC[N]}, vs${ABC[N]});
66
67      $for N in range(SIMD_TILE):
68        const float16x8_t ve${ABC[N]} = vfmaq_f16(vs${ABC[N]}, vp${ABC[N]}, vt${ABC[N]});
69
70      $for N in range(SIMD_TILE):
71        const float16x8_t vd${ABC[N]} = vaddq_f16(ve${ABC[N]}, vone);
72
73      $if DIV_ALGO == "DIV":
74        $for N in range(SIMD_TILE):
75          float16x8_t vf${ABC[N]} = vdivq_f16(ve${ABC[N]}, vd${ABC[N]});
76      $else:
77        $for N in range(SIMD_TILE):
78          float16x8_t vr${ABC[N]} = vrecpeq_f16(vd${ABC[N]});
79
80        $if DIV_ALGO == "NR1FMA":
81          $for N in range(SIMD_TILE):
82            const float16x8_t vadj${ABC[N]} = vfmsq_f16(vone, vr${N}, vd${N});
83
84          $for N in range(SIMD_TILE):
85            vr${ABC[N]} = vfmaq_f16(vr${ABC[N]}, vr${ABC[N]}, vadj${ABC[N]});
86        $else:
87          $for N in range(SIMD_TILE):
88            const float16x8_t vadj${ABC[N]} = vrecpsq_f16(vr${ABC[N]}, vd${ABC[N]});
89
90          $for N in range(SIMD_TILE):
91            vr${ABC[N]} = vmulq_f16(vr${ABC[N]}, vadj${ABC[N]});
92
93        $for N in range(SIMD_TILE):
94          float16x8_t vf${ABC[N]} = vmulq_f16(ve${ABC[N]}, vr${ABC[N]});
95
96      $for N in range(SIMD_TILE):
97        vf${ABC[N]} = vreinterpretq_f16_u16(vbicq_u16(vreinterpretq_u16_f16(vf${ABC[N]}), vcagtq_f16(vx${ABC[N]}, vdenorm_cutoff)));
98
99      $for N in range(SIMD_TILE):
100        const uint16x8_t vm${ABC[N]} = vcltq_f16(vx${ABC[N]}, vmovq_n_f16(0.0f));
101
102      $for N in range(SIMD_TILE):
103        vf${ABC[N]} = vbslq_f16(vm${ABC[N]}, vf${ABC[N]}, vsubq_f16(vone, vf${ABC[N]}));
104
105      $for N in range(SIMD_TILE):
106        vst1q_f16(o, vf${ABC[N]}); o += 8;
107    }
108  for (; batch >= 8 * sizeof(__fp16); batch -= 8 * sizeof(__fp16)) {
109    const float16x8_t vx = vld1q_f16(i); i += 8;
110
111    const float16x8_t vz = vabsq_f16(vx);
112
113    float16x8_t vn = vfmaq_f16(vmagic_bias, vz, vminus_log2e);
114    const float16x8_t vs = vreinterpretq_f16_s16(vshlq_n_s16(vreinterpretq_s16_f16(vn), 10));
115    vn = vsubq_f16(vn, vmagic_bias);
116
117    float16x8_t vt = vfmaq_f16(vz, vn, vln2_hi);
118    vt = vfmaq_f16(vt, vn, vln2_lo);
119
120    const float16x8_t vp = vfmaq_f16(vc1, vc2, vt);
121    vt = vmulq_f16(vt, vs);
122    const float16x8_t ve = vfmaq_f16(vs, vp, vt);
123    const float16x8_t vd = vaddq_f16(ve, vone);
124
125    $if DIV_ALGO == "DIV":
126      float16x8_t vf = vdivq_f16(ve, vd);
127    $else:
128      float16x8_t vr = vrecpeq_f16(vd);
129      $if DIV_ALGO == "NR1FMA":
130        const float16x8_t vadj = vfmsq_f16(vone, vr, vd);
131        vr = vfmaq_f16(vr, vr, vadj);
132      $else:
133        const float16x8_t vadj = vrecpsq_f16(vr, vd);
134        vr = vmulq_f16(vr, vadj);
135
136      float16x8_t vf = vmulq_f16(ve, vr);
137    vf = vreinterpretq_f16_u16(vbicq_u16(vreinterpretq_u16_f16(vf), vcagtq_f16(vx, vdenorm_cutoff)));
138    const uint16x8_t vm = vcltq_f16(vx, vmovq_n_f16(0.0f));
139    vf = vbslq_f16(vm, vf, vsubq_f16(vone, vf));
140
141    vst1q_f16(o, vf); o += 8;
142  }
143  if XNN_UNLIKELY(batch != 0) {
144    const float16x8_t vx = vld1q_f16(i);
145
146    const float16x8_t vz = vabsq_f16(vx);
147
148    float16x8_t vn = vfmaq_f16(vmagic_bias, vz, vminus_log2e);
149    const float16x8_t vs = vreinterpretq_f16_s16(vshlq_n_s16(vreinterpretq_s16_f16(vn), 10));
150    vn = vsubq_f16(vn, vmagic_bias);
151
152    float16x8_t vt = vfmaq_f16(vz, vn, vln2_hi);
153    vt = vfmaq_f16(vt, vn, vln2_lo);
154
155    const float16x8_t vp = vfmaq_f16(vc1, vc2, vt);
156    vt = vmulq_f16(vt, vs);
157    const float16x8_t ve = vfmaq_f16(vs, vp, vt);
158    const float16x8_t vd = vaddq_f16(ve, vone);
159
160    $if DIV_ALGO == "DIV":
161      float16x8_t vf = vdivq_f16(ve, vd);
162    $else:
163      float16x8_t vr = vrecpeq_f16(vd);
164      $if DIV_ALGO == "NR1FMA":
165        const float16x8_t vadj = vfmsq_f16(vone, vr, vd);
166        vr = vfmaq_f16(vr, vr, vadj);
167      $else:
168        const float16x8_t vadj = vrecpsq_f16(vr, vd);
169        vr = vmulq_f16(vr, vadj);
170
171      float16x8_t vf = vmulq_f16(ve, vr);
172    vf = vreinterpretq_f16_u16(vbicq_u16(vreinterpretq_u16_f16(vf), vcagtq_f16(vx, vdenorm_cutoff)));
173    const uint16x8_t vm = vcltq_f16(vx, vmovq_n_f16(0.0f));
174    vf = vbslq_f16(vm, vf, vsubq_f16(vone, vf));
175
176    float16x4_t vf_lo = vget_low_f16(vf);
177    if (batch & (4 * sizeof(__fp16))) {
178      vst1_f16(o, vf_lo); o += 4;
179      vf_lo = vget_high_f16(vf);
180    }
181    if (batch & (2 * sizeof(__fp16))) {
182      vst1_f16(o, vf_lo); o += 2;
183      vf_lo = vext_f16(vf_lo, vf_lo, 2);
184    }
185    if (batch & (1 * sizeof(__fp16))) {
186      vst1_lane_f16(o, vf_lo, 0);
187    }
188  }
189}
190