xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_frexp.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2015 Intel Corporation
3  * Copyright © 2019 Valve Corporation
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22  * IN THE SOFTWARE.
23  */
24 
25 #include "nir.h"
26 #include "nir_builder.h"
27 
28 static nir_def *
lower_frexp_sig(nir_builder * b,nir_def * x)29 lower_frexp_sig(nir_builder *b, nir_def *x)
30 {
31    nir_def *abs_x = nir_fabs(b, x);
32    nir_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
33    nir_def *sign_mantissa_mask, *exponent_value;
34 
35    switch (x->bit_size) {
36    case 16:
37       /* Half-precision floating-point values are stored as
38        *   1 sign bit;
39        *   5 exponent bits;
40        *   10 mantissa bits.
41        *
42        * An exponent shift of 10 will shift the mantissa out, leaving only the
43        * exponent and sign bit (which itself may be zero, if the absolute value
44        * was taken before the bitcast and shift).
45        */
46       sign_mantissa_mask = nir_imm_intN_t(b, 0x83ffu, 16);
47       /* Exponent of floating-point values in the range [0.5, 1.0). */
48       exponent_value = nir_imm_intN_t(b, 0x3800u, 16);
49       break;
50    case 32:
51       /* Single-precision floating-point values are stored as
52        *   1 sign bit;
53        *   8 exponent bits;
54        *   23 mantissa bits.
55        *
56        * An exponent shift of 23 will shift the mantissa out, leaving only the
57        * exponent and sign bit (which itself may be zero, if the absolute value
58        * was taken before the bitcast and shift.
59        */
60       sign_mantissa_mask = nir_imm_int(b, 0x807fffffu);
61       /* Exponent of floating-point values in the range [0.5, 1.0). */
62       exponent_value = nir_imm_int(b, 0x3f000000u);
63       break;
64    case 64:
65       /* Double-precision floating-point values are stored as
66        *   1 sign bit;
67        *   11 exponent bits;
68        *   52 mantissa bits.
69        *
70        * An exponent shift of 20 will shift the remaining mantissa bits out,
71        * leaving only the exponent and sign bit (which itself may be zero, if
72        * the absolute value was taken before the bitcast and shift.
73        */
74       sign_mantissa_mask = nir_imm_int(b, 0x800fffffu);
75       /* Exponent of floating-point values in the range [0.5, 1.0). */
76       exponent_value = nir_imm_int(b, 0x3fe00000u);
77       break;
78    default:
79       unreachable("Invalid bitsize");
80    }
81 
82    if (x->bit_size == 64) {
83       /* We only need to deal with the exponent so first we extract the upper
84        * 32 bits using nir_unpack_64_2x32_split_y.
85        */
86       nir_def *upper_x = nir_unpack_64_2x32_split_y(b, x);
87 
88       /* If x is ±0, ±Inf, or NaN, return x unmodified. */
89       nir_def *new_upper =
90          nir_bcsel(b,
91                    nir_iand(b,
92                             nir_flt(b, zero, abs_x),
93                             nir_fisfinite(b, x)),
94                    nir_ior(b,
95                            nir_iand(b, upper_x, sign_mantissa_mask),
96                            exponent_value),
97                    upper_x);
98 
99       nir_def *lower_x = nir_unpack_64_2x32_split_x(b, x);
100 
101       return nir_pack_64_2x32_split(b, lower_x, new_upper);
102    } else {
103       /* If x is ±0, ±Inf, or NaN, return x unmodified. */
104       return nir_bcsel(b,
105                        nir_iand(b,
106                                 nir_flt(b, zero, abs_x),
107                                 nir_fisfinite(b, x)),
108                        nir_ior(b,
109                                nir_iand(b, x, sign_mantissa_mask),
110                                exponent_value),
111                        x);
112    }
113 }
114 
115 static nir_def *
lower_frexp_exp(nir_builder * b,nir_def * x)116 lower_frexp_exp(nir_builder *b, nir_def *x)
117 {
118    nir_def *abs_x = nir_fabs(b, x);
119    nir_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
120    nir_def *is_not_zero = nir_fneu(b, abs_x, zero);
121    nir_def *exponent;
122 
123    switch (x->bit_size) {
124    case 16: {
125       nir_def *exponent_shift = nir_imm_int(b, 10);
126       nir_def *exponent_bias = nir_imm_intN_t(b, -14, 16);
127 
128       /* Significand return must be of the same type as the input, but the
129        * exponent must be a 32-bit integer.
130        */
131       exponent = nir_i2i32(b, nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
132                                        nir_bcsel(b, is_not_zero, exponent_bias, zero)));
133       break;
134    }
135    case 32: {
136       nir_def *exponent_shift = nir_imm_int(b, 23);
137       nir_def *exponent_bias = nir_imm_int(b, -126);
138 
139       exponent = nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
140                           nir_bcsel(b, is_not_zero, exponent_bias, zero));
141       break;
142    }
143    case 64: {
144       nir_def *exponent_shift = nir_imm_int(b, 20);
145       nir_def *exponent_bias = nir_imm_int(b, -1022);
146 
147       nir_def *zero32 = nir_imm_int(b, 0);
148       nir_def *abs_upper_x = nir_unpack_64_2x32_split_y(b, abs_x);
149 
150       exponent = nir_iadd(b, nir_ushr(b, abs_upper_x, exponent_shift),
151                           nir_bcsel(b, is_not_zero, exponent_bias, zero32));
152       break;
153    }
154    default:
155       unreachable("Invalid bitsize");
156    }
157 
158    return exponent;
159 }
160 
161 static bool
lower_frexp_instr(nir_builder * b,nir_instr * instr,UNUSED void * cb_data)162 lower_frexp_instr(nir_builder *b, nir_instr *instr, UNUSED void *cb_data)
163 {
164    if (instr->type != nir_instr_type_alu)
165       return false;
166 
167    nir_alu_instr *alu_instr = nir_instr_as_alu(instr);
168    nir_def *lower;
169 
170    b->cursor = nir_before_instr(instr);
171 
172    switch (alu_instr->op) {
173    case nir_op_frexp_sig:
174       lower = lower_frexp_sig(b, nir_ssa_for_alu_src(b, alu_instr, 0));
175       break;
176    case nir_op_frexp_exp:
177       lower = lower_frexp_exp(b, nir_ssa_for_alu_src(b, alu_instr, 0));
178       break;
179    default:
180       return false;
181    }
182 
183    nir_def_rewrite_uses(&alu_instr->def, lower);
184    nir_instr_remove(instr);
185    return true;
186 }
187 
188 bool
nir_lower_frexp(nir_shader * shader)189 nir_lower_frexp(nir_shader *shader)
190 {
191    return nir_shader_instructions_pass(shader, lower_frexp_instr,
192                                        nir_metadata_control_flow,
193                                        NULL);
194 }
195