xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_int_to_float.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2018 Intel Corporation
3  * Copyright © 2019 Vasily Khoruzhick <[email protected]>
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22  * IN THE SOFTWARE.
23  */
24 
25 #include "nir.h"
26 #include "nir_builder.h"
27 
28 static bool
assert_ssa_def_is_not_int(nir_def * def,void * arg)29 assert_ssa_def_is_not_int(nir_def *def, void *arg)
30 {
31    ASSERTED BITSET_WORD *int_types = arg;
32    assert(!BITSET_TEST(int_types, def->index));
33    return true;
34 }
35 
36 static bool
instr_has_only_trivial_swizzles(nir_alu_instr * alu)37 instr_has_only_trivial_swizzles(nir_alu_instr *alu)
38 {
39    const nir_op_info *info = &nir_op_infos[alu->op];
40 
41    for (unsigned i = 0; i < info->num_inputs; i++) {
42       for (unsigned chan = 0; chan < alu->def.num_components; chan++) {
43          if (alu->src[i].swizzle[chan] != chan)
44             return false;
45       }
46    }
47    return true;
48 }
49 
50 /* Recognize the y = x - ffract(x) patterns from lowered ffloor.
51  * It only works for the simple case when no swizzling is involved.
52  */
53 static bool
check_for_lowered_ffloor(nir_alu_instr * fadd)54 check_for_lowered_ffloor(nir_alu_instr *fadd)
55 {
56    if (!instr_has_only_trivial_swizzles(fadd))
57       return false;
58 
59    nir_alu_instr *fneg = NULL;
60    nir_src x;
61    for (unsigned i = 0; i < 2; i++) {
62       nir_alu_instr *fadd_src_alu = nir_src_as_alu_instr(fadd->src[i].src);
63       if (fadd_src_alu && fadd_src_alu->op == nir_op_fneg) {
64          fneg = fadd_src_alu;
65          x = fadd->src[1 - i].src;
66       }
67    }
68 
69    if (!fneg || !instr_has_only_trivial_swizzles(fneg))
70       return false;
71 
72    nir_alu_instr *ffract = nir_src_as_alu_instr(fneg->src[0].src);
73    if (ffract && ffract->op == nir_op_ffract &&
74        nir_srcs_equal(ffract->src[0].src, x) &&
75        instr_has_only_trivial_swizzles(ffract))
76       return true;
77 
78    return false;
79 }
80 
81 static bool
lower_alu_instr(nir_builder * b,nir_alu_instr * alu)82 lower_alu_instr(nir_builder *b, nir_alu_instr *alu)
83 {
84    const nir_op_info *info = &nir_op_infos[alu->op];
85 
86    bool is_bool_only = alu->def.bit_size == 1;
87    for (unsigned i = 0; i < info->num_inputs; i++) {
88       if (alu->src[i].src.ssa->bit_size != 1)
89          is_bool_only = false;
90    }
91 
92    if (is_bool_only) {
93       /* avoid lowering integers ops are used for booleans (ieq,ine,etc) */
94       return false;
95    }
96 
97    b->cursor = nir_before_instr(&alu->instr);
98 
99    /* Replacement SSA value */
100    nir_def *rep = NULL;
101    switch (alu->op) {
102    case nir_op_mov:
103    case nir_op_vec2:
104    case nir_op_vec3:
105    case nir_op_vec4:
106    case nir_op_bcsel:
107       /* These we expect to have integers but the opcode doesn't change */
108       break;
109 
110    case nir_op_b2i32:
111       alu->op = nir_op_b2f32;
112       break;
113    case nir_op_i2f32:
114       alu->op = nir_op_mov;
115       break;
116    case nir_op_u2f32:
117       alu->op = nir_op_mov;
118       break;
119 
120    case nir_op_f2i32: {
121       alu->op = nir_op_ftrunc;
122 
123       /* If the source was already integer, then we did't need to truncate and
124        * can switch it to a mov that can be copy-propagated away.
125        */
126       nir_alu_instr *src_alu = nir_src_as_alu_instr(alu->src[0].src);
127       if (src_alu) {
128          switch (src_alu->op) {
129          /* Check for the y = x - ffract(x) patterns from lowered ffloor. */
130          case nir_op_fadd:
131             if (check_for_lowered_ffloor(src_alu))
132                alu->op = nir_op_mov;
133             break;
134          case nir_op_fround_even:
135          case nir_op_fceil:
136          case nir_op_ftrunc:
137          case nir_op_ffloor:
138             alu->op = nir_op_mov;
139             break;
140          default:
141             break;
142          }
143       }
144       break;
145    }
146 
147    case nir_op_f2u32:
148       alu->op = nir_op_ffloor;
149       break;
150 
151    case nir_op_ilt:
152       alu->op = nir_op_flt;
153       break;
154    case nir_op_ige:
155       alu->op = nir_op_fge;
156       break;
157    case nir_op_ieq:
158       alu->op = nir_op_feq;
159       break;
160    case nir_op_ine:
161       alu->op = nir_op_fneu;
162       break;
163    case nir_op_ult:
164       alu->op = nir_op_flt;
165       break;
166    case nir_op_uge:
167       alu->op = nir_op_fge;
168       break;
169 
170    case nir_op_iadd:
171       alu->op = nir_op_fadd;
172       break;
173    case nir_op_isub:
174       alu->op = nir_op_fsub;
175       break;
176    case nir_op_imul:
177       alu->op = nir_op_fmul;
178       break;
179 
180    case nir_op_idiv: {
181       nir_def *x = nir_ssa_for_alu_src(b, alu, 0);
182       nir_def *y = nir_ssa_for_alu_src(b, alu, 1);
183 
184       /* Hand-lower fdiv, since lower_int_to_float is after nir_opt_algebraic. */
185       if (b->shader->options->lower_fdiv) {
186          rep = nir_ftrunc(b, nir_fmul(b, x, nir_frcp(b, y)));
187       } else {
188          rep = nir_ftrunc(b, nir_fdiv(b, x, y));
189       }
190       break;
191    }
192 
193    case nir_op_iabs:
194       alu->op = nir_op_fabs;
195       break;
196    case nir_op_ineg:
197       alu->op = nir_op_fneg;
198       break;
199    case nir_op_imax:
200       alu->op = nir_op_fmax;
201       break;
202    case nir_op_imin:
203       alu->op = nir_op_fmin;
204       break;
205    case nir_op_umax:
206       alu->op = nir_op_fmax;
207       break;
208    case nir_op_umin:
209       alu->op = nir_op_fmin;
210       break;
211 
212    case nir_op_ball_iequal2:
213       alu->op = nir_op_ball_fequal2;
214       break;
215    case nir_op_ball_iequal3:
216       alu->op = nir_op_ball_fequal3;
217       break;
218    case nir_op_ball_iequal4:
219       alu->op = nir_op_ball_fequal4;
220       break;
221    case nir_op_bany_inequal2:
222       alu->op = nir_op_bany_fnequal2;
223       break;
224    case nir_op_bany_inequal3:
225       alu->op = nir_op_bany_fnequal3;
226       break;
227    case nir_op_bany_inequal4:
228       alu->op = nir_op_bany_fnequal4;
229       break;
230 
231    case nir_op_i32csel_gt:
232       alu->op = nir_op_fcsel_gt;
233       break;
234    case nir_op_i32csel_ge:
235       alu->op = nir_op_fcsel_ge;
236       break;
237 
238    default:
239       assert(nir_alu_type_get_base_type(info->output_type) != nir_type_int &&
240              nir_alu_type_get_base_type(info->output_type) != nir_type_uint);
241       for (unsigned i = 0; i < info->num_inputs; i++) {
242          assert(nir_alu_type_get_base_type(info->input_types[i]) != nir_type_int &&
243                 nir_alu_type_get_base_type(info->input_types[i]) != nir_type_uint);
244       }
245       return false;
246    }
247 
248    if (rep) {
249       /* We've emitted a replacement instruction */
250       nir_def_replace(&alu->def, rep);
251    }
252 
253    return true;
254 }
255 
256 static bool
nir_lower_int_to_float_impl(nir_function_impl * impl)257 nir_lower_int_to_float_impl(nir_function_impl *impl)
258 {
259    bool progress = false;
260    BITSET_WORD *float_types = NULL, *int_types = NULL;
261 
262    nir_builder b = nir_builder_create(impl);
263 
264    nir_index_ssa_defs(impl);
265    float_types = calloc(BITSET_WORDS(impl->ssa_alloc),
266                         sizeof(BITSET_WORD));
267    int_types = calloc(BITSET_WORDS(impl->ssa_alloc),
268                       sizeof(BITSET_WORD));
269    nir_gather_types(impl, float_types, int_types);
270 
271    nir_foreach_block(block, impl) {
272       nir_foreach_instr_safe(instr, block) {
273          switch (instr->type) {
274          case nir_instr_type_alu:
275             progress |= lower_alu_instr(&b, nir_instr_as_alu(instr));
276             break;
277 
278          case nir_instr_type_load_const: {
279             nir_load_const_instr *load = nir_instr_as_load_const(instr);
280             if (load->def.bit_size != 1 && BITSET_TEST(int_types, load->def.index)) {
281                for (unsigned i = 0; i < load->def.num_components; i++)
282                   load->value[i].f32 = load->value[i].i32;
283             }
284             break;
285          }
286 
287          case nir_instr_type_intrinsic:
288          case nir_instr_type_undef:
289          case nir_instr_type_phi:
290          case nir_instr_type_tex:
291             break;
292 
293          default:
294             nir_foreach_def(instr, assert_ssa_def_is_not_int, (void *)int_types);
295             break;
296          }
297       }
298    }
299 
300    if (progress) {
301       nir_metadata_preserve(impl, nir_metadata_control_flow);
302    } else {
303       nir_metadata_preserve(impl, nir_metadata_all);
304    }
305 
306    free(float_types);
307    free(int_types);
308 
309    return progress;
310 }
311 
312 bool
nir_lower_int_to_float(nir_shader * shader)313 nir_lower_int_to_float(nir_shader *shader)
314 {
315    bool progress = false;
316 
317    nir_foreach_function_impl(impl, shader) {
318       if (nir_lower_int_to_float_impl(impl))
319          progress = true;
320    }
321 
322    return progress;
323 }
324