xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_fp16_conv.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir_builder.h"
25 
26 /* The following float-to-half conversion routines are based on the "half" library:
27  * https://sourceforge.net/projects/half/
28  *
29  * half - IEEE 754-based half-precision floating-point library.
30  *
31  * Copyright (c) 2012-2019 Christian Rau <[email protected]>
32  *
33  * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
34  * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
35  * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
36  * Software is furnished to do so, subject to the following conditions:
37  *
38  * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
39  *
40  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
41  * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
42  * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
43  * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
44  *
45  * Version 2.1.0
46  */
47 
48 static nir_def *
half_rounded(nir_builder * b,nir_def * value,nir_def * guard,nir_def * sticky,nir_def * sign,nir_rounding_mode mode)49 half_rounded(nir_builder *b, nir_def *value, nir_def *guard, nir_def *sticky,
50              nir_def *sign, nir_rounding_mode mode)
51 {
52    switch (mode) {
53    case nir_rounding_mode_rtne:
54       return nir_iadd(b, value, nir_iand(b, guard, nir_ior(b, sticky, value)));
55    case nir_rounding_mode_ru:
56       sign = nir_ushr_imm(b, sign, 31);
57       return nir_iadd(b, value, nir_iand(b, nir_inot(b, sign), nir_ior(b, guard, sticky)));
58    case nir_rounding_mode_rd:
59       sign = nir_ushr_imm(b, sign, 31);
60       return nir_iadd(b, value, nir_iand(b, sign, nir_ior(b, guard, sticky)));
61    default:
62       return value;
63    }
64 }
65 
66 static nir_def *
float_to_half_impl(nir_builder * b,nir_def * src,nir_rounding_mode mode)67 float_to_half_impl(nir_builder *b, nir_def *src, nir_rounding_mode mode)
68 {
69    nir_def *f32infinity = nir_imm_int(b, 255 << 23);
70    nir_def *f16max = nir_imm_int(b, (127 + 16) << 23);
71 
72    nir_def *sign = nir_iand_imm(b, src, 0x80000000);
73    nir_def *one = nir_imm_int(b, 1);
74 
75    nir_def *abs = nir_iand_imm(b, src, 0x7FFFFFFF);
76    /* NaN or INF. For rtne, overflow also becomes INF, so combine the comparisons */
77    nir_push_if(b, nir_ige(b, abs, mode == nir_rounding_mode_rtne ? f16max : f32infinity));
78    nir_def *inf_nanfp16 = nir_bcsel(b,
79                                     nir_ilt(b, f32infinity, abs),
80                                     nir_imm_int(b, 0x7E00),
81                                     nir_imm_int(b, 0x7C00));
82    nir_push_else(b, NULL);
83 
84    nir_def *overflowed_fp16 = NULL;
85    if (mode != nir_rounding_mode_rtne) {
86       /* Handle overflow */
87       nir_push_if(b, nir_ige(b, abs, f16max));
88       switch (mode) {
89       case nir_rounding_mode_rtz:
90          overflowed_fp16 = nir_imm_int(b, 0x7BFF);
91          break;
92       case nir_rounding_mode_ru:
93          /* Negative becomes max float, positive becomes inf */
94          overflowed_fp16 = nir_bcsel(b, nir_i2b(b, sign), nir_imm_int(b, 0x7BFF), nir_imm_int(b, 0x7C00));
95          break;
96       case nir_rounding_mode_rd:
97          /* Negative becomes inf, positive becomes max float */
98          overflowed_fp16 = nir_bcsel(b, nir_i2b(b, sign), nir_imm_int(b, 0x7C00), nir_imm_int(b, 0x7BFF));
99          break;
100       default:
101          unreachable("Should've been handled already");
102       }
103       nir_push_else(b, NULL);
104    }
105 
106    nir_def *zero = nir_imm_int(b, 0);
107 
108    nir_push_if(b, nir_ige_imm(b, abs, 113 << 23));
109 
110    /* FP16 will be normal */
111    nir_def *value = nir_ior(b,
112                             nir_ishl_imm(b,
113                                          nir_iadd_imm(b,
114                                                       nir_ushr_imm(b, abs, 23),
115                                                       -112),
116                                          10),
117                             nir_iand_imm(b, nir_ushr_imm(b, abs, 13), 0x3FFF));
118    nir_def *guard = nir_iand(b, nir_ushr_imm(b, abs, 12), one);
119    nir_def *sticky = nir_bcsel(b, nir_ine(b, nir_iand_imm(b, abs, 0xFFF), zero), one, zero);
120    nir_def *normal_fp16 = half_rounded(b, value, guard, sticky, sign, mode);
121 
122    nir_push_else(b, NULL);
123    nir_push_if(b, nir_ige_imm(b, abs, 102 << 23));
124 
125    /* FP16 will be denormal */
126    nir_def *i = nir_isub_imm(b, 125, nir_ushr_imm(b, abs, 23));
127    nir_def *masked = nir_ior_imm(b, nir_iand_imm(b, abs, 0x7FFFFF), 0x800000);
128    value = nir_ushr(b, masked, nir_iadd(b, i, one));
129    guard = nir_iand(b, nir_ushr(b, masked, i), one);
130    sticky = nir_bcsel(b, nir_ine(b, nir_iand(b, masked, nir_isub(b, nir_ishl(b, one, i), one)), zero), one, zero);
131    nir_def *denormal_fp16 = half_rounded(b, value, guard, sticky, sign, mode);
132 
133    nir_push_else(b, NULL);
134 
135    /* Handle underflow. Nonzero values need to shift up or down for round-up or round-down */
136    nir_def *underflowed_fp16 = zero;
137    if (mode == nir_rounding_mode_ru ||
138        mode == nir_rounding_mode_rd) {
139       nir_push_if(b, nir_i2b(b, abs));
140 
141       if (mode == nir_rounding_mode_ru)
142          underflowed_fp16 = nir_bcsel(b, nir_i2b(b, sign), zero, one);
143       else
144          underflowed_fp16 = nir_bcsel(b, nir_i2b(b, sign), one, zero);
145 
146       nir_push_else(b, NULL);
147       nir_pop_if(b, NULL);
148       underflowed_fp16 = nir_if_phi(b, underflowed_fp16, zero);
149    }
150 
151    nir_pop_if(b, NULL);
152    nir_def *underflowed_or_denorm_fp16 = nir_if_phi(b, denormal_fp16, underflowed_fp16);
153 
154    nir_pop_if(b, NULL);
155    nir_def *finite_fp16 = nir_if_phi(b, normal_fp16, underflowed_or_denorm_fp16);
156 
157    nir_def *finite_or_overflowed_fp16 = finite_fp16;
158    if (mode != nir_rounding_mode_rtne) {
159       nir_pop_if(b, NULL);
160       finite_or_overflowed_fp16 = nir_if_phi(b, overflowed_fp16, finite_fp16);
161    }
162 
163    nir_pop_if(b, NULL);
164    nir_def *fp16 = nir_if_phi(b, inf_nanfp16, finite_or_overflowed_fp16);
165 
166    return nir_u2u16(b, nir_ior(b, fp16, nir_ushr_imm(b, sign, 16)));
167 }
168 
169 static nir_def *
split_f2f16_conversion(nir_builder * b,nir_def * src,nir_rounding_mode rnd)170 split_f2f16_conversion(nir_builder *b, nir_def *src, nir_rounding_mode rnd)
171 {
172    nir_def *tmp = nir_f2f32(b, src);
173 
174    if (rnd == nir_rounding_mode_rtne) {
175       /* We round down from double to half float by going through float in
176        * between, but this can give us inaccurate results in some cases. One
177        * such case is 0x40ee6a0000000001, which should round to 0x7b9b, but
178        * going through float first turns into 0x7b9a instead. This is because
179        * the first non-fitting bit is set, so we get a tie, but with the least
180        * significant bit of the original number set, the tie should break
181        * rounding up. The cast to float, however, turns into 0x47735000, which
182        * when going to half still ties, but now we lost the tie-up bit, and
183        * instead we round to the nearest even, which in this case is down.
184        *
185        * To fix this, we check if the original would have tied, and if the tie
186        * would have rounded up, and if both are true, set the least
187        * significant bit of the intermediate float to 1, so that a tie on the
188        * next cast rounds up as well. If the rounding already got rid of the
189        * tie, that set bit will just be truncated anyway and the end result
190        * doesn't change.
191        *
192        * Another failing case is 0x40effdffffffffff. This one doesn't have the
193        * tie from double to half, so it just rounds down to 0x7bff (65504.0),
194        * but going through float first, it turns into 0x477ff000, which does
195        * have the tie bit for half set, and when that one gets rounded it
196        * turns into 0x7c00 (Infinity).
197        * The fix for that one is to make sure the intermediate float does not
198        * have the tie bit set if the original didn't have it.
199        *
200        * For the RTZ case, we don't need to do anything, as the intermediate
201        * float should be ok already.
202        */
203       int significand_bits16 = 10;
204       int significand_bits32 = 23;
205       int significand_bits64 = 52;
206       int f64_to_16_tie_bit = significand_bits64 - significand_bits16 - 1;
207       int f32_to_16_tie_bit = significand_bits32 - significand_bits16 - 1;
208       uint64_t f64_rounds_up_mask = ((1ULL << f64_to_16_tie_bit) - 1);
209 
210       nir_def *would_tie = nir_iand_imm(b, src, 1ULL << f64_to_16_tie_bit);
211       nir_def *would_rnd_up = nir_iand_imm(b, src, f64_rounds_up_mask);
212 
213       nir_def *tie_up = nir_b2i32(b, nir_ine_imm(b, would_rnd_up, 0));
214 
215       nir_def *break_tie = nir_bcsel(b,
216                                      nir_ine_imm(b, would_tie, 0),
217                                      nir_imm_int(b, ~0),
218                                      nir_imm_int(b, ~(1U << f32_to_16_tie_bit)));
219 
220       tmp = nir_ior(b, tmp, tie_up);
221       tmp = nir_iand(b, tmp, break_tie);
222    }
223 
224    return tmp;
225 }
226 
227 static bool
lower_fp16_cast_impl(nir_builder * b,nir_instr * instr,void * data)228 lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
229 {
230    nir_lower_fp16_cast_options options = *(nir_lower_fp16_cast_options *)data;
231    nir_src *src;
232    nir_def *dst;
233    uint8_t *swizzle = NULL;
234    nir_rounding_mode mode = nir_rounding_mode_undef;
235 
236    if (instr->type == nir_instr_type_alu) {
237       nir_alu_instr *alu = nir_instr_as_alu(instr);
238       src = &alu->src[0].src;
239       swizzle = alu->src[0].swizzle;
240       dst = &alu->def;
241       switch (alu->op) {
242       case nir_op_f2f16:
243          if (b->shader->info.float_controls_execution_mode & FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP16)
244             mode = nir_rounding_mode_rtz;
245          else if (b->shader->info.float_controls_execution_mode & FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP16)
246             mode = nir_rounding_mode_rtne;
247          break;
248       case nir_op_f2f16_rtne:
249          mode = nir_rounding_mode_rtne;
250          break;
251       case nir_op_f2f16_rtz:
252          mode = nir_rounding_mode_rtz;
253          break;
254       case nir_op_f2f64:
255          if (src->ssa->bit_size == 16 && (options & nir_lower_fp16_split_fp64)) {
256             b->cursor = nir_before_instr(instr);
257             nir_src_rewrite(src, nir_f2f32(b, src->ssa));
258             return true;
259          }
260          return false;
261       default:
262          return false;
263       }
264    } else if (instr->type == nir_instr_type_intrinsic) {
265       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
266       if (intrin->intrinsic != nir_intrinsic_convert_alu_types)
267          return false;
268 
269       src = &intrin->src[0];
270       dst = &intrin->def;
271       mode = nir_intrinsic_rounding_mode(intrin);
272 
273       if (nir_intrinsic_src_type(intrin) == nir_type_float16 &&
274           nir_intrinsic_dest_type(intrin) == nir_type_float64 &&
275           (options & nir_lower_fp16_split_fp64)) {
276          b->cursor = nir_before_instr(instr);
277          nir_src_rewrite(src, nir_f2f32(b, src->ssa));
278          return true;
279       }
280 
281       if (nir_intrinsic_dest_type(intrin) != nir_type_float16)
282          return false;
283    } else {
284       return false;
285    }
286 
287    bool progress = false;
288    if (src->ssa->bit_size == 64 && (options & nir_lower_fp16_split_fp64)) {
289       b->cursor = nir_before_instr(instr);
290       nir_src_rewrite(src, split_f2f16_conversion(b, src->ssa, mode));
291       if (instr->type == nir_instr_type_intrinsic)
292          nir_intrinsic_set_src_type(nir_instr_as_intrinsic(instr), nir_type_float32);
293       progress = true;
294    }
295 
296    nir_lower_fp16_cast_options req_option = 0;
297    switch (mode) {
298    case nir_rounding_mode_rtz:
299       req_option = nir_lower_fp16_rtz;
300       break;
301    case nir_rounding_mode_rtne:
302       req_option = nir_lower_fp16_rtne;
303       break;
304    case nir_rounding_mode_ru:
305       req_option = nir_lower_fp16_ru;
306       break;
307    case nir_rounding_mode_rd:
308       req_option = nir_lower_fp16_rd;
309       break;
310    case nir_rounding_mode_undef:
311       if ((options & nir_lower_fp16_all) == nir_lower_fp16_all) {
312          /* Pick one arbitrarily for lowering */
313          mode = nir_rounding_mode_rtne;
314          req_option = nir_lower_fp16_rtne;
315       }
316       /* Otherwise assume the backend can handle f2f16 with undef rounding */
317       break;
318    default:
319       unreachable("Invalid rounding mode");
320    }
321    if (!(options & req_option))
322       return progress;
323 
324    b->cursor = nir_before_instr(instr);
325    nir_def *rets[NIR_MAX_VEC_COMPONENTS] = { NULL };
326 
327    for (unsigned i = 0; i < dst->num_components; i++) {
328       nir_def *comp = nir_channel(b, src->ssa, swizzle ? swizzle[i] : i);
329       if (comp->bit_size == 64)
330          comp = split_f2f16_conversion(b, comp, mode);
331       rets[i] = float_to_half_impl(b, comp, mode);
332    }
333 
334    nir_def *new_val = nir_vec(b, rets, dst->num_components);
335    nir_def_rewrite_uses(dst, new_val);
336    return true;
337 }
338 
339 bool
nir_lower_fp16_casts(nir_shader * shader,nir_lower_fp16_cast_options options)340 nir_lower_fp16_casts(nir_shader *shader, nir_lower_fp16_cast_options options)
341 {
342    return nir_shader_instructions_pass(shader,
343                                        lower_fp16_cast_impl,
344                                        nir_metadata_none,
345                                        &options);
346 }
347