xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_goto_ifs.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2020 Julian Winkler
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_vla.h"
27 
28 #define NIR_LOWER_GOTO_IFS_DEBUG 0
29 
30 struct path {
31    /** Set of blocks which this path represents
32     *
33     * It's "reachable" not in the sense that these are all the nodes reachable
34     * through this path but in the sense that, when you see one of these
35     * blocks, you know you've reached this path.
36     */
37    struct set *reachable;
38 
39    /** Fork in the path, if reachable->entries > 1 */
40    struct path_fork *fork;
41 };
42 
43 struct path_fork {
44    bool is_var;
45    union {
46       nir_variable *path_var;
47       nir_def *path_ssa;
48    };
49    struct path paths[2];
50 };
51 
52 struct routes {
53    struct path regular;
54    struct path brk;
55    struct path cont;
56    struct routes *loop_backup;
57 };
58 
59 struct strct_lvl {
60    struct list_head link;
61 
62    /** Set of blocks at the current level */
63    struct set *blocks;
64 
65    /** Path for the next level */
66    struct path out_path;
67 
68    /** Reach set from inside_outside if irreducable */
69    struct set *reach;
70 
71    /** True if a skip region starts with this level */
72    bool skip_start;
73 
74    /** True if a skip region ends with this level */
75    bool skip_end;
76 
77    /** True if this level is irreducable */
78    bool irreducible;
79 };
80 
81 static int
nir_block_ptr_cmp(const void * _a,const void * _b)82 nir_block_ptr_cmp(const void *_a, const void *_b)
83 {
84    const nir_block *const *a = _a;
85    const nir_block *const *b = _b;
86    return (int)(*a)->index - (int)(*b)->index;
87 }
88 
89 static void
print_block_set(const struct set * set)90 print_block_set(const struct set *set)
91 {
92    printf("{ ");
93    if (set != NULL) {
94       unsigned count = 0;
95       set_foreach(set, entry) {
96          if (count++)
97             printf(", ");
98          printf("%u", ((nir_block *)entry->key)->index);
99       }
100    }
101    printf(" }\n");
102 }
103 
104 /** Return a sorted array of blocks for a set
105  *
106  * Hash set ordering is non-deterministic.  We hash based on pointers and so,
107  * if any pointer ever changes from one run to another, the order of the set
108  * may change.  Any time we're going to make decisions which may affect the
109  * final structure which may depend on ordering, we should first sort the
110  * blocks.
111  */
112 static nir_block **
sorted_block_arr_for_set(const struct set * block_set,void * mem_ctx)113 sorted_block_arr_for_set(const struct set *block_set, void *mem_ctx)
114 {
115    const unsigned num_blocks = block_set->entries;
116    nir_block **block_arr = ralloc_array(mem_ctx, nir_block *, num_blocks);
117    unsigned i = 0;
118    set_foreach(block_set, entry)
119       block_arr[i++] = (nir_block *)entry->key;
120    assert(i == num_blocks);
121    qsort(block_arr, num_blocks, sizeof(*block_arr), nir_block_ptr_cmp);
122    return block_arr;
123 }
124 
125 static nir_block *
block_for_singular_set(const struct set * block_set)126 block_for_singular_set(const struct set *block_set)
127 {
128    assert(block_set->entries == 1);
129    return (nir_block *)_mesa_set_next_entry(block_set, NULL)->key;
130 }
131 
132 /**
133  * Sets all path variables to reach the target block via a fork
134  */
135 static void
set_path_vars(nir_builder * b,struct path_fork * fork,nir_block * target)136 set_path_vars(nir_builder *b, struct path_fork *fork, nir_block *target)
137 {
138    while (fork) {
139       for (int i = 0; i < 2; i++) {
140          if (_mesa_set_search(fork->paths[i].reachable, target)) {
141             if (fork->is_var) {
142                nir_store_var(b, fork->path_var, nir_imm_bool(b, i), 1);
143             } else {
144                assert(fork->path_ssa == NULL);
145                fork->path_ssa = nir_imm_bool(b, i);
146             }
147             fork = fork->paths[i].fork;
148             break;
149          }
150       }
151    }
152 }
153 
154 /**
155  * Sets all path variables to reach the both target blocks via a fork.
156  * If the blocks are in different fork paths, the condition will be used.
157  * As the fork is already created, the then and else blocks may be swapped,
158  * in this case the condition is inverted
159  */
160 static void
set_path_vars_cond(nir_builder * b,struct path_fork * fork,nir_def * condition,nir_block * then_block,nir_block * else_block)161 set_path_vars_cond(nir_builder *b, struct path_fork *fork, nir_def *condition,
162                    nir_block *then_block, nir_block *else_block)
163 {
164    int i;
165    while (fork) {
166       for (i = 0; i < 2; i++) {
167          if (_mesa_set_search(fork->paths[i].reachable, then_block)) {
168             if (_mesa_set_search(fork->paths[i].reachable, else_block)) {
169                if (fork->is_var) {
170                   nir_store_var(b, fork->path_var, nir_imm_bool(b, i), 1);
171                } else {
172                   assert(fork->path_ssa == NULL);
173                   fork->path_ssa = nir_imm_bool(b, i);
174                }
175                fork = fork->paths[i].fork;
176                break;
177             } else {
178                assert(condition->bit_size == 1);
179                assert(condition->num_components == 1);
180                nir_def *fork_cond = condition;
181                if (!i)
182                   fork_cond = nir_inot(b, fork_cond);
183                if (fork->is_var) {
184                   nir_store_var(b, fork->path_var, fork_cond, 1);
185                } else {
186                   assert(fork->path_ssa == NULL);
187                   fork->path_ssa = fork_cond;
188                }
189                set_path_vars(b, fork->paths[i].fork, then_block);
190                set_path_vars(b, fork->paths[!i].fork, else_block);
191                return;
192             }
193          }
194       }
195       assert(i < 2);
196    }
197 }
198 
199 /**
200  * Sets all path variables and places the right jump instruction to reach the
201  * target block
202  */
203 static void
route_to(nir_builder * b,struct routes * routing,nir_block * target)204 route_to(nir_builder *b, struct routes *routing, nir_block *target)
205 {
206    if (_mesa_set_search(routing->regular.reachable, target)) {
207       set_path_vars(b, routing->regular.fork, target);
208    } else if (_mesa_set_search(routing->brk.reachable, target)) {
209       set_path_vars(b, routing->brk.fork, target);
210       nir_jump(b, nir_jump_break);
211    } else if (_mesa_set_search(routing->cont.reachable, target)) {
212       set_path_vars(b, routing->cont.fork, target);
213       nir_jump(b, nir_jump_continue);
214    } else {
215       assert(!target->successors[0]); /* target is endblock */
216       nir_jump(b, nir_jump_return);
217    }
218 }
219 
220 /**
221  * Sets path vars and places the right jump instr to reach one of the two
222  * target blocks based on the condition. If the targets need different jump
223  * istructions, they will be placed into an if else statement.
224  * This can happen if one target is the loop head
225  *     A __
226  *     |   \
227  *     B    |
228  *     |\__/
229  *     C
230  */
231 static void
route_to_cond(nir_builder * b,struct routes * routing,nir_def * condition,nir_block * then_block,nir_block * else_block)232 route_to_cond(nir_builder *b, struct routes *routing, nir_def *condition,
233               nir_block *then_block, nir_block *else_block)
234 {
235    if (_mesa_set_search(routing->regular.reachable, then_block)) {
236       if (_mesa_set_search(routing->regular.reachable, else_block)) {
237          set_path_vars_cond(b, routing->regular.fork, condition,
238                             then_block, else_block);
239          return;
240       }
241    } else if (_mesa_set_search(routing->brk.reachable, then_block)) {
242       if (_mesa_set_search(routing->brk.reachable, else_block)) {
243          set_path_vars_cond(b, routing->brk.fork, condition,
244                             then_block, else_block);
245          nir_jump(b, nir_jump_break);
246          return;
247       }
248    } else if (_mesa_set_search(routing->cont.reachable, then_block)) {
249       if (_mesa_set_search(routing->cont.reachable, else_block)) {
250          set_path_vars_cond(b, routing->cont.fork, condition,
251                             then_block, else_block);
252          nir_jump(b, nir_jump_continue);
253          return;
254       }
255    }
256 
257    /* then and else blocks are in different routes */
258    nir_push_if(b, condition);
259    route_to(b, routing, then_block);
260    nir_push_else(b, NULL);
261    route_to(b, routing, else_block);
262    nir_pop_if(b, NULL);
263 }
264 
265 /**
266  * Merges the reachable sets of both fork subpaths into the forks entire
267  * reachable set
268  */
269 static struct set *
fork_reachable(struct path_fork * fork)270 fork_reachable(struct path_fork *fork)
271 {
272    struct set *reachable = _mesa_set_clone(fork->paths[0].reachable, fork);
273    set_foreach(fork->paths[1].reachable, entry)
274       _mesa_set_add_pre_hashed(reachable, entry->hash, entry->key);
275    return reachable;
276 }
277 
278 /**
279  * Modifies the routing to be the routing inside a loop. The old regular path
280  * becomes the new break path. The loop in path becomes the new regular and
281  * continue path.
282  * The lost routing information is stacked into the loop_backup stack.
283  * Also creates helper vars for multilevel loop jumping if needed.
284  * Also calls the nir builder to build the loop
285  */
286 static void
loop_routing_start(struct routes * routing,nir_builder * b,struct path loop_path,struct set * reach,void * mem_ctx)287 loop_routing_start(struct routes *routing, nir_builder *b,
288                    struct path loop_path, struct set *reach,
289                    void *mem_ctx)
290 {
291    if (NIR_LOWER_GOTO_IFS_DEBUG) {
292       printf("loop_routing_start:\n");
293       printf("    reach =                       ");
294       print_block_set(reach);
295       printf("    loop_path.reachable =         ");
296       print_block_set(loop_path.reachable);
297       printf("    routing->regular.reachable =  ");
298       print_block_set(routing->regular.reachable);
299       printf("    routing->brk.reachable =      ");
300       print_block_set(routing->brk.reachable);
301       printf("    routing->cont.reachable =     ");
302       print_block_set(routing->cont.reachable);
303       printf("\n");
304    }
305 
306    struct routes *routing_backup = rzalloc(mem_ctx, struct routes);
307    *routing_backup = *routing;
308    bool break_needed = false;
309    bool continue_needed = false;
310 
311    set_foreach(reach, entry) {
312       if (_mesa_set_search(loop_path.reachable, entry->key))
313          continue;
314       if (_mesa_set_search(routing->regular.reachable, entry->key))
315          continue;
316       if (_mesa_set_search(routing->brk.reachable, entry->key)) {
317          break_needed = true;
318          continue;
319       }
320       assert(_mesa_set_search(routing->cont.reachable, entry->key));
321       continue_needed = true;
322    }
323 
324    routing->brk = routing_backup->regular;
325    routing->cont = loop_path;
326    routing->regular = loop_path;
327    routing->loop_backup = routing_backup;
328 
329    if (break_needed) {
330       struct path_fork *fork = rzalloc(mem_ctx, struct path_fork);
331       fork->is_var = true;
332       fork->path_var = nir_local_variable_create(b->impl, glsl_bool_type(),
333                                                  "path_break");
334       fork->paths[0] = routing->brk;
335       fork->paths[1] = routing_backup->brk;
336       routing->brk.fork = fork;
337       routing->brk.reachable = fork_reachable(fork);
338    }
339    if (continue_needed) {
340       struct path_fork *fork = rzalloc(mem_ctx, struct path_fork);
341       fork->is_var = true;
342       fork->path_var = nir_local_variable_create(b->impl, glsl_bool_type(),
343                                                  "path_continue");
344       fork->paths[0] = routing->brk;
345       fork->paths[1] = routing_backup->cont;
346       routing->brk.fork = fork;
347       routing->brk.reachable = fork_reachable(fork);
348    }
349    nir_push_loop(b);
350 }
351 
352 /**
353  * Gets a forks condition as ssa def if the condition is inside a helper var,
354  * the variable will be read into an ssa def
355  */
356 static nir_def *
fork_condition(nir_builder * b,struct path_fork * fork)357 fork_condition(nir_builder *b, struct path_fork *fork)
358 {
359    nir_def *ret;
360    if (fork->is_var) {
361       ret = nir_load_var(b, fork->path_var);
362    } else
363       ret = fork->path_ssa;
364    return ret;
365 }
366 
367 /**
368  * Restores the routing after leaving a loop based on the loop_backup stack.
369  * Also handles multi level jump helper vars if existing and calls the nir
370  * builder to pop the nir loop
371  */
372 static void
loop_routing_end(struct routes * routing,nir_builder * b)373 loop_routing_end(struct routes *routing, nir_builder *b)
374 {
375    struct routes *routing_backup = routing->loop_backup;
376    assert(routing->cont.fork == routing->regular.fork);
377    assert(routing->cont.reachable == routing->regular.reachable);
378    nir_pop_loop(b, NULL);
379    if (routing->brk.fork && routing->brk.fork->paths[1].reachable ==
380                                routing_backup->cont.reachable) {
381       assert(!(routing->brk.fork->is_var &&
382                strcmp(routing->brk.fork->path_var->name, "path_continue")));
383       nir_push_if(b, fork_condition(b, routing->brk.fork));
384       nir_jump(b, nir_jump_continue);
385       nir_pop_if(b, NULL);
386       routing->brk = routing->brk.fork->paths[0];
387    }
388    if (routing->brk.fork && routing->brk.fork->paths[1].reachable ==
389                                routing_backup->brk.reachable) {
390       assert(!(routing->brk.fork->is_var &&
391                strcmp(routing->brk.fork->path_var->name, "path_break")));
392       nir_break_if(b, fork_condition(b, routing->brk.fork));
393       routing->brk = routing->brk.fork->paths[0];
394    }
395    assert(routing->brk.fork == routing_backup->regular.fork);
396    assert(routing->brk.reachable == routing_backup->regular.reachable);
397    *routing = *routing_backup;
398    ralloc_free(routing_backup);
399 }
400 
401 /**
402  * generates a list of all blocks dominated by the loop header, but the
403  * control flow can't go back to the loop header from the block.
404  * also generates a list of all blocks that can be reached from within the
405  * loop
406  *    | __
407  *    A´  \
408  *    | \  \
409  *    B  C-´
410  *   /
411  *  D
412  * here B and C are directly dominated by A but only C can reach back to the
413  * loop head A. B will be added to the outside set and to the reach set.
414  * \param  loop_heads  set of loop heads. All blocks inside the loop will be
415  *                     added to this set
416  * \param  outside  all blocks directly outside the loop will be added
417  * \param  reach  all blocks reachable from the loop will be added
418  */
419 static void
inside_outside(nir_block * block,struct set * loop_heads,struct set * outside,struct set * reach,struct set * brk_reachable,void * mem_ctx)420 inside_outside(nir_block *block, struct set *loop_heads, struct set *outside,
421                struct set *reach, struct set *brk_reachable, void *mem_ctx)
422 {
423    assert(_mesa_set_search(loop_heads, block));
424    struct set *remaining = _mesa_pointer_set_create(mem_ctx);
425    for (int i = 0; i < block->num_dom_children; i++) {
426       if (!_mesa_set_search(brk_reachable, block->dom_children[i]))
427          _mesa_set_add(remaining, block->dom_children[i]);
428    }
429 
430    if (NIR_LOWER_GOTO_IFS_DEBUG) {
431       printf("inside_outside(%u):\n", block->index);
432       printf("    loop_heads = ");
433       print_block_set(loop_heads);
434       printf("    reach =      ");
435       print_block_set(reach);
436       printf("    brk_reach =  ");
437       print_block_set(brk_reachable);
438       printf("    remaining =  ");
439       print_block_set(remaining);
440       printf("\n");
441    }
442 
443    bool progress = true;
444    while (remaining->entries && progress) {
445       progress = false;
446       set_foreach(remaining, child_entry) {
447          nir_block *dom_child = (nir_block *)child_entry->key;
448          bool can_jump_back = false;
449          set_foreach(dom_child->dom_frontier, entry) {
450             if (entry->key == dom_child)
451                continue;
452             if (_mesa_set_search_pre_hashed(remaining, entry->hash,
453                                             entry->key)) {
454                can_jump_back = true;
455                break;
456             }
457             if (_mesa_set_search_pre_hashed(loop_heads, entry->hash,
458                                             entry->key)) {
459                can_jump_back = true;
460                break;
461             }
462          }
463          if (!can_jump_back) {
464             _mesa_set_add_pre_hashed(outside, child_entry->hash,
465                                      child_entry->key);
466             _mesa_set_remove(remaining, child_entry);
467             progress = true;
468          }
469       }
470    }
471 
472    /* Add everything remaining to loop_heads */
473    set_foreach(remaining, entry)
474       _mesa_set_add_pre_hashed(loop_heads, entry->hash, entry->key);
475 
476    /* Recurse for each remaining */
477    set_foreach(remaining, entry) {
478       inside_outside((nir_block *)entry->key, loop_heads, outside, reach,
479                      brk_reachable, mem_ctx);
480    }
481 
482    for (int i = 0; i < 2; i++) {
483       if (block->successors[i] && block->successors[i]->successors[0] &&
484           !_mesa_set_search(loop_heads, block->successors[i])) {
485          _mesa_set_add(reach, block->successors[i]);
486       }
487    }
488 
489    if (NIR_LOWER_GOTO_IFS_DEBUG) {
490       printf("outside(%u) = ", block->index);
491       print_block_set(outside);
492       printf("reach(%u) =   ", block->index);
493       print_block_set(reach);
494    }
495 }
496 
497 static struct path_fork *
select_fork_recur(struct nir_block ** blocks,unsigned start,unsigned end,nir_function_impl * impl,bool need_var,void * mem_ctx)498 select_fork_recur(struct nir_block **blocks, unsigned start, unsigned end,
499                   nir_function_impl *impl, bool need_var, void *mem_ctx)
500 {
501    if (start == end - 1)
502       return NULL;
503 
504    struct path_fork *fork = rzalloc(mem_ctx, struct path_fork);
505    fork->is_var = need_var;
506    if (need_var)
507       fork->path_var = nir_local_variable_create(impl, glsl_bool_type(),
508                                                  "path_select");
509 
510    unsigned mid = start + (end - start) / 2;
511 
512    fork->paths[0].reachable = _mesa_pointer_set_create(fork);
513    for (unsigned i = start; i < mid; i++)
514       _mesa_set_add(fork->paths[0].reachable, blocks[i]);
515    fork->paths[0].fork =
516       select_fork_recur(blocks, start, mid, impl, need_var, mem_ctx);
517 
518    fork->paths[1].reachable = _mesa_pointer_set_create(fork);
519    for (unsigned i = mid; i < end; i++)
520       _mesa_set_add(fork->paths[1].reachable, blocks[i]);
521    fork->paths[1].fork =
522       select_fork_recur(blocks, mid, end, impl, need_var, mem_ctx);
523 
524    return fork;
525 }
526 
527 /**
528  * Gets a set of blocks organized into the same level by the organize_levels
529  * function and creates enough forks to be able to route to them.
530  * If the set only contains one block, the function has nothing to do.
531  * The set should almost never contain more than two blocks, but if so,
532  * then the function calls itself recursively
533  */
534 static struct path_fork *
select_fork(struct set * reachable,nir_function_impl * impl,bool need_var,void * mem_ctx)535 select_fork(struct set *reachable, nir_function_impl *impl, bool need_var,
536             void *mem_ctx)
537 {
538    assert(reachable->entries > 0);
539    if (reachable->entries <= 1)
540       return NULL;
541 
542    /* Hash set ordering is non-deterministic.  We're about to turn a set into
543     * a tree so we really want things to be in a deterministic ordering.
544     */
545    return select_fork_recur(sorted_block_arr_for_set(reachable, mem_ctx),
546                             0, reachable->entries, impl, need_var, mem_ctx);
547 }
548 
549 /**
550  * gets called when the organize_levels functions fails to find blocks that
551  * can't be reached by the other remaining blocks. This means, at least two
552  * dominance sibling blocks can reach each other. So we have a multi entry
553  * loop. This function tries to find the smallest possible set of blocks that
554  * must be part of the multi entry loop.
555  * example cf:  |    |
556  *              A<---B
557  *             / \__,^ \
558  *             \       /
559  *               \   /
560  *                 C
561  * The function choses a random block as candidate. for example C
562  * The function checks which remaining blocks can reach C, in this case A.
563  * So A becomes the new candidate and C is removed from the result set.
564  * B can reach A.
565  * So B becomes the new candidate and A is removed from the set.
566  * A can reach B.
567  * A was an old candidate. So it is added to the set containing B.
568  * No other remaining blocks can reach A or B.
569  * So only A and B must be part of the multi entry loop.
570  */
571 static void
handle_irreducible(struct set * remaining,struct strct_lvl * curr_level,struct set * brk_reachable,void * mem_ctx)572 handle_irreducible(struct set *remaining, struct strct_lvl *curr_level,
573                    struct set *brk_reachable, void *mem_ctx)
574 {
575    nir_block *candidate = (nir_block *)
576                              _mesa_set_next_entry(remaining, NULL)
577                                 ->key;
578    struct set *old_candidates = _mesa_pointer_set_create(mem_ctx);
579    while (candidate) {
580       _mesa_set_add(old_candidates, candidate);
581 
582       /* Start with just the candidate block */
583       _mesa_set_clear(curr_level->blocks, NULL);
584       _mesa_set_add(curr_level->blocks, candidate);
585 
586       candidate = NULL;
587       set_foreach(remaining, entry) {
588          nir_block *remaining_block = (nir_block *)entry->key;
589          if (!_mesa_set_search(curr_level->blocks, remaining_block) &&
590              _mesa_set_intersects(remaining_block->dom_frontier,
591                                   curr_level->blocks)) {
592             if (_mesa_set_search(old_candidates, remaining_block)) {
593                _mesa_set_add(curr_level->blocks, remaining_block);
594             } else {
595                candidate = remaining_block;
596                break;
597             }
598          }
599       }
600    }
601    _mesa_set_destroy(old_candidates, NULL);
602    old_candidates = NULL;
603 
604    struct set *loop_heads = _mesa_set_clone(curr_level->blocks, curr_level);
605    curr_level->reach = _mesa_pointer_set_create(curr_level);
606    set_foreach(curr_level->blocks, entry) {
607       _mesa_set_remove_key(remaining, entry->key);
608       inside_outside((nir_block *)entry->key, loop_heads, remaining,
609                      curr_level->reach, brk_reachable, mem_ctx);
610    }
611    _mesa_set_destroy(loop_heads, NULL);
612 }
613 
614 /**
615  * organize a set of blocks into a list of levels. Where every level contains
616  * one or more blocks. So that every block is before all blocks it can reach.
617  * Also creates all path variables needed, for the control flow between the
618  * block.
619  * For example if the control flow looks like this:
620  *       A
621  *     / |
622  *    B  C
623  *    | / \
624  *    E    |
625  *     \  /
626  *      F
627  * B, C, E and F are dominance children of A
628  * The level list should look like this:
629  *          blocks  irreducible   conditional
630  * level 0   B, C     false        false
631  * level 1    E       false        true
632  * level 2    F       false        false
633  * The final structure should look like this:
634  * A
635  * if (path_select) {
636  *    B
637  * } else {
638  *    C
639  * }
640  * if (path_conditional) {
641  *   E
642  * }
643  * F
644  *
645  * \param  levels  uninitialized list
646  * \param  is_dominated  if true, no helper variables will be created for the
647  *                       zeroth level
648  */
649 static void
organize_levels(struct list_head * levels,struct set * remaining,struct set * reach,struct routes * routing,nir_function_impl * impl,bool is_domminated,void * mem_ctx)650 organize_levels(struct list_head *levels, struct set *remaining,
651                 struct set *reach, struct routes *routing,
652                 nir_function_impl *impl, bool is_domminated, void *mem_ctx)
653 {
654    if (NIR_LOWER_GOTO_IFS_DEBUG) {
655       printf("organize_levels:\n");
656       printf("    reach =     ");
657       print_block_set(reach);
658    }
659 
660    /* blocks that can be reached by the remaining blocks */
661    struct set *remaining_frontier = _mesa_pointer_set_create(mem_ctx);
662 
663    /* targets of active skip path */
664    struct set *skip_targets = _mesa_pointer_set_create(mem_ctx);
665 
666    list_inithead(levels);
667    while (remaining->entries) {
668       _mesa_set_clear(remaining_frontier, NULL);
669       set_foreach(remaining, entry) {
670          nir_block *remain_block = (nir_block *)entry->key;
671          set_foreach(remain_block->dom_frontier, frontier_entry) {
672             nir_block *frontier = (nir_block *)frontier_entry->key;
673             if (frontier != remain_block) {
674                _mesa_set_add(remaining_frontier, frontier);
675             }
676          }
677       }
678 
679       struct strct_lvl *curr_level = rzalloc(mem_ctx, struct strct_lvl);
680       curr_level->blocks = _mesa_pointer_set_create(curr_level);
681       set_foreach(remaining, entry) {
682          nir_block *candidate = (nir_block *)entry->key;
683          if (!_mesa_set_search(remaining_frontier, candidate)) {
684             _mesa_set_add(curr_level->blocks, candidate);
685             _mesa_set_remove_key(remaining, candidate);
686          }
687       }
688 
689       curr_level->irreducible = !curr_level->blocks->entries;
690       if (curr_level->irreducible) {
691          handle_irreducible(remaining, curr_level,
692                             routing->brk.reachable, mem_ctx);
693       }
694       assert(curr_level->blocks->entries);
695 
696       struct strct_lvl *prev_level = NULL;
697       if (!list_is_empty(levels))
698          prev_level = list_last_entry(levels, struct strct_lvl, link);
699 
700       set_foreach(skip_targets, entry) {
701          if (_mesa_set_search_pre_hashed(curr_level->blocks,
702                                          entry->hash, entry->key)) {
703             _mesa_set_remove(skip_targets, entry);
704             prev_level->skip_end = 1;
705          }
706       }
707       curr_level->skip_start = skip_targets->entries != 0;
708 
709       struct set *prev_frontier = NULL;
710       if (!prev_level) {
711          prev_frontier = _mesa_set_clone(reach, curr_level);
712       } else if (prev_level->irreducible) {
713          prev_frontier = _mesa_set_clone(prev_level->reach, curr_level);
714       }
715 
716       set_foreach(curr_level->blocks, blocks_entry) {
717          nir_block *level_block = (nir_block *)blocks_entry->key;
718          if (prev_frontier == NULL) {
719             prev_frontier =
720                _mesa_set_clone(level_block->dom_frontier, curr_level);
721          } else {
722             set_foreach(level_block->dom_frontier, entry)
723                _mesa_set_add_pre_hashed(prev_frontier, entry->hash,
724                                         entry->key);
725          }
726       }
727 
728       bool is_in_skip = skip_targets->entries != 0;
729       set_foreach(prev_frontier, entry) {
730          if (_mesa_set_search(remaining, entry->key) ||
731              (_mesa_set_search(routing->regular.reachable, entry->key) &&
732               !_mesa_set_search(routing->brk.reachable, entry->key) &&
733               !_mesa_set_search(routing->cont.reachable, entry->key))) {
734             _mesa_set_add_pre_hashed(skip_targets, entry->hash, entry->key);
735             if (is_in_skip)
736                prev_level->skip_end = 1;
737             curr_level->skip_start = 1;
738          }
739       }
740 
741       curr_level->skip_end = 0;
742       list_addtail(&curr_level->link, levels);
743    }
744 
745    if (NIR_LOWER_GOTO_IFS_DEBUG) {
746       printf("    levels:\n");
747       list_for_each_entry(struct strct_lvl, level, levels, link) {
748          printf("        ");
749          print_block_set(level->blocks);
750       }
751       printf("\n");
752    }
753 
754    if (skip_targets->entries)
755       list_last_entry(levels, struct strct_lvl, link)->skip_end = 1;
756 
757    /* Iterate throught all levels reverse and create all the paths and forks */
758    struct path path_after_skip;
759 
760    list_for_each_entry_rev(struct strct_lvl, level, levels, link) {
761       bool need_var = !(is_domminated && level->link.prev == levels);
762       level->out_path = routing->regular;
763       if (level->skip_end) {
764          path_after_skip = routing->regular;
765       }
766       routing->regular.reachable = level->blocks;
767       routing->regular.fork = select_fork(routing->regular.reachable, impl,
768                                           need_var, mem_ctx);
769       if (level->skip_start) {
770          struct path_fork *fork = rzalloc(mem_ctx, struct path_fork);
771          fork->is_var = need_var;
772          if (need_var)
773             fork->path_var = nir_local_variable_create(impl, glsl_bool_type(),
774                                                        "path_conditional");
775          fork->paths[0] = path_after_skip;
776          fork->paths[1] = routing->regular;
777          routing->regular.fork = fork;
778          routing->regular.reachable = fork_reachable(fork);
779       }
780    }
781 }
782 
783 static void
784 nir_structurize(struct routes *routing, nir_builder *b,
785                 nir_block *block, void *mem_ctx);
786 
787 /**
788  * Places all the if else statements to select between all blocks in a select
789  * path
790  */
791 static void
select_blocks(struct routes * routing,nir_builder * b,struct path in_path,void * mem_ctx)792 select_blocks(struct routes *routing, nir_builder *b,
793               struct path in_path, void *mem_ctx)
794 {
795    if (!in_path.fork) {
796       nir_block *block = block_for_singular_set(in_path.reachable);
797       nir_structurize(routing, b, block, mem_ctx);
798    } else {
799       assert(!(in_path.fork->is_var &&
800                strcmp(in_path.fork->path_var->name, "path_select")));
801       nir_push_if(b, fork_condition(b, in_path.fork));
802       select_blocks(routing, b, in_path.fork->paths[1], mem_ctx);
803       nir_push_else(b, NULL);
804       select_blocks(routing, b, in_path.fork->paths[0], mem_ctx);
805       nir_pop_if(b, NULL);
806    }
807 }
808 
809 /**
810  * Builds the structurized nir code by the final level list.
811  */
812 static void
plant_levels(struct list_head * levels,struct routes * routing,nir_builder * b,void * mem_ctx)813 plant_levels(struct list_head *levels, struct routes *routing,
814              nir_builder *b, void *mem_ctx)
815 {
816    /* Place all dominated blocks and build the path forks */
817    list_for_each_entry(struct strct_lvl, level, levels, link) {
818       if (level->skip_start) {
819          assert(routing->regular.fork);
820          assert(!(routing->regular.fork->is_var && strcmp(
821                                                       routing->regular.fork->path_var->name, "path_conditional")));
822          nir_push_if(b, fork_condition(b, routing->regular.fork));
823          routing->regular = routing->regular.fork->paths[1];
824       }
825       struct path in_path = routing->regular;
826       routing->regular = level->out_path;
827       if (level->irreducible)
828          loop_routing_start(routing, b, in_path, level->reach, mem_ctx);
829       select_blocks(routing, b, in_path, mem_ctx);
830       if (level->irreducible)
831          loop_routing_end(routing, b);
832       if (level->skip_end)
833          nir_pop_if(b, NULL);
834    }
835 }
836 
837 /**
838  * builds the control flow of a block and all its dominance children
839  * \param  routing  the routing after the block and all dominated blocks
840  */
841 static void
nir_structurize(struct routes * routing,nir_builder * b,nir_block * block,void * mem_ctx)842 nir_structurize(struct routes *routing, nir_builder *b, nir_block *block,
843                 void *mem_ctx)
844 {
845    struct set *remaining = _mesa_pointer_set_create(mem_ctx);
846    for (int i = 0; i < block->num_dom_children; i++) {
847       if (!_mesa_set_search(routing->brk.reachable, block->dom_children[i]))
848          _mesa_set_add(remaining, block->dom_children[i]);
849    }
850 
851    /* If the block can reach back to itself, it is a loop head */
852    int is_looped = _mesa_set_search(block->dom_frontier, block) != NULL;
853    struct list_head outside_levels;
854    if (is_looped) {
855       struct set *loop_heads = _mesa_pointer_set_create(mem_ctx);
856       _mesa_set_add(loop_heads, block);
857 
858       struct set *outside = _mesa_pointer_set_create(mem_ctx);
859       struct set *reach = _mesa_pointer_set_create(mem_ctx);
860       inside_outside(block, loop_heads, outside, reach,
861                      routing->brk.reachable, mem_ctx);
862 
863       set_foreach(outside, entry)
864          _mesa_set_remove_key(remaining, entry->key);
865 
866       organize_levels(&outside_levels, outside, reach, routing, b->impl,
867                       false, mem_ctx);
868 
869       struct path loop_path = {
870          .reachable = _mesa_pointer_set_create(mem_ctx),
871          .fork = NULL,
872       };
873       _mesa_set_add(loop_path.reachable, block);
874 
875       loop_routing_start(routing, b, loop_path, reach, mem_ctx);
876    }
877 
878    struct set *reach = _mesa_pointer_set_create(mem_ctx);
879    if (block->successors[0]->successors[0]) /* it is not the end_block */
880       _mesa_set_add(reach, block->successors[0]);
881    if (block->successors[1] && block->successors[1]->successors[0])
882       _mesa_set_add(reach, block->successors[1]);
883 
884    struct list_head levels;
885    organize_levels(&levels, remaining, reach, routing, b->impl, true, mem_ctx);
886 
887    /* Push all instructions of this block, without the jump instr */
888    nir_jump_instr *jump_instr = NULL;
889    nir_foreach_instr_safe(instr, block) {
890       if (instr->type == nir_instr_type_jump) {
891          jump_instr = nir_instr_as_jump(instr);
892          break;
893       }
894       nir_instr_remove(instr);
895       nir_builder_instr_insert(b, instr);
896    }
897 
898    /* Find path to the successor blocks */
899    if (jump_instr->type == nir_jump_goto_if) {
900       route_to_cond(b, routing, jump_instr->condition.ssa,
901                     jump_instr->target, jump_instr->else_target);
902    } else {
903       route_to(b, routing, block->successors[0]);
904    }
905 
906    plant_levels(&levels, routing, b, mem_ctx);
907    if (is_looped) {
908       loop_routing_end(routing, b);
909       plant_levels(&outside_levels, routing, b, mem_ctx);
910    }
911 }
912 
913 static bool
nir_lower_goto_ifs_impl(nir_function_impl * impl)914 nir_lower_goto_ifs_impl(nir_function_impl *impl)
915 {
916    if (impl->structured) {
917       nir_metadata_preserve(impl, nir_metadata_all);
918       return false;
919    }
920 
921    nir_metadata_require(impl, nir_metadata_dominance);
922 
923    /* We're going to re-arrange blocks like crazy.  This is much easier to do
924     * if we don't have any phi nodes to fix up.
925     */
926    nir_foreach_block_unstructured(block, impl)
927       nir_lower_phis_to_regs_block(block);
928 
929    nir_cf_list cf_list;
930    nir_cf_extract(&cf_list, nir_before_impl(impl),
931                   nir_after_impl(impl));
932 
933    /* From this point on, it's structured */
934    impl->structured = true;
935 
936    nir_builder b = nir_builder_at(nir_before_impl(impl));
937 
938    void *mem_ctx = ralloc_context(b.shader);
939 
940    struct set *end_set = _mesa_pointer_set_create(mem_ctx);
941    _mesa_set_add(end_set, impl->end_block);
942    struct set *empty_set = _mesa_pointer_set_create(mem_ctx);
943 
944    nir_cf_node *start_node =
945       exec_node_data(nir_cf_node, exec_list_get_head(&cf_list.list), node);
946    nir_block *start_block = nir_cf_node_as_block(start_node);
947 
948    struct routes *routing = rzalloc(mem_ctx, struct routes);
949    *routing = (struct routes){
950       .regular.reachable = end_set,
951       .brk.reachable = empty_set,
952       .cont.reachable = empty_set,
953    };
954    nir_structurize(routing, &b, start_block, mem_ctx);
955    assert(routing->regular.fork == NULL);
956    assert(routing->brk.fork == NULL);
957    assert(routing->cont.fork == NULL);
958    assert(routing->brk.reachable == empty_set);
959    assert(routing->cont.reachable == empty_set);
960 
961    ralloc_free(mem_ctx);
962    nir_cf_delete(&cf_list);
963 
964    nir_metadata_preserve(impl, nir_metadata_none);
965 
966    nir_repair_ssa_impl(impl);
967    nir_lower_reg_intrinsics_to_ssa_impl(impl);
968 
969    return true;
970 }
971 
972 bool
nir_lower_goto_ifs(nir_shader * shader)973 nir_lower_goto_ifs(nir_shader *shader)
974 {
975    bool progress = true;
976 
977    nir_foreach_function_impl(impl, shader) {
978       if (nir_lower_goto_ifs_impl(impl))
979          progress = true;
980    }
981 
982    return progress;
983 }
984