xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_opt_comparison_pre.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2018 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "util/u_vector.h"
25 #include "nir_builder.h"
26 #include "nir_instr_set.h"
27 #include "nir_search_helpers.h"
28 
29 /* Partial redundancy elimination of compares
30  *
31  * Seaches for comparisons of the form 'a cmp b' that dominate arithmetic
32  * instructions like 'b - a'.  The comparison is replaced by the arithmetic
33  * instruction, and the result is compared with zero.  For example,
34  *
35  *       vec1 32 ssa_111 = flt 0.37, ssa_110.w
36  *       if ssa_111 {
37  *               block block_1:
38  *              vec1 32 ssa_112 = fadd ssa_110.w, -0.37
39  *              ...
40  *
41  * becomes
42  *
43  *       vec1 32 ssa_111 = fadd ssa_110.w, -0.37
44  *       vec1 32 ssa_112 = flt 0.0, ssa_111
45  *       if ssa_112 {
46  *               block block_1:
47  *              ...
48  */
49 
50 struct block_queue {
51    /**
52     * Stack of blocks from the current location in the CFG to the entry point
53     * of the function.
54     *
55     * This is sort of a poor man's dominator tree.
56     */
57    struct exec_list blocks;
58 
59    /** List of freed block_instructions structures that can be reused. */
60    struct exec_list reusable_blocks;
61 };
62 
63 struct block_instructions {
64    struct exec_node node;
65 
66    /**
67     * Set of comparison instructions from the block that are candidates for
68     * being replaced by add instructions.
69     */
70    struct u_vector instructions;
71 };
72 
73 static void
block_queue_init(struct block_queue * bq)74 block_queue_init(struct block_queue *bq)
75 {
76    exec_list_make_empty(&bq->blocks);
77    exec_list_make_empty(&bq->reusable_blocks);
78 }
79 
80 static void
block_queue_finish(struct block_queue * bq)81 block_queue_finish(struct block_queue *bq)
82 {
83    struct block_instructions *n;
84 
85    while ((n = (struct block_instructions *)exec_list_pop_head(&bq->blocks)) != NULL) {
86       u_vector_finish(&n->instructions);
87       free(n);
88    }
89 
90    while ((n = (struct block_instructions *)exec_list_pop_head(&bq->reusable_blocks)) != NULL) {
91       free(n);
92    }
93 }
94 
95 static struct block_instructions *
push_block(struct block_queue * bq)96 push_block(struct block_queue *bq)
97 {
98    struct block_instructions *bi =
99       (struct block_instructions *)exec_list_pop_head(&bq->reusable_blocks);
100 
101    if (bi == NULL) {
102       bi = calloc(1, sizeof(struct block_instructions));
103 
104       if (bi == NULL)
105          return NULL;
106    }
107 
108    if (!u_vector_init_pow2(&bi->instructions, 8, sizeof(nir_alu_instr *))) {
109       free(bi);
110       return NULL;
111    }
112 
113    exec_list_push_tail(&bq->blocks, &bi->node);
114 
115    return bi;
116 }
117 
118 static void
pop_block(struct block_queue * bq,struct block_instructions * bi)119 pop_block(struct block_queue *bq, struct block_instructions *bi)
120 {
121    u_vector_finish(&bi->instructions);
122    exec_node_remove(&bi->node);
123    exec_list_push_head(&bq->reusable_blocks, &bi->node);
124 }
125 
126 static void
add_instruction_for_block(struct block_instructions * bi,nir_alu_instr * alu)127 add_instruction_for_block(struct block_instructions *bi,
128                           nir_alu_instr *alu)
129 {
130    nir_alu_instr **data =
131       u_vector_add(&bi->instructions);
132 
133    *data = alu;
134 }
135 
136 /**
137  * Determine if the ALU instruction is used by an if-condition or used by a
138  * logic-not that is used by an if-condition.
139  */
140 static bool
is_compatible_condition(const nir_alu_instr * instr)141 is_compatible_condition(const nir_alu_instr *instr)
142 {
143    if (is_used_by_if(instr))
144       return true;
145 
146    nir_foreach_use(src, &instr->def) {
147       const nir_instr *const user_instr = nir_src_parent_instr(src);
148 
149       if (user_instr->type != nir_instr_type_alu)
150          continue;
151 
152       const nir_alu_instr *const user_alu = nir_instr_as_alu(user_instr);
153 
154       if (user_alu->op != nir_op_inot)
155          continue;
156 
157       if (is_used_by_if(user_alu))
158          return true;
159    }
160 
161    return false;
162 }
163 
164 static void
rewrite_compare_instruction(nir_builder * bld,nir_alu_instr * orig_cmp,nir_alu_instr * orig_add,bool zero_on_left)165 rewrite_compare_instruction(nir_builder *bld, nir_alu_instr *orig_cmp,
166                             nir_alu_instr *orig_add, bool zero_on_left)
167 {
168    bld->cursor = nir_before_instr(&orig_cmp->instr);
169 
170    /* This is somewhat tricky.  The compare instruction may be something like
171     * (fcmp, a, b) while the add instruction is something like (fadd, fneg(a),
172     * b).  This is problematic because the SSA value for the fneg(a) may not
173     * exist yet at the compare instruction.
174     *
175     * We fabricate the operands of the new add.  This is done using
176     * information provided by zero_on_left.  If zero_on_left is true, we know
177     * the resulting compare instruction is (fcmp, 0.0, (fadd, x, y)).  If the
178     * original compare instruction was (fcmp, a, b), x = b and y = -a.  If
179     * zero_on_left is false, the resulting compare instruction is (fcmp,
180     * (fadd, x, y), 0.0) and x = a and y = -b.
181     */
182    nir_def *const a = nir_ssa_for_alu_src(bld, orig_cmp, 0);
183    nir_def *const b = nir_ssa_for_alu_src(bld, orig_cmp, 1);
184 
185    nir_def *const fadd = zero_on_left
186                             ? nir_fadd(bld, b, nir_fneg(bld, a))
187                             : nir_fadd(bld, a, nir_fneg(bld, b));
188 
189    nir_def *const zero =
190       nir_imm_floatN_t(bld, 0.0, orig_add->def.bit_size);
191 
192    nir_def *const cmp = zero_on_left
193                            ? nir_build_alu(bld, orig_cmp->op, zero, fadd, NULL, NULL)
194                            : nir_build_alu(bld, orig_cmp->op, fadd, zero, NULL, NULL);
195 
196    /* Generating extra moves of the results is the easy way to make sure the
197     * writemasks match the original instructions.  Later optimization passes
198     * will clean these up.  This is similar to nir_replace_instr (in
199     * nir_search.c).
200     */
201    nir_alu_instr *mov_add = nir_alu_instr_create(bld->shader, nir_op_mov);
202    nir_def_init(&mov_add->instr, &mov_add->def,
203                 orig_add->def.num_components,
204                 orig_add->def.bit_size);
205    mov_add->src[0].src = nir_src_for_ssa(fadd);
206 
207    nir_builder_instr_insert(bld, &mov_add->instr);
208 
209    nir_alu_instr *mov_cmp = nir_alu_instr_create(bld->shader, nir_op_mov);
210    nir_def_init(&mov_cmp->instr, &mov_cmp->def,
211                 orig_cmp->def.num_components,
212                 orig_cmp->def.bit_size);
213    mov_cmp->src[0].src = nir_src_for_ssa(cmp);
214 
215    nir_builder_instr_insert(bld, &mov_cmp->instr);
216 
217    nir_def_rewrite_uses(&orig_cmp->def,
218                         &mov_cmp->def);
219    nir_def_rewrite_uses(&orig_add->def,
220                         &mov_add->def);
221 
222    /* We know these have no more uses because we just rewrote them all, so we
223     * can remove them.
224     */
225    nir_instr_remove(&orig_cmp->instr);
226    nir_instr_remove(&orig_add->instr);
227 }
228 
229 static bool
comparison_pre_block(nir_block * block,struct block_queue * bq,nir_builder * bld)230 comparison_pre_block(nir_block *block, struct block_queue *bq, nir_builder *bld)
231 {
232    bool progress = false;
233 
234    struct block_instructions *bi = push_block(bq);
235    if (bi == NULL)
236       return false;
237 
238    /* Starting with the current block, examine each instruction.  If the
239     * instruction is a comparison that matches the '±a cmp ±b' pattern, add it
240     * to the block_instructions::instructions set.  If the instruction is an
241     * add instruction, walk up the block queue looking at the stored
242     * instructions.  If a matching comparison is found, move the addition and
243     * replace the comparison with a different comparison based on the result
244     * of the addition.  All of the blocks in the queue are guaranteed to be
245     * dominators of the current block.
246     *
247     * After processing the current block, recurse into the blocks dominated by
248     * the current block.
249     */
250    nir_foreach_instr_safe(instr, block) {
251       if (instr->type != nir_instr_type_alu)
252          continue;
253 
254       nir_alu_instr *const alu = nir_instr_as_alu(instr);
255 
256       if (alu->def.num_components != 1)
257          continue;
258 
259       static const uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
260 
261       switch (alu->op) {
262       case nir_op_fadd: {
263          /* If the instruction is fadd, check it against comparison
264           * instructions that dominate it.
265           */
266          struct block_instructions *b =
267             (struct block_instructions *)exec_list_get_head_raw(&bq->blocks);
268 
269          while (b->node.next != NULL) {
270             nir_alu_instr **a;
271             bool rewrote_compare = false;
272 
273             u_vector_foreach(a, &b->instructions)
274             {
275                nir_alu_instr *const cmp = *a;
276 
277                if (cmp == NULL)
278                   continue;
279 
280                /* The operands of both instructions are, with some liberty,
281                 * commutative.  Check all four permutations.  The third and
282                 * fourth permutations are negations of the first two.
283                 */
284                if ((nir_alu_srcs_equal(cmp, alu, 0, 0) &&
285                     nir_alu_srcs_negative_equal(cmp, alu, 1, 1)) ||
286                    (nir_alu_srcs_equal(cmp, alu, 0, 1) &&
287                     nir_alu_srcs_negative_equal(cmp, alu, 1, 0))) {
288                   /* These are the cases where (A cmp B) matches either (A +
289                    * -B) or (-B + A)
290                    *
291                    *    A cmp B <=> A + -B cmp 0
292                    */
293                   rewrite_compare_instruction(bld, cmp, alu, false);
294 
295                   *a = NULL;
296                   rewrote_compare = true;
297                   break;
298                } else if ((nir_alu_srcs_equal(cmp, alu, 1, 0) &&
299                            nir_alu_srcs_negative_equal(cmp, alu, 0, 1)) ||
300                           (nir_alu_srcs_equal(cmp, alu, 1, 1) &&
301                            nir_alu_srcs_negative_equal(cmp, alu, 0, 0))) {
302                   /* This is the case where (A cmp B) matches (B + -A) or (-A
303                    * + B).
304                    *
305                    *    A cmp B <=> 0 cmp B + -A
306                    */
307                   rewrite_compare_instruction(bld, cmp, alu, true);
308 
309                   *a = NULL;
310                   rewrote_compare = true;
311                   break;
312                }
313             }
314 
315             /* Bail after a compare in the most dominating block is found.
316              * This is necessary because 'alu' has been removed from the
317              * instruction stream.  Should there be a matching compare in
318              * another block, calling rewrite_compare_instruction again will
319              * try to operate on a node that is not in the list as if it were
320              * in the list.
321              *
322              * FINISHME: There may be opportunity for additional optimization
323              * here.  I discovered this problem due to a shader in Guacamelee.
324              * It may be possible to rewrite the matching compares that are
325              * encountered later to reuse the result from the compare that was
326              * first rewritten.  It's also possible that this is just taken
327              * care of by calling the optimization pass repeatedly.
328              */
329             if (rewrote_compare) {
330                progress = true;
331                break;
332             }
333 
334             b = (struct block_instructions *)b->node.next;
335          }
336 
337          break;
338       }
339 
340       case nir_op_flt:
341       case nir_op_fge:
342       case nir_op_fneu:
343       case nir_op_feq:
344          /* If the instruction is a comparison that is used by an if-statement
345           * and neither operand is immediate value 0, add it to the set.
346           */
347          if (is_compatible_condition(alu) &&
348              is_not_const_zero(NULL, alu, 0, 1, swizzle) &&
349              is_not_const_zero(NULL, alu, 1, 1, swizzle))
350             add_instruction_for_block(bi, alu);
351 
352          break;
353 
354       default:
355          break;
356       }
357    }
358 
359    for (unsigned i = 0; i < block->num_dom_children; i++) {
360       nir_block *child = block->dom_children[i];
361 
362       if (comparison_pre_block(child, bq, bld))
363          progress = true;
364    }
365 
366    pop_block(bq, bi);
367 
368    return progress;
369 }
370 
371 bool
nir_opt_comparison_pre_impl(nir_function_impl * impl)372 nir_opt_comparison_pre_impl(nir_function_impl *impl)
373 {
374    struct block_queue bq;
375    nir_builder bld;
376 
377    block_queue_init(&bq);
378    bld = nir_builder_create(impl);
379 
380    nir_metadata_require(impl, nir_metadata_dominance);
381 
382    const bool progress =
383       comparison_pre_block(nir_start_block(impl), &bq, &bld);
384 
385    block_queue_finish(&bq);
386 
387    if (progress) {
388       nir_metadata_preserve(impl, nir_metadata_control_flow);
389    } else {
390       nir_metadata_preserve(impl, nir_metadata_all);
391    }
392 
393    return progress;
394 }
395 
396 bool
nir_opt_comparison_pre(nir_shader * shader)397 nir_opt_comparison_pre(nir_shader *shader)
398 {
399    bool progress = false;
400 
401    nir_foreach_function_impl(impl, shader) {
402       progress |= nir_opt_comparison_pre_impl(impl);
403    }
404 
405    return progress;
406 }
407