xref: /aosp_15_r20/external/XNNPACK/src/f32-f16-vcvt/scalar-bitcast.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE >= 1
7#include <assert.h>
8
9#include <xnnpack/common.h>
10#include <xnnpack/math.h>
11#include <xnnpack/vcvt.h>
12
13
14void xnn_f32_f16_vcvt_ukernel__scalar_bitcast_x${BATCH_TILE}(
15    size_t n,
16    const float* input,
17    void* output,
18    const union xnn_f32_f16_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
19{
20  assert(n != 0);
21  assert(n % sizeof(float) == 0);
22  assert(input != NULL);
23  assert(output != NULL);
24
25  const uint32_t vnonsign_mask = params->scalar_bitcast.nonsign_mask;
26  const uint32_t vexp_bias = params->scalar_bitcast.exp_bias;
27  const float vscale_to_inf = params->scalar_bitcast.scale_to_inf;
28  const uint32_t vexpw_max = params->scalar_bitcast.expw_max;
29  const float vscale_to_zero = params->scalar_bitcast.scale_to_zero;
30  const uint32_t vbias_min = params->scalar_bitcast.bias_min;
31  const uint16_t vexph_mask = params->scalar_bitcast.exph_mask;
32  const uint16_t vmanth_mask = params->scalar_bitcast.manth_mask;
33  const uint16_t vnanh = params->scalar_bitcast.nanh;
34
35  const uint32_t* i = (const uint32_t*) input;
36  uint16_t* o = (uint16_t*) output;
37  $if BATCH_TILE > 1:
38    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
39      $for N in range(BATCH_TILE):
40        const uint32_t vw${N} = i[${N}];
41      i += ${BATCH_TILE};
42
43      $for N in range(BATCH_TILE):
44        const uint32_t vnonsignw${N} = vw${N} & vnonsign_mask;
45
46      $for N in range(BATCH_TILE):
47        float vf${N} = uint32_as_float(vnonsignw${N});
48      $for N in range(BATCH_TILE):
49        const uint32_t vsignw${N} = vw${N} ^ vnonsignw${N};
50      $for N in range(BATCH_TILE):
51        uint32_t vbias${N} = vnonsignw${N} + vexp_bias;
52
53      $for N in range(BATCH_TILE):
54        vf${N} *= vscale_to_inf;
55      $for N in range(BATCH_TILE):
56        vbias${N} &= vexpw_max;
57
58      $for N in range(BATCH_TILE):
59        vf${N} *= vscale_to_zero;
60      $for N in range(BATCH_TILE):
61        vbias${N} = math_max_u32(vbias${N}, vbias_min);
62
63      $for N in range(BATCH_TILE):
64        vf${N} += uint32_as_float(vbias${N});
65
66      $for N in range(BATCH_TILE):
67        const uint32_t vbits${N} = float_as_uint32(vf${N});
68
69      $for N in range(BATCH_TILE):
70        const uint16_t vexph${N} = (uint16_t) (vbits${N} >> 13) & vexph_mask;
71      $for N in range(BATCH_TILE):
72        const uint16_t vmanth${N} = (uint16_t) vbits${N} & vmanth_mask;
73      $for N in range(BATCH_TILE):
74        const uint16_t vsignh${N} = (uint16_t) (vsignw${N} >> 16);
75
76      $for N in range(BATCH_TILE):
77        uint16_t vh${N} = vexph${N} + vmanth${N};
78      $for N in range(BATCH_TILE):
79        if XNN_UNPREDICTABLE(vnonsignw${N} > vexpw_max) {
80          vh${N} = vnanh;
81        }
82      $for N in range(BATCH_TILE):
83        vh${N} |= vsignh${N};
84
85      $for N in range(BATCH_TILE):
86        o[${N}] = vh${N};
87      o += ${BATCH_TILE};
88    }
89  $if BATCH_TILE == 1:
90    do {
91      const uint32_t vw = *i++;
92
93      const uint32_t vnonsignw = vw & vnonsign_mask;
94
95      float vf = uint32_as_float(vnonsignw);
96      const uint32_t vsignw = vw ^ vnonsignw;
97      uint32_t vbias = vnonsignw + vexp_bias;
98
99      vf *= vscale_to_inf;
100      vbias &= vexpw_max;
101
102      vf *= vscale_to_zero;
103      vbias = math_max_u32(vbias, vbias_min);
104
105      vf += uint32_as_float(vbias);
106
107      const uint32_t vbits = float_as_uint32(vf);
108
109      const uint16_t vexph = (uint16_t) (vbits >> 13) & vexph_mask;
110      const uint16_t vmanth = (uint16_t) vbits & vmanth_mask;
111      const uint16_t vsignh = (uint16_t) (vsignw >> 16);
112
113      uint16_t vh = vexph + vmanth;
114      if XNN_UNPREDICTABLE(vnonsignw > vexpw_max) {
115        vh = vnanh;
116      }
117      vh |= vsignh;
118
119      *o++ = vh;
120
121      n -= sizeof(float);
122    } while (n != 0);
123  $elif BATCH_TILE == 2:
124    if XNN_UNLIKELY(n != 0) {
125      const uint32_t vw = *i;
126
127      const uint32_t vnonsignw = vw & vnonsign_mask;
128
129      float vf = uint32_as_float(vnonsignw);
130      const uint32_t vsignw = vw ^ vnonsignw;
131      uint32_t vbias = vnonsignw + vexp_bias;
132
133      vf *= vscale_to_inf;
134      vbias &= vexpw_max;
135
136      vf *= vscale_to_zero;
137      vbias = math_max_u32(vbias, vbias_min);
138
139      vf += uint32_as_float(vbias);
140
141      const uint32_t vbits = float_as_uint32(vf);
142
143      const uint16_t vexph = (uint16_t) (vbits >> 13) & vexph_mask;
144      const uint16_t vmanth = (uint16_t) vbits & vmanth_mask;
145      const uint16_t vsignh = (uint16_t) (vsignw >> 16);
146
147      uint16_t vh = vexph + vmanth;
148      if XNN_UNPREDICTABLE(vnonsignw > vexpw_max) {
149        vh = vnanh;
150      }
151      vh |= vsignh;
152
153      *o = vh;
154    }
155  $else:
156    if XNN_UNLIKELY(n != 0) {
157      do {
158        const uint32_t vw = *i++;
159
160        const uint32_t vnonsignw = vw & vnonsign_mask;
161
162        float vf = uint32_as_float(vnonsignw);
163        const uint32_t vsignw = vw ^ vnonsignw;
164        uint32_t vbias = vnonsignw + vexp_bias;
165
166        vf *= vscale_to_inf;
167        vbias &= vexpw_max;
168
169        vf *= vscale_to_zero;
170        vbias = math_max_u32(vbias, vbias_min);
171
172        vf += uint32_as_float(vbias);
173
174        const uint32_t vbits = float_as_uint32(vf);
175
176        const uint16_t vexph = (uint16_t) (vbits >> 13) & vexph_mask;
177        const uint16_t vmanth = (uint16_t) vbits & vmanth_mask;
178        const uint16_t vsignh = (uint16_t) (vsignw >> 16);
179
180        uint16_t vh = vexph + vmanth;
181        if XNN_UNPREDICTABLE(vnonsignw > vexpw_max) {
182          vh = vnanh;
183        }
184        vh |= vsignh;
185
186        *o++ = vh;
187
188        n -= sizeof(float);
189      } while (n != 0);
190    }
191}
192