xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_flrp.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2018 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 #include <math.h>
24 #include "util/u_vector.h"
25 #include "nir.h"
26 #include "nir_builder.h"
27 
28 /**
29  * Lower flrp instructions.
30  *
31  * Unlike the lowerings that are possible in nir_opt_algrbraic, this pass can
32  * examine more global information to determine a possibly more efficient
33  * lowering for each flrp.
34  */
35 
36 static void
append_flrp_to_dead_list(struct u_vector * dead_flrp,struct nir_alu_instr * alu)37 append_flrp_to_dead_list(struct u_vector *dead_flrp, struct nir_alu_instr *alu)
38 {
39    struct nir_alu_instr **tail = u_vector_add(dead_flrp);
40    *tail = alu;
41 }
42 
43 /**
44  * Replace flrp(a, b, c) with ffma(b, c, ffma(-a, c, a)).
45  */
46 static void
replace_with_strict_ffma(struct nir_builder * bld,struct u_vector * dead_flrp,struct nir_alu_instr * alu)47 replace_with_strict_ffma(struct nir_builder *bld, struct u_vector *dead_flrp,
48                          struct nir_alu_instr *alu)
49 {
50    nir_def *const a = nir_ssa_for_alu_src(bld, alu, 0);
51    nir_def *const b = nir_ssa_for_alu_src(bld, alu, 1);
52    nir_def *const c = nir_ssa_for_alu_src(bld, alu, 2);
53 
54    nir_def *const neg_a = nir_fneg(bld, a);
55    nir_instr_as_alu(neg_a->parent_instr)->exact = alu->exact;
56    nir_instr_as_alu(neg_a->parent_instr)->fp_fast_math = alu->fp_fast_math;
57 
58    nir_def *const inner_ffma = nir_ffma(bld, neg_a, c, a);
59    nir_instr_as_alu(inner_ffma->parent_instr)->exact = alu->exact;
60    nir_instr_as_alu(inner_ffma->parent_instr)->fp_fast_math = alu->fp_fast_math;
61 
62    nir_def *const outer_ffma = nir_ffma(bld, b, c, inner_ffma);
63    nir_instr_as_alu(outer_ffma->parent_instr)->exact = alu->exact;
64    nir_instr_as_alu(outer_ffma->parent_instr)->fp_fast_math = alu->fp_fast_math;
65 
66    nir_def_rewrite_uses(&alu->def, outer_ffma);
67 
68    /* DO NOT REMOVE the original flrp yet.  Many of the lowering choices are
69     * based on other uses of the sources.  Removing the flrp may cause the
70     * last flrp in a sequence to make a different, incorrect choice.
71     */
72    append_flrp_to_dead_list(dead_flrp, alu);
73 }
74 
75 /**
76  * Replace flrp(a, b, c) with ffma(a, (1 - c), bc)
77  */
78 static void
replace_with_single_ffma(struct nir_builder * bld,struct u_vector * dead_flrp,struct nir_alu_instr * alu)79 replace_with_single_ffma(struct nir_builder *bld, struct u_vector *dead_flrp,
80                          struct nir_alu_instr *alu)
81 {
82    nir_def *const a = nir_ssa_for_alu_src(bld, alu, 0);
83    nir_def *const b = nir_ssa_for_alu_src(bld, alu, 1);
84    nir_def *const c = nir_ssa_for_alu_src(bld, alu, 2);
85 
86    nir_def *const neg_c = nir_fneg(bld, c);
87    nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact;
88    nir_instr_as_alu(neg_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
89 
90    nir_def *const one_minus_c =
91       nir_fadd(bld, nir_imm_floatN_t(bld, 1.0f, c->bit_size), neg_c);
92    nir_instr_as_alu(one_minus_c->parent_instr)->exact = alu->exact;
93    nir_instr_as_alu(one_minus_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
94 
95    nir_def *const b_times_c = nir_fmul(bld, b, c);
96    nir_instr_as_alu(b_times_c->parent_instr)->exact = alu->exact;
97    nir_instr_as_alu(b_times_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
98 
99    nir_def *const final_ffma = nir_ffma(bld, a, one_minus_c, b_times_c);
100    nir_instr_as_alu(final_ffma->parent_instr)->exact = alu->exact;
101    nir_instr_as_alu(final_ffma->parent_instr)->fp_fast_math = alu->fp_fast_math;
102 
103    nir_def_rewrite_uses(&alu->def, final_ffma);
104 
105    /* DO NOT REMOVE the original flrp yet.  Many of the lowering choices are
106     * based on other uses of the sources.  Removing the flrp may cause the
107     * last flrp in a sequence to make a different, incorrect choice.
108     */
109    append_flrp_to_dead_list(dead_flrp, alu);
110 }
111 
112 /**
113  * Replace flrp(a, b, c) with a(1-c) + bc.
114  */
115 static void
replace_with_strict(struct nir_builder * bld,struct u_vector * dead_flrp,struct nir_alu_instr * alu)116 replace_with_strict(struct nir_builder *bld, struct u_vector *dead_flrp,
117                     struct nir_alu_instr *alu)
118 {
119    nir_def *const a = nir_ssa_for_alu_src(bld, alu, 0);
120    nir_def *const b = nir_ssa_for_alu_src(bld, alu, 1);
121    nir_def *const c = nir_ssa_for_alu_src(bld, alu, 2);
122 
123    nir_def *const neg_c = nir_fneg(bld, c);
124    nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact;
125    nir_instr_as_alu(neg_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
126 
127    nir_def *const one_minus_c =
128       nir_fadd(bld, nir_imm_floatN_t(bld, 1.0f, c->bit_size), neg_c);
129    nir_instr_as_alu(one_minus_c->parent_instr)->exact = alu->exact;
130    nir_instr_as_alu(one_minus_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
131 
132    nir_def *const first_product = nir_fmul(bld, a, one_minus_c);
133    nir_instr_as_alu(first_product->parent_instr)->exact = alu->exact;
134    nir_instr_as_alu(first_product->parent_instr)->fp_fast_math = alu->fp_fast_math;
135 
136    nir_def *const second_product = nir_fmul(bld, b, c);
137    nir_instr_as_alu(second_product->parent_instr)->exact = alu->exact;
138    nir_instr_as_alu(second_product->parent_instr)->fp_fast_math = alu->fp_fast_math;
139 
140    nir_def *const sum = nir_fadd(bld, first_product, second_product);
141    nir_instr_as_alu(sum->parent_instr)->exact = alu->exact;
142    nir_instr_as_alu(sum->parent_instr)->fp_fast_math = alu->fp_fast_math;
143 
144    nir_def_rewrite_uses(&alu->def, sum);
145 
146    /* DO NOT REMOVE the original flrp yet.  Many of the lowering choices are
147     * based on other uses of the sources.  Removing the flrp may cause the
148     * last flrp in a sequence to make a different, incorrect choice.
149     */
150    append_flrp_to_dead_list(dead_flrp, alu);
151 }
152 
153 /**
154  * Replace flrp(a, b, c) with a + c(b-a).
155  */
156 static void
replace_with_fast(struct nir_builder * bld,struct u_vector * dead_flrp,struct nir_alu_instr * alu)157 replace_with_fast(struct nir_builder *bld, struct u_vector *dead_flrp,
158                   struct nir_alu_instr *alu)
159 {
160    nir_def *const a = nir_ssa_for_alu_src(bld, alu, 0);
161    nir_def *const b = nir_ssa_for_alu_src(bld, alu, 1);
162    nir_def *const c = nir_ssa_for_alu_src(bld, alu, 2);
163 
164    nir_def *const neg_a = nir_fneg(bld, a);
165    nir_instr_as_alu(neg_a->parent_instr)->exact = alu->exact;
166    nir_instr_as_alu(neg_a->parent_instr)->fp_fast_math = alu->fp_fast_math;
167 
168    nir_def *const b_minus_a = nir_fadd(bld, b, neg_a);
169    nir_instr_as_alu(b_minus_a->parent_instr)->exact = alu->exact;
170    nir_instr_as_alu(b_minus_a->parent_instr)->fp_fast_math = alu->fp_fast_math;
171 
172    nir_def *const product = nir_fmul(bld, c, b_minus_a);
173    nir_instr_as_alu(product->parent_instr)->exact = alu->exact;
174    nir_instr_as_alu(product->parent_instr)->fp_fast_math = alu->fp_fast_math;
175 
176    nir_def *const sum = nir_fadd(bld, a, product);
177    nir_instr_as_alu(sum->parent_instr)->exact = alu->exact;
178    nir_instr_as_alu(sum->parent_instr)->fp_fast_math = alu->fp_fast_math;
179 
180    nir_def_rewrite_uses(&alu->def, sum);
181 
182    /* DO NOT REMOVE the original flrp yet.  Many of the lowering choices are
183     * based on other uses of the sources.  Removing the flrp may cause the
184     * last flrp in a sequence to make a different, incorrect choice.
185     */
186    append_flrp_to_dead_list(dead_flrp, alu);
187 }
188 
189 /**
190  * Replace flrp(a, b, c) with (b*c ± c) + a => b*c + (a ± c)
191  *
192  * \note: This only works if a = ±1.
193  */
194 static void
replace_with_expanded_ffma_and_add(struct nir_builder * bld,struct u_vector * dead_flrp,struct nir_alu_instr * alu,bool subtract_c)195 replace_with_expanded_ffma_and_add(struct nir_builder *bld,
196                                    struct u_vector *dead_flrp,
197                                    struct nir_alu_instr *alu, bool subtract_c)
198 {
199    nir_def *const a = nir_ssa_for_alu_src(bld, alu, 0);
200    nir_def *const b = nir_ssa_for_alu_src(bld, alu, 1);
201    nir_def *const c = nir_ssa_for_alu_src(bld, alu, 2);
202 
203    nir_def *const b_times_c = nir_fmul(bld, b, c);
204    nir_instr_as_alu(b_times_c->parent_instr)->exact = alu->exact;
205    nir_instr_as_alu(b_times_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
206 
207    nir_def *inner_sum;
208 
209    if (subtract_c) {
210       nir_def *const neg_c = nir_fneg(bld, c);
211       nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact;
212       nir_instr_as_alu(neg_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
213 
214       inner_sum = nir_fadd(bld, a, neg_c);
215    } else {
216       inner_sum = nir_fadd(bld, a, c);
217    }
218 
219    nir_instr_as_alu(inner_sum->parent_instr)->exact = alu->exact;
220    nir_instr_as_alu(inner_sum->parent_instr)->fp_fast_math = alu->fp_fast_math;
221 
222    nir_def *const outer_sum = nir_fadd(bld, inner_sum, b_times_c);
223    nir_instr_as_alu(outer_sum->parent_instr)->exact = alu->exact;
224    nir_instr_as_alu(outer_sum->parent_instr)->fp_fast_math = alu->fp_fast_math;
225 
226    nir_def_rewrite_uses(&alu->def, outer_sum);
227 
228    /* DO NOT REMOVE the original flrp yet.  Many of the lowering choices are
229     * based on other uses of the sources.  Removing the flrp may cause the
230     * last flrp in a sequence to make a different, incorrect choice.
231     */
232    append_flrp_to_dead_list(dead_flrp, alu);
233 }
234 
235 /**
236  * Determines whether a swizzled source is constant w/ all components the same.
237  *
238  * The value of the constant is stored in \c result.
239  *
240  * \return
241  * True if all components of the swizzled source are the same constant.
242  * Otherwise false is returned.
243  */
244 static bool
all_same_constant(const nir_alu_instr * instr,unsigned src,double * result)245 all_same_constant(const nir_alu_instr *instr, unsigned src, double *result)
246 {
247    nir_const_value *val = nir_src_as_const_value(instr->src[src].src);
248 
249    if (!val)
250       return false;
251 
252    const uint8_t *const swizzle = instr->src[src].swizzle;
253    const unsigned num_components = instr->def.num_components;
254 
255    if (instr->def.bit_size == 32) {
256       const float first = val[swizzle[0]].f32;
257 
258       for (unsigned i = 1; i < num_components; i++) {
259          if (val[swizzle[i]].f32 != first)
260             return false;
261       }
262 
263       *result = first;
264    } else {
265       const double first = val[swizzle[0]].f64;
266 
267       for (unsigned i = 1; i < num_components; i++) {
268          if (val[swizzle[i]].f64 != first)
269             return false;
270       }
271 
272       *result = first;
273    }
274 
275    return true;
276 }
277 
278 static bool
sources_are_constants_with_similar_magnitudes(const nir_alu_instr * instr)279 sources_are_constants_with_similar_magnitudes(const nir_alu_instr *instr)
280 {
281    nir_const_value *val0 = nir_src_as_const_value(instr->src[0].src);
282    nir_const_value *val1 = nir_src_as_const_value(instr->src[1].src);
283 
284    if (val0 == NULL || val1 == NULL)
285       return false;
286 
287    const uint8_t *const swizzle0 = instr->src[0].swizzle;
288    const uint8_t *const swizzle1 = instr->src[1].swizzle;
289    const unsigned num_components = instr->def.num_components;
290 
291    if (instr->def.bit_size == 32) {
292       for (unsigned i = 0; i < num_components; i++) {
293          int exp0;
294          int exp1;
295 
296          frexpf(val0[swizzle0[i]].f32, &exp0);
297          frexpf(val1[swizzle1[i]].f32, &exp1);
298 
299          /* If the difference between exponents is >= 24, then A+B will always
300           * have the value whichever between A and B has the largest absolute
301           * value.  So, [0, 23] is the valid range.  The smaller the limit
302           * value, the more precision will be maintained at a potential
303           * performance cost.  Somewhat arbitrarilly split the range in half.
304           */
305          if (abs(exp0 - exp1) > (23 / 2))
306             return false;
307       }
308    } else {
309       for (unsigned i = 0; i < num_components; i++) {
310          int exp0;
311          int exp1;
312 
313          frexp(val0[swizzle0[i]].f64, &exp0);
314          frexp(val1[swizzle1[i]].f64, &exp1);
315 
316          /* If the difference between exponents is >= 53, then A+B will always
317           * have the value whichever between A and B has the largest absolute
318           * value.  So, [0, 52] is the valid range.  The smaller the limit
319           * value, the more precision will be maintained at a potential
320           * performance cost.  Somewhat arbitrarilly split the range in half.
321           */
322          if (abs(exp0 - exp1) > (52 / 2))
323             return false;
324       }
325    }
326 
327    return true;
328 }
329 
330 /**
331  * Counts of similar types of nir_op_flrp instructions
332  *
333  * If a similar instruction fits into more than one category, it will only be
334  * counted once.  The assumption is that no other instruction will have all
335  * sources the same, or CSE would have removed one of the instructions.
336  */
337 struct similar_flrp_stats {
338    unsigned src2;
339    unsigned src0_and_src2;
340    unsigned src1_and_src2;
341 };
342 
343 /**
344  * Collection counts of similar FLRP instructions.
345  *
346  * This function only cares about similar instructions that have src2 in
347  * common.
348  */
349 static void
get_similar_flrp_stats(nir_alu_instr * alu,struct similar_flrp_stats * st)350 get_similar_flrp_stats(nir_alu_instr *alu, struct similar_flrp_stats *st)
351 {
352    memset(st, 0, sizeof(*st));
353 
354    nir_foreach_use(other_use, alu->src[2].src.ssa) {
355       /* Is the use also a flrp? */
356       nir_instr *const other_instr = nir_src_parent_instr(other_use);
357       if (other_instr->type != nir_instr_type_alu)
358          continue;
359 
360       /* Eh-hem... don't match the instruction with itself. */
361       if (other_instr == &alu->instr)
362          continue;
363 
364       nir_alu_instr *const other_alu = nir_instr_as_alu(other_instr);
365       if (other_alu->op != nir_op_flrp)
366          continue;
367 
368       /* Does the other flrp use source 2 from the first flrp as its source 2
369        * as well?
370        */
371       if (!nir_alu_srcs_equal(alu, other_alu, 2, 2))
372          continue;
373 
374       if (nir_alu_srcs_equal(alu, other_alu, 0, 0))
375          st->src0_and_src2++;
376       else if (nir_alu_srcs_equal(alu, other_alu, 1, 1))
377          st->src1_and_src2++;
378       else
379          st->src2++;
380    }
381 }
382 
383 static void
convert_flrp_instruction(nir_builder * bld,struct u_vector * dead_flrp,nir_alu_instr * alu,bool always_precise)384 convert_flrp_instruction(nir_builder *bld,
385                          struct u_vector *dead_flrp,
386                          nir_alu_instr *alu,
387                          bool always_precise)
388 {
389    bool have_ffma = false;
390    unsigned bit_size = alu->def.bit_size;
391 
392    if (bit_size == 16)
393       have_ffma = !bld->shader->options->lower_ffma16;
394    else if (bit_size == 32)
395       have_ffma = !bld->shader->options->lower_ffma32;
396    else if (bit_size == 64)
397       have_ffma = !bld->shader->options->lower_ffma64;
398    else
399       unreachable("invalid bit_size");
400 
401    bld->cursor = nir_before_instr(&alu->instr);
402 
403    /* There are two methods to implement flrp(x, y, t).  The strictly correct
404     * implementation according to the GLSL spec is:
405     *
406     *    x(1 - t) + yt
407     *
408     * This can also be implemented using two chained FMAs
409     *
410     *    fma(y, t, fma(-x, t, x))
411     *
412     * This method, using either formulation, has better precision when the
413     * difference between x and y is very large.  It guarantess that flrp(x, y,
414     * 1) = y.  For example, flrp(1e38, 1.0, 1.0) is 1.0.  This is correct.
415     *
416     * The other possible implementation is:
417     *
418     *    x + t(y - x)
419     *
420     * This can also be formuated as an FMA:
421     *
422     *    fma(y - x, t, x)
423     *
424     * For this implementation, flrp(1e38, 1.0, 1.0) is 0.0.  Since 1.0 was
425     * expected, that's a pretty significant error.
426     *
427     * The choice made for lowering depends on a number of factors.
428     *
429     * - If the flrp is marked precise and FMA is supported:
430     *
431     *        fma(y, t, fma(-x, t, x))
432     *
433     *   This is strictly correct (maybe?), and the cost is two FMA
434     *   instructions.  It at least maintains the flrp(x, y, 1.0) == y
435     *   condition.
436     *
437     * - If the flrp is marked precise and FMA is not supported:
438     *
439     *        x(1 - t) + yt
440     *
441     *   This is strictly correct, and the cost is 4 instructions.  If FMA is
442     *   supported, this may or may not be reduced to 3 instructions (a
443     *   subtract, a multiply, and an FMA)... but in that case the other
444     *   formulation should have been used.
445     */
446    if (alu->exact) {
447       if (have_ffma)
448          replace_with_strict_ffma(bld, dead_flrp, alu);
449       else
450          replace_with_strict(bld, dead_flrp, alu);
451 
452       return;
453    }
454 
455    /*
456     * - If x and y are both immediates and the relative magnitude of the
457     *   values is similar (such that x-y does not lose too much precision):
458     *
459     *        x + t(x - y)
460     *
461     *   We rely on constant folding to eliminate x-y, and we rely on
462     *   nir_opt_algebraic to possibly generate an FMA.  The cost is either one
463     *   FMA or two instructions.
464     */
465    if (sources_are_constants_with_similar_magnitudes(alu)) {
466       replace_with_fast(bld, dead_flrp, alu);
467       return;
468    }
469 
470    /*
471     * - If x = 1:
472     *
473     *        (yt + -t) + 1
474     *
475     * - If x = -1:
476     *
477     *        (yt + t) - 1
478     *
479     *   In both cases, x is used in place of ±1 for simplicity.  Both forms
480     *   lend to ffma generation on platforms that support ffma.
481     */
482    double src0_as_constant;
483    if (all_same_constant(alu, 0, &src0_as_constant)) {
484       if (src0_as_constant == 1.0) {
485          replace_with_expanded_ffma_and_add(bld, dead_flrp, alu,
486                                             true /* subtract t */);
487          return;
488       } else if (src0_as_constant == -1.0) {
489          replace_with_expanded_ffma_and_add(bld, dead_flrp, alu,
490                                             false /* add t */);
491          return;
492       }
493    }
494 
495    /*
496     * - If y = ±1:
497     *
498     *        x(1 - t) + yt
499     *
500     *   In this case either the multiply in yt will be eliminated by
501     *   nir_opt_algebraic.  If FMA is supported, this results in fma(x, (1 -
502     *   t), ±t) for two instructions.  If FMA is not supported, then the cost
503     *   is 3 instructions.  We rely on nir_opt_algebraic to generate the FMA
504     *   instructions as well.
505     *
506     *   Another possible replacement is
507     *
508     *        -xt + x ± t
509     *
510     *   Some groupings of this may be better on some platforms in some
511     *   circumstances, bit it is probably dependent on scheduling.  Futher
512     *   investigation may be required.
513     */
514    double src1_as_constant;
515    if ((all_same_constant(alu, 1, &src1_as_constant) &&
516         (src1_as_constant == -1.0 || src1_as_constant == 1.0))) {
517       replace_with_strict(bld, dead_flrp, alu);
518       return;
519    }
520 
521    if (have_ffma) {
522       if (always_precise) {
523          replace_with_strict_ffma(bld, dead_flrp, alu);
524          return;
525       }
526 
527       /*
528        * - If FMA is supported and other flrp(x, _, t) exists:
529        *
530        *        fma(y, t, fma(-x, t, x))
531        *
532        *   The hope is that the inner FMA calculation will be shared with the
533        *   other lowered flrp.  This results in two FMA instructions for the
534        *   first flrp and one FMA instruction for each additional flrp.  It
535        *   also means that the live range for x might be complete after the
536        *   inner ffma instead of after the last flrp.
537        */
538       struct similar_flrp_stats st;
539 
540       get_similar_flrp_stats(alu, &st);
541       if (st.src0_and_src2 > 0) {
542          replace_with_strict_ffma(bld, dead_flrp, alu);
543          return;
544       }
545 
546       /*
547        * - If FMA is supported and another flrp(_, y, t) exists:
548        *
549        *        fma(x, (1 - t), yt)
550        *
551        *   The hope is that the (1 - t) and the yt will be shared with the
552        *   other lowered flrp.  This results in 3 insructions for the first
553        *   flrp and 1 for each additional flrp.
554        */
555       if (st.src1_and_src2 > 0) {
556          replace_with_single_ffma(bld, dead_flrp, alu);
557          return;
558       }
559    } else {
560       if (always_precise) {
561          replace_with_strict(bld, dead_flrp, alu);
562          return;
563       }
564 
565       /*
566        * - If FMA is not supported and another flrp(x, _, t) exists:
567        *
568        *        x(1 - t) + yt
569        *
570        *   The hope is that the x(1 - t) will be shared with the other lowered
571        *   flrp.  This results in 4 insructions for the first flrp and 2 for
572        *   each additional flrp.
573        *
574        * - If FMA is not supported and another flrp(_, y, t) exists:
575        *
576        *        x(1 - t) + yt
577        *
578        *   The hope is that the (1 - t) and the yt will be shared with the
579        *   other lowered flrp.  This results in 4 insructions for the first
580        *   flrp and 2 for each additional flrp.
581        */
582       struct similar_flrp_stats st;
583 
584       get_similar_flrp_stats(alu, &st);
585       if (st.src0_and_src2 > 0 || st.src1_and_src2 > 0) {
586          replace_with_strict(bld, dead_flrp, alu);
587          return;
588       }
589    }
590 
591    /*
592     * - If t is constant:
593     *
594     *        x(1 - t) + yt
595     *
596     *   The cost is three instructions without FMA or two instructions with
597     *   FMA.  This is the same cost as the imprecise lowering, but it gives
598     *   the instruction scheduler a little more freedom.
599     *
600     *   There is no need to handle t = 0.5 specially.  nir_opt_algebraic
601     *   already has optimizations to convert 0.5x + 0.5y to 0.5(x + y).
602     */
603    if (alu->src[2].src.ssa->parent_instr->type == nir_instr_type_load_const) {
604       replace_with_strict(bld, dead_flrp, alu);
605       return;
606    }
607 
608    /*
609     * - Otherwise
610     *
611     *        x + t(x - y)
612     */
613    replace_with_fast(bld, dead_flrp, alu);
614 }
615 
616 static void
lower_flrp_impl(nir_function_impl * impl,struct u_vector * dead_flrp,unsigned lowering_mask,bool always_precise)617 lower_flrp_impl(nir_function_impl *impl,
618                 struct u_vector *dead_flrp,
619                 unsigned lowering_mask,
620                 bool always_precise)
621 {
622    nir_builder b = nir_builder_create(impl);
623 
624    nir_foreach_block(block, impl) {
625       nir_foreach_instr_safe(instr, block) {
626          if (instr->type == nir_instr_type_alu) {
627             nir_alu_instr *const alu = nir_instr_as_alu(instr);
628 
629             if (alu->op == nir_op_flrp &&
630                 (alu->def.bit_size & lowering_mask)) {
631                convert_flrp_instruction(&b, dead_flrp, alu, always_precise);
632             }
633          }
634       }
635    }
636 
637    nir_metadata_preserve(impl, nir_metadata_control_flow);
638 }
639 
640 /**
641  * \param lowering_mask - Bitwise-or of the bit sizes that need to be lowered
642  *                        (e.g., 16 | 64 if only 16-bit and 64-bit flrp need
643  *                        lowering).
644  * \param always_precise - Always require precise lowering for flrp.  This
645  *                        will always lower flrp to (a * (1 - c)) + (b * c).
646  * \param have_ffma - Set to true if the GPU has an FFMA instruction that
647  *                    should be used.
648  */
649 bool
nir_lower_flrp(nir_shader * shader,unsigned lowering_mask,bool always_precise)650 nir_lower_flrp(nir_shader *shader,
651                unsigned lowering_mask,
652                bool always_precise)
653 {
654    struct u_vector dead_flrp;
655 
656    if (!u_vector_init_pow2(&dead_flrp, 8, sizeof(struct nir_alu_instr *)))
657       return false;
658 
659    nir_foreach_function_impl(impl, shader) {
660       lower_flrp_impl(impl, &dead_flrp, lowering_mask, always_precise);
661    }
662 
663    /* Progress was made if the dead list is not empty.  Remove all the
664     * instructions from the dead list.
665     */
666    const bool progress = u_vector_length(&dead_flrp) != 0;
667 
668    struct nir_alu_instr **instr;
669    u_vector_foreach(instr, &dead_flrp)
670       nir_instr_remove(&(*instr)->instr);
671 
672    u_vector_finish(&dead_flrp);
673 
674    return progress;
675 }
676