xref: /aosp_15_r20/external/XNNPACK/src/f32-f16-vcvt/sse.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert SSE in [2, 4]
7$assert not AVX or SSE == 4
8$assert BATCH_TILE % 8 == 0
9$assert BATCH_TILE >= 8
10$SIMD_TILE = BATCH_TILE // 8
11$SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
12$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
13#include <assert.h>
14
15#include <${SSE_HEADER}>
16
17#include <xnnpack/common.h>
18#include <xnnpack/unaligned.h>
19#include <xnnpack/vcvt.h>
20
21
22$ISA = "avx" if AVX else {2: "sse2", 4: "sse41"}[SSE]
23void xnn_f32_f16_vcvt_ukernel__${ISA}_x${BATCH_TILE}(
24    size_t n,
25    const float* input,
26    void* output,
27    const union xnn_f32_f16_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
28{
29  assert(n != 0);
30  assert(n % sizeof(float) == 0);
31  assert(input != NULL);
32  assert(output != NULL);
33
34  const __m128 vnonsign_mask = _mm_load_ps((const float*) params->sse2.nonsign_mask);
35  const __m128i vexp_bias = _mm_load_si128((const __m128i*) params->sse2.exp_bias);
36  const __m128 vscale_to_inf = _mm_load_ps(params->sse2.scale_to_inf);
37  const __m128i vexpw_max = _mm_load_si128((const __m128i*) params->sse2.expw_max);
38  const __m128 vscale_to_zero = _mm_load_ps(params->sse2.scale_to_zero);
39  const __m128i vbias_min = _mm_load_si128((const __m128i*) params->sse2.bias_min);
40  const __m128i vmanth_mask = _mm_load_si128((const __m128i*) params->sse2.manth_mask);
41  const __m128i vexph_mask = _mm_load_si128((const __m128i*) params->sse2.exph_mask);
42  const __m128i vnanh = _mm_load_si128((const __m128i*) params->sse2.nanh);
43
44  uint16_t* o = (uint16_t*) output;
45  $if BATCH_TILE > 8:
46    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
47      const __m128 vx0 = _mm_loadu_ps(input);
48      $for N in range(1, 2*SIMD_TILE):
49        const __m128 vx${N} = _mm_loadu_ps(input + ${N * 4});
50      input += ${BATCH_TILE};
51
52      $for N in range(2*SIMD_TILE):
53        const __m128 vabsx${N} = _mm_and_ps(vx${N}, vnonsign_mask);
54
55      $for N in range(2*SIMD_TILE):
56        const __m128 vsignx${N} = _mm_xor_ps(vx${N}, vabsx${N});
57
58      $for N in range(2*SIMD_TILE):
59        __m128i vbias${N} = _mm_add_epi32(_mm_castps_si128(vabsx${N}), vexp_bias);
60
61      $for N in range(2*SIMD_TILE):
62        __m128 vf${N} = _mm_mul_ps(vabsx${N}, vscale_to_inf);
63
64      $for N in range(2*SIMD_TILE):
65        const __m128i vnanmaskw${N} = _mm_cmpgt_epi32(_mm_castps_si128(vabsx${N}), vexpw_max);
66
67      $for N in range(2*SIMD_TILE):
68        vbias${N} = _mm_and_si128(vbias${N}, vexpw_max);
69
70      $for N in range(2*SIMD_TILE):
71        vf${N} = _mm_mul_ps(vf${N}, vscale_to_zero);
72
73      $for N in range(SIMD_TILE):
74        const __m128i vnanmaskh${N} = _mm_packs_epi32(vnanmaskw${2*N}, vnanmaskw${2*N+1});
75
76      $for N in range(SIMD_TILE):
77        const __m128i vsignh${N} = _mm_packs_epi32(_mm_castps_si128(vsignx${2*N}), _mm_castps_si128(vsignx${2*N+1}));
78
79      $for N in range(2*SIMD_TILE):
80        vbias${N} = _mm_max_epi16(vbias${N}, vbias_min);
81
82      $if SSE < 4:
83        $for N in range(SIMD_TILE):
84          __m128i vh${N} = _mm_and_si128(vnanh, vnanmaskh${N});
85
86      $for N in range(2*SIMD_TILE):
87        vf${N} = _mm_add_ps(vf${N}, _mm_castsi128_ps(vbias${N}));
88
89      $if SSE < 4:
90        $for N in range(SIMD_TILE):
91          vh${N} = _mm_or_si128(vh${N}, vsignh${N});
92
93      $for N in range(2*SIMD_TILE):
94        __m128i vexpw${N} = _mm_srli_epi32(_mm_castps_si128(vf${N}), 13);
95
96      $for N in range(2*SIMD_TILE):
97        const __m128i vmantw${N} = _mm_and_si128(_mm_castps_si128(vf${N}), vmanth_mask);
98
99      $for N in range(2*SIMD_TILE):
100        vexpw${N} = _mm_and_si128(vexpw${N}, vexph_mask);
101
102      $for N in range(2*SIMD_TILE):
103        const __m128i vnonsignw${N} = _mm_add_epi32(vmantw${N}, vexpw${N});
104
105      $for N in range(SIMD_TILE):
106        const __m128i vnonsignh${N} = _mm_packs_epi32(vnonsignw${2*N}, vnonsignw${2*N+1});
107
108      $if SSE == 4:
109        $for N in range(SIMD_TILE):
110          const __m128i vabsh${N} = _mm_blendv_epi8(vnonsignh${N}, vnanh, vnanmaskh${N});
111
112        $for N in range(SIMD_TILE):
113          const __m128i vh${N} = _mm_or_si128(vabsh${N}, vsignh${N});
114      $else:
115        $for N in range(SIMD_TILE):
116          vh${N} = _mm_or_si128(vh${N}, _mm_andnot_si128(vnanmaskh${N}, vnonsignh${N}));
117
118      _mm_storeu_si128((__m128i*) o, vh0);
119      $for N in range(1, SIMD_TILE):
120        _mm_storeu_si128((__m128i*) (o + ${N * 8}), vh${N});
121      o += ${BATCH_TILE};
122    }
123  for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
124    const __m128 vx_lo = _mm_loadu_ps(input);
125    const __m128 vx_hi = _mm_loadu_ps(input + 4);
126    input += 8;
127
128    const __m128 vabsx_lo = _mm_and_ps(vx_lo, vnonsign_mask);
129    const __m128 vabsx_hi = _mm_and_ps(vx_hi, vnonsign_mask);
130
131    const __m128 vsignx_lo = _mm_xor_ps(vx_lo, vabsx_lo);
132    const __m128 vsignx_hi = _mm_xor_ps(vx_hi, vabsx_hi);
133    __m128i vbias_lo = _mm_add_epi32(_mm_castps_si128(vabsx_lo), vexp_bias);
134    __m128i vbias_hi = _mm_add_epi32(_mm_castps_si128(vabsx_hi), vexp_bias);
135    __m128 vf_lo = _mm_mul_ps(vabsx_lo, vscale_to_inf);
136    __m128 vf_hi = _mm_mul_ps(vabsx_hi, vscale_to_inf);
137    const __m128i vnanmaskw_lo = _mm_cmpgt_epi32(_mm_castps_si128(vabsx_lo), vexpw_max);
138    const __m128i vnanmaskw_hi = _mm_cmpgt_epi32(_mm_castps_si128(vabsx_hi), vexpw_max);
139
140    vbias_lo = _mm_and_si128(vbias_lo, vexpw_max);
141    vbias_hi = _mm_and_si128(vbias_hi, vexpw_max);
142    vf_lo = _mm_mul_ps(vf_lo, vscale_to_zero);
143    vf_hi = _mm_mul_ps(vf_hi, vscale_to_zero);
144    const __m128i vnanmaskh = _mm_packs_epi32(vnanmaskw_lo, vnanmaskw_hi);
145    const __m128i vsignh = _mm_packs_epi32(_mm_castps_si128(vsignx_lo), _mm_castps_si128(vsignx_hi));
146
147    vbias_lo = _mm_max_epi16(vbias_lo, vbias_min);
148    vbias_hi = _mm_max_epi16(vbias_hi, vbias_min);
149    $if SSE < 4:
150      __m128i vh = _mm_and_si128(vnanh, vnanmaskh);
151
152    vf_lo = _mm_add_ps(vf_lo, _mm_castsi128_ps(vbias_lo));
153    vf_hi = _mm_add_ps(vf_hi, _mm_castsi128_ps(vbias_hi));
154    $if SSE < 4:
155      vh = _mm_or_si128(vh, vsignh);
156
157    __m128i vexpw_lo = _mm_srli_epi32(_mm_castps_si128(vf_lo), 13);
158    __m128i vexpw_hi = _mm_srli_epi32(_mm_castps_si128(vf_hi), 13);
159    const __m128i vmantw_lo = _mm_and_si128(_mm_castps_si128(vf_lo), vmanth_mask);
160    const __m128i vmantw_hi = _mm_and_si128(_mm_castps_si128(vf_hi), vmanth_mask);
161
162    vexpw_lo = _mm_and_si128(vexpw_lo, vexph_mask);
163    vexpw_hi = _mm_and_si128(vexpw_hi, vexph_mask);
164
165    const __m128i vnonsignw_lo = _mm_add_epi32(vmantw_lo, vexpw_lo);
166    const __m128i vnonsignw_hi = _mm_add_epi32(vmantw_hi, vexpw_hi);
167
168    const __m128i vnonsignh = _mm_packs_epi32(vnonsignw_lo, vnonsignw_hi);
169
170    $if SSE == 4:
171      const __m128i vabsh = _mm_blendv_epi8(vnonsignh, vnanh, vnanmaskh);
172
173      const __m128i vh = _mm_or_si128(vabsh, vsignh);
174    $else:
175      vh = _mm_or_si128(vh, _mm_andnot_si128(vnanmaskh, vnonsignh));
176
177    _mm_storeu_si128((__m128i*) o, vh);
178    o += 8;
179  }
180  if XNN_UNPREDICTABLE(n != 0) {
181    const __m128 vx_lo = _mm_loadu_ps(input);
182    const float* input_hi = (const float*) ((uintptr_t) input + (n & (4 * sizeof(float))));
183    const __m128 vx_hi = _mm_loadu_ps(input_hi);
184
185    const __m128 vabsx_lo = _mm_and_ps(vx_lo, vnonsign_mask);
186    const __m128 vabsx_hi = _mm_and_ps(vx_hi, vnonsign_mask);
187
188    const __m128 vsignx_lo = _mm_xor_ps(vx_lo, vabsx_lo);
189    const __m128 vsignx_hi = _mm_xor_ps(vx_hi, vabsx_hi);
190    __m128i vbias_lo = _mm_add_epi32(_mm_castps_si128(vabsx_lo), vexp_bias);
191    __m128i vbias_hi = _mm_add_epi32(_mm_castps_si128(vabsx_hi), vexp_bias);
192    __m128 vf_lo = _mm_mul_ps(vabsx_lo, vscale_to_inf);
193    __m128 vf_hi = _mm_mul_ps(vabsx_hi, vscale_to_inf);
194    const __m128i vnanmaskw_lo = _mm_cmpgt_epi32(_mm_castps_si128(vabsx_lo), vexpw_max);
195    const __m128i vnanmaskw_hi = _mm_cmpgt_epi32(_mm_castps_si128(vabsx_hi), vexpw_max);
196
197    vbias_lo = _mm_and_si128(vbias_lo, vexpw_max);
198    vbias_hi = _mm_and_si128(vbias_hi, vexpw_max);
199    vf_lo = _mm_mul_ps(vf_lo, vscale_to_zero);
200    vf_hi = _mm_mul_ps(vf_hi, vscale_to_zero);
201    const __m128i vnanmaskh = _mm_packs_epi32(vnanmaskw_lo, vnanmaskw_hi);
202    const __m128i vsignh = _mm_packs_epi32(_mm_castps_si128(vsignx_lo), _mm_castps_si128(vsignx_hi));
203
204    vbias_lo = _mm_max_epi16(vbias_lo, vbias_min);
205    vbias_hi = _mm_max_epi16(vbias_hi, vbias_min);
206    $if SSE < 4:
207      __m128i vh = _mm_and_si128(vnanh, vnanmaskh);
208
209    vf_lo = _mm_add_ps(vf_lo, _mm_castsi128_ps(vbias_lo));
210    vf_hi = _mm_add_ps(vf_hi, _mm_castsi128_ps(vbias_hi));
211    $if SSE < 4:
212      vh = _mm_or_si128(vh, vsignh);
213
214    __m128i vexpw_lo = _mm_srli_epi32(_mm_castps_si128(vf_lo), 13);
215    __m128i vexpw_hi = _mm_srli_epi32(_mm_castps_si128(vf_hi), 13);
216    const __m128i vmantw_lo = _mm_and_si128(_mm_castps_si128(vf_lo), vmanth_mask);
217    const __m128i vmantw_hi = _mm_and_si128(_mm_castps_si128(vf_hi), vmanth_mask);
218
219    vexpw_lo = _mm_and_si128(vexpw_lo, vexph_mask);
220    vexpw_hi = _mm_and_si128(vexpw_hi, vexph_mask);
221
222    const __m128i vnonsignw_lo = _mm_add_epi32(vmantw_lo, vexpw_lo);
223    const __m128i vnonsignw_hi = _mm_add_epi32(vmantw_hi, vexpw_hi);
224
225    const __m128i vnonsignh = _mm_packs_epi32(vnonsignw_lo, vnonsignw_hi);
226
227    $if SSE == 4:
228      const __m128i vabsh = _mm_blendv_epi8(vnonsignh, vnanh, vnanmaskh);
229
230      __m128i vh = _mm_or_si128(vabsh, vsignh);
231    $else:
232      vh = _mm_or_si128(vh, _mm_andnot_si128(vnanmaskh, vnonsignh));
233
234    if (n & (4 * sizeof(float))) {
235      _mm_storel_epi64((__m128i*) o, vh);
236      vh = _mm_unpackhi_epi64(vh, vh);
237      o += 4;
238    }
239    if (n & (2 * sizeof(float))) {
240      unaligned_store_u32(o, (uint32_t) _mm_cvtsi128_si32(vh));
241      vh = _mm_srli_epi64(vh, 32);
242      o += 2;
243    }
244    if (n & (1 * sizeof(float))) {
245      $if SSE == 4:
246        *o = (uint16_t) _mm_extract_epi16(vh, 0);
247      $else:
248        *o = (uint16_t) _mm_cvtsi128_si32(vh);
249    }
250  }
251}
252