xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_shader_calls.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2020 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "util/u_dynarray.h"
25 #include "util/u_math.h"
26 #include "nir.h"
27 #include "nir_builder.h"
28 #include "nir_phi_builder.h"
29 
30 static bool
move_system_values_to_top(nir_shader * shader)31 move_system_values_to_top(nir_shader *shader)
32 {
33    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
34 
35    bool progress = false;
36    nir_foreach_block(block, impl) {
37       nir_foreach_instr_safe(instr, block) {
38          if (instr->type != nir_instr_type_intrinsic)
39             continue;
40 
41          /* These intrinsics not only can't be re-materialized but aren't
42           * preserved when moving to the continuation shader.  We have to move
43           * them to the top to ensure they get spilled as needed.
44           */
45          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
46          switch (intrin->intrinsic) {
47          case nir_intrinsic_load_shader_record_ptr:
48          case nir_intrinsic_load_btd_local_arg_addr_intel:
49             nir_instr_remove(instr);
50             nir_instr_insert(nir_before_impl(impl), instr);
51             progress = true;
52             break;
53 
54          default:
55             break;
56          }
57       }
58    }
59 
60    if (progress) {
61       nir_metadata_preserve(impl, nir_metadata_control_flow);
62    } else {
63       nir_metadata_preserve(impl, nir_metadata_all);
64    }
65 
66    return progress;
67 }
68 
69 static bool
instr_is_shader_call(nir_instr * instr)70 instr_is_shader_call(nir_instr *instr)
71 {
72    if (instr->type != nir_instr_type_intrinsic)
73       return false;
74 
75    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
76    return intrin->intrinsic == nir_intrinsic_trace_ray ||
77           intrin->intrinsic == nir_intrinsic_report_ray_intersection ||
78           intrin->intrinsic == nir_intrinsic_execute_callable;
79 }
80 
81 /* Previously named bitset, it had to be renamed as FreeBSD defines a struct
82  * named bitset in sys/_bitset.h required by pthread_np.h which is included
83  * from src/util/u_thread.h that is indirectly included by this file.
84  */
85 struct sized_bitset {
86    BITSET_WORD *set;
87    unsigned size;
88 };
89 
90 static struct sized_bitset
bitset_create(void * mem_ctx,unsigned size)91 bitset_create(void *mem_ctx, unsigned size)
92 {
93    return (struct sized_bitset){
94       .set = rzalloc_array(mem_ctx, BITSET_WORD, BITSET_WORDS(size)),
95       .size = size,
96    };
97 }
98 
99 static bool
src_is_in_bitset(nir_src * src,void * _set)100 src_is_in_bitset(nir_src *src, void *_set)
101 {
102    struct sized_bitset *set = _set;
103 
104    /* Any SSA values which were added after we generated liveness information
105     * are things generated by this pass and, while most of it is arithmetic
106     * which we could re-materialize, we don't need to because it's only used
107     * for a single load/store and so shouldn't cross any shader calls.
108     */
109    if (src->ssa->index >= set->size)
110       return false;
111 
112    return BITSET_TEST(set->set, src->ssa->index);
113 }
114 
115 static void
add_ssa_def_to_bitset(nir_def * def,struct sized_bitset * set)116 add_ssa_def_to_bitset(nir_def *def, struct sized_bitset *set)
117 {
118    if (def->index >= set->size)
119       return;
120 
121    BITSET_SET(set->set, def->index);
122 }
123 
124 static bool
can_remat_instr(nir_instr * instr,struct sized_bitset * remat)125 can_remat_instr(nir_instr *instr, struct sized_bitset *remat)
126 {
127    /* Set of all values which are trivially re-materializable and we shouldn't
128     * ever spill them.  This includes:
129     *
130     *   - Undef values
131     *   - Constants
132     *   - Uniforms (UBO or push constant)
133     *   - ALU combinations of any of the above
134     *   - Derefs which are either complete or casts of any of the above
135     *
136     * Because this pass rewrites things in-order and phis are always turned
137     * into register writes, we can use "is it SSA?" to answer the question
138     * "can my source be re-materialized?". Register writes happen via
139     * non-rematerializable intrinsics.
140     */
141    switch (instr->type) {
142    case nir_instr_type_alu:
143       return nir_foreach_src(instr, src_is_in_bitset, remat);
144 
145    case nir_instr_type_deref:
146       return nir_foreach_src(instr, src_is_in_bitset, remat);
147 
148    case nir_instr_type_intrinsic: {
149       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
150       switch (intrin->intrinsic) {
151       case nir_intrinsic_load_uniform:
152       case nir_intrinsic_load_ubo:
153       case nir_intrinsic_vulkan_resource_index:
154       case nir_intrinsic_vulkan_resource_reindex:
155       case nir_intrinsic_load_vulkan_descriptor:
156       case nir_intrinsic_load_push_constant:
157       case nir_intrinsic_load_global_constant:
158          /* These intrinsics don't need to be spilled as long as they don't
159           * depend on any spilled values.
160           */
161          return nir_foreach_src(instr, src_is_in_bitset, remat);
162 
163       case nir_intrinsic_load_scratch_base_ptr:
164       case nir_intrinsic_load_ray_launch_id:
165       case nir_intrinsic_load_topology_id_intel:
166       case nir_intrinsic_load_btd_global_arg_addr_intel:
167       case nir_intrinsic_load_btd_resume_sbt_addr_intel:
168       case nir_intrinsic_load_ray_base_mem_addr_intel:
169       case nir_intrinsic_load_ray_hw_stack_size_intel:
170       case nir_intrinsic_load_ray_sw_stack_size_intel:
171       case nir_intrinsic_load_ray_num_dss_rt_stacks_intel:
172       case nir_intrinsic_load_ray_hit_sbt_addr_intel:
173       case nir_intrinsic_load_ray_hit_sbt_stride_intel:
174       case nir_intrinsic_load_ray_miss_sbt_addr_intel:
175       case nir_intrinsic_load_ray_miss_sbt_stride_intel:
176       case nir_intrinsic_load_callable_sbt_addr_intel:
177       case nir_intrinsic_load_callable_sbt_stride_intel:
178       case nir_intrinsic_load_reloc_const_intel:
179       case nir_intrinsic_load_ray_query_global_intel:
180       case nir_intrinsic_load_ray_launch_size:
181          /* Notably missing from the above list is btd_local_arg_addr_intel.
182           * This is because the resume shader will have a different local
183           * argument pointer because it has a different BSR.  Any access of
184           * the original shader's local arguments needs to be preserved so
185           * that pointer has to be saved on the stack.
186           *
187           * TODO: There may be some system values we want to avoid
188           *       re-materializing as well but we have to be very careful
189           *       to ensure that it's a system value which cannot change
190           *       across a shader call.
191           */
192          return true;
193 
194       case nir_intrinsic_resource_intel:
195          return nir_foreach_src(instr, src_is_in_bitset, remat);
196 
197       default:
198          return false;
199       }
200    }
201 
202    case nir_instr_type_undef:
203    case nir_instr_type_load_const:
204       return true;
205 
206    default:
207       return false;
208    }
209 }
210 
211 static bool
can_remat_ssa_def(nir_def * def,struct sized_bitset * remat)212 can_remat_ssa_def(nir_def *def, struct sized_bitset *remat)
213 {
214    return can_remat_instr(def->parent_instr, remat);
215 }
216 
217 struct add_instr_data {
218    struct util_dynarray *buf;
219    struct sized_bitset *remat;
220 };
221 
222 static bool
add_src_instr(nir_src * src,void * state)223 add_src_instr(nir_src *src, void *state)
224 {
225    struct add_instr_data *data = state;
226    if (BITSET_TEST(data->remat->set, src->ssa->index))
227       return true;
228 
229    util_dynarray_foreach(data->buf, nir_instr *, instr_ptr) {
230       if (*instr_ptr == src->ssa->parent_instr)
231          return true;
232    }
233 
234    /* Abort rematerializing an instruction chain if it is too long. */
235    if (data->buf->size >= data->buf->capacity)
236       return false;
237 
238    util_dynarray_append(data->buf, nir_instr *, src->ssa->parent_instr);
239    return true;
240 }
241 
242 static int
compare_instr_indexes(const void * _inst1,const void * _inst2)243 compare_instr_indexes(const void *_inst1, const void *_inst2)
244 {
245    const nir_instr *const *inst1 = _inst1;
246    const nir_instr *const *inst2 = _inst2;
247 
248    return (*inst1)->index - (*inst2)->index;
249 }
250 
251 static bool
can_remat_chain_ssa_def(nir_def * def,struct sized_bitset * remat,struct util_dynarray * buf)252 can_remat_chain_ssa_def(nir_def *def, struct sized_bitset *remat, struct util_dynarray *buf)
253 {
254    assert(util_dynarray_num_elements(buf, nir_instr *) == 0);
255 
256    void *mem_ctx = ralloc_context(NULL);
257 
258    /* Add all the instructions involved in build this ssa_def */
259    util_dynarray_append(buf, nir_instr *, def->parent_instr);
260 
261    unsigned idx = 0;
262    struct add_instr_data data = {
263       .buf = buf,
264       .remat = remat,
265    };
266    while (idx < util_dynarray_num_elements(buf, nir_instr *)) {
267       nir_instr *instr = *util_dynarray_element(buf, nir_instr *, idx++);
268       if (!nir_foreach_src(instr, add_src_instr, &data))
269          goto fail;
270    }
271 
272    /* Sort instructions by index */
273    qsort(util_dynarray_begin(buf),
274          util_dynarray_num_elements(buf, nir_instr *),
275          sizeof(nir_instr *),
276          compare_instr_indexes);
277 
278    /* Create a temporary bitset with all values already
279     * rematerialized/rematerializable. We'll add to this bit set as we go
280     * through values that might not be in that set but that we can
281     * rematerialize.
282     */
283    struct sized_bitset potential_remat = bitset_create(mem_ctx, remat->size);
284    memcpy(potential_remat.set, remat->set, BITSET_WORDS(remat->size) * sizeof(BITSET_WORD));
285 
286    util_dynarray_foreach(buf, nir_instr *, instr_ptr) {
287       nir_def *instr_ssa_def = nir_instr_def(*instr_ptr);
288 
289       /* If already in the potential rematerializable, nothing to do. */
290       if (BITSET_TEST(potential_remat.set, instr_ssa_def->index))
291          continue;
292 
293       if (!can_remat_instr(*instr_ptr, &potential_remat))
294          goto fail;
295 
296       /* All the sources are rematerializable and the instruction is also
297        * rematerializable, mark it as rematerializable too.
298        */
299       BITSET_SET(potential_remat.set, instr_ssa_def->index);
300    }
301 
302    ralloc_free(mem_ctx);
303 
304    return true;
305 
306 fail:
307    util_dynarray_clear(buf);
308    ralloc_free(mem_ctx);
309    return false;
310 }
311 
312 static nir_def *
remat_ssa_def(nir_builder * b,nir_def * def,struct hash_table * remap_table)313 remat_ssa_def(nir_builder *b, nir_def *def, struct hash_table *remap_table)
314 {
315    nir_instr *clone = nir_instr_clone_deep(b->shader, def->parent_instr, remap_table);
316    nir_builder_instr_insert(b, clone);
317    return nir_instr_def(clone);
318 }
319 
320 static nir_def *
remat_chain_ssa_def(nir_builder * b,struct util_dynarray * buf,struct sized_bitset * remat,nir_def *** fill_defs,unsigned call_idx,struct hash_table * remap_table)321 remat_chain_ssa_def(nir_builder *b, struct util_dynarray *buf,
322                     struct sized_bitset *remat, nir_def ***fill_defs,
323                     unsigned call_idx, struct hash_table *remap_table)
324 {
325    nir_def *last_def = NULL;
326 
327    util_dynarray_foreach(buf, nir_instr *, instr_ptr) {
328       nir_def *instr_ssa_def = nir_instr_def(*instr_ptr);
329       unsigned ssa_index = instr_ssa_def->index;
330 
331       if (fill_defs[ssa_index] != NULL &&
332           fill_defs[ssa_index][call_idx] != NULL)
333          continue;
334 
335       /* Clone the instruction we want to rematerialize */
336       nir_def *clone_ssa_def = remat_ssa_def(b, instr_ssa_def, remap_table);
337 
338       if (fill_defs[ssa_index] == NULL) {
339          fill_defs[ssa_index] =
340             rzalloc_array(fill_defs, nir_def *, remat->size);
341       }
342 
343       /* Add the new ssa_def to the list fill_defs and flag it as
344        * rematerialized
345        */
346       fill_defs[ssa_index][call_idx] = last_def = clone_ssa_def;
347       BITSET_SET(remat->set, ssa_index);
348 
349       _mesa_hash_table_insert(remap_table, instr_ssa_def, last_def);
350    }
351 
352    return last_def;
353 }
354 
355 struct pbv_array {
356    struct nir_phi_builder_value **arr;
357    unsigned len;
358 };
359 
360 static struct nir_phi_builder_value *
get_phi_builder_value_for_def(nir_def * def,struct pbv_array * pbv_arr)361 get_phi_builder_value_for_def(nir_def *def,
362                               struct pbv_array *pbv_arr)
363 {
364    if (def->index >= pbv_arr->len)
365       return NULL;
366 
367    return pbv_arr->arr[def->index];
368 }
369 
370 static nir_def *
get_phi_builder_def_for_src(nir_src * src,struct pbv_array * pbv_arr,nir_block * block)371 get_phi_builder_def_for_src(nir_src *src, struct pbv_array *pbv_arr,
372                             nir_block *block)
373 {
374 
375    struct nir_phi_builder_value *pbv =
376       get_phi_builder_value_for_def(src->ssa, pbv_arr);
377    if (pbv == NULL)
378       return NULL;
379 
380    return nir_phi_builder_value_get_block_def(pbv, block);
381 }
382 
383 static bool
rewrite_instr_src_from_phi_builder(nir_src * src,void * _pbv_arr)384 rewrite_instr_src_from_phi_builder(nir_src *src, void *_pbv_arr)
385 {
386    nir_block *block;
387    if (nir_src_parent_instr(src)->type == nir_instr_type_phi) {
388       nir_phi_src *phi_src = exec_node_data(nir_phi_src, src, src);
389       block = phi_src->pred;
390    } else {
391       block = nir_src_parent_instr(src)->block;
392    }
393 
394    nir_def *new_def = get_phi_builder_def_for_src(src, _pbv_arr, block);
395    if (new_def != NULL)
396       nir_src_rewrite(src, new_def);
397    return true;
398 }
399 
400 static nir_def *
spill_fill(nir_builder * before,nir_builder * after,nir_def * def,unsigned value_id,unsigned call_idx,unsigned offset,unsigned stack_alignment)401 spill_fill(nir_builder *before, nir_builder *after, nir_def *def,
402            unsigned value_id, unsigned call_idx,
403            unsigned offset, unsigned stack_alignment)
404 {
405    const unsigned comp_size = def->bit_size / 8;
406 
407    nir_store_stack(before, def,
408                    .base = offset,
409                    .call_idx = call_idx,
410                    .align_mul = MIN2(comp_size, stack_alignment),
411                    .value_id = value_id,
412                    .write_mask = BITFIELD_MASK(def->num_components));
413    return nir_load_stack(after, def->num_components, def->bit_size,
414                          .base = offset,
415                          .call_idx = call_idx,
416                          .value_id = value_id,
417                          .align_mul = MIN2(comp_size, stack_alignment));
418 }
419 
420 static bool
add_src_to_call_live_bitset(nir_src * src,void * state)421 add_src_to_call_live_bitset(nir_src *src, void *state)
422 {
423    BITSET_WORD *call_live = state;
424 
425    BITSET_SET(call_live, src->ssa->index);
426    return true;
427 }
428 
429 static void
spill_ssa_defs_and_lower_shader_calls(nir_shader * shader,uint32_t num_calls,const nir_lower_shader_calls_options * options)430 spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
431                                       const nir_lower_shader_calls_options *options)
432 {
433    /* TODO: If a SSA def is filled more than once, we probably want to just
434     *       spill it at the LCM of the fill sites so we avoid unnecessary
435     *       extra spills
436     *
437     * TODO: If a SSA def is defined outside a loop but live through some call
438     *       inside the loop, we probably want to spill outside the loop.  We
439     *       may also want to fill outside the loop if it's not used in the
440     *       loop.
441     *
442     * TODO: Right now, we only re-materialize things if their immediate
443     *       sources are things which we filled.  We probably want to expand
444     *       that to re-materialize things whose sources are things we can
445     *       re-materialize from things we filled.  We may want some DAG depth
446     *       heuristic on this.
447     */
448 
449    /* This happens per-shader rather than per-impl because we mess with
450     * nir_shader::scratch_size.
451     */
452    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
453 
454    nir_metadata_require(impl, nir_metadata_live_defs |
455                                  nir_metadata_dominance |
456                                  nir_metadata_block_index |
457                                  nir_metadata_instr_index);
458 
459    void *mem_ctx = ralloc_context(shader);
460 
461    const unsigned num_ssa_defs = impl->ssa_alloc;
462    const unsigned live_words = BITSET_WORDS(num_ssa_defs);
463    struct sized_bitset trivial_remat = bitset_create(mem_ctx, num_ssa_defs);
464 
465    /* Array of all live SSA defs which are spill candidates */
466    nir_def **spill_defs =
467       rzalloc_array(mem_ctx, nir_def *, num_ssa_defs);
468 
469    /* For each spill candidate, an array of every time it's defined by a fill,
470     * indexed by call instruction index.
471     */
472    nir_def ***fill_defs =
473       rzalloc_array(mem_ctx, nir_def **, num_ssa_defs);
474 
475    /* For each call instruction, the liveness set at the call */
476    const BITSET_WORD **call_live =
477       rzalloc_array(mem_ctx, const BITSET_WORD *, num_calls);
478 
479    /* For each call instruction, the block index of the block it lives in */
480    uint32_t *call_block_indices = rzalloc_array(mem_ctx, uint32_t, num_calls);
481 
482    /* Remap table when rebuilding instructions out of fill operations */
483    struct hash_table *trivial_remap_table =
484       _mesa_pointer_hash_table_create(mem_ctx);
485 
486    /* Walk the call instructions and fetch the liveness set and block index
487     * for each one.  We need to do this before we start modifying the shader
488     * so that liveness doesn't complain that it's been invalidated.  Don't
489     * worry, we'll be very careful with our live sets. :-)
490     */
491    unsigned call_idx = 0;
492    nir_foreach_block(block, impl) {
493       nir_foreach_instr(instr, block) {
494          if (!instr_is_shader_call(instr))
495             continue;
496 
497          call_block_indices[call_idx] = block->index;
498 
499          /* The objective here is to preserve values around shader call
500           * instructions.  Therefore, we use the live set after the
501           * instruction as the set of things we want to preserve.  Because
502           * none of our shader call intrinsics return anything, we don't have
503           * to worry about spilling over a return value.
504           *
505           * TODO: This isn't quite true for report_intersection.
506           */
507          call_live[call_idx] =
508             nir_get_live_defs(nir_after_instr(instr), mem_ctx);
509 
510          call_idx++;
511       }
512    }
513 
514    /* If a should_remat_callback is given, call it on each of the live values
515     * for each call site. If it returns true we need to rematerialize that
516     * instruction (instead of spill/fill). Therefore we need to add the
517     * sources as live values so that we can rematerialize on top of those
518     * spilled/filled sources.
519     */
520    if (options->should_remat_callback) {
521       BITSET_WORD **updated_call_live =
522          rzalloc_array(mem_ctx, BITSET_WORD *, num_calls);
523 
524       nir_foreach_block(block, impl) {
525          nir_foreach_instr(instr, block) {
526             nir_def *def = nir_instr_def(instr);
527             if (def == NULL)
528                continue;
529 
530             for (unsigned c = 0; c < num_calls; c++) {
531                if (!BITSET_TEST(call_live[c], def->index))
532                   continue;
533 
534                if (!options->should_remat_callback(def->parent_instr,
535                                                    options->should_remat_data))
536                   continue;
537 
538                if (updated_call_live[c] == NULL) {
539                   const unsigned bitset_words = BITSET_WORDS(impl->ssa_alloc);
540                   updated_call_live[c] = ralloc_array(mem_ctx, BITSET_WORD, bitset_words);
541                   memcpy(updated_call_live[c], call_live[c], bitset_words * sizeof(BITSET_WORD));
542                }
543 
544                nir_foreach_src(instr, add_src_to_call_live_bitset, updated_call_live[c]);
545             }
546          }
547       }
548 
549       for (unsigned c = 0; c < num_calls; c++) {
550          if (updated_call_live[c] != NULL)
551             call_live[c] = updated_call_live[c];
552       }
553    }
554 
555    nir_builder before, after;
556    before = nir_builder_create(impl);
557    after = nir_builder_create(impl);
558 
559    call_idx = 0;
560    unsigned max_scratch_size = shader->scratch_size;
561    nir_foreach_block(block, impl) {
562       nir_foreach_instr_safe(instr, block) {
563          nir_def *def = nir_instr_def(instr);
564          if (def != NULL) {
565             if (can_remat_ssa_def(def, &trivial_remat)) {
566                add_ssa_def_to_bitset(def, &trivial_remat);
567                _mesa_hash_table_insert(trivial_remap_table, def, def);
568             } else {
569                spill_defs[def->index] = def;
570             }
571          }
572 
573          if (!instr_is_shader_call(instr))
574             continue;
575 
576          const BITSET_WORD *live = call_live[call_idx];
577 
578          struct hash_table *remap_table =
579             _mesa_hash_table_clone(trivial_remap_table, mem_ctx);
580 
581          /* Make a copy of trivial_remat that we'll update as we crawl through
582           * the live SSA defs and unspill them.
583           */
584          struct sized_bitset remat = bitset_create(mem_ctx, num_ssa_defs);
585          memcpy(remat.set, trivial_remat.set, live_words * sizeof(BITSET_WORD));
586 
587          /* Before the two builders are always separated by the call
588           * instruction, it won't break anything to have two of them.
589           */
590          before.cursor = nir_before_instr(instr);
591          after.cursor = nir_after_instr(instr);
592 
593          /* Array used to hold all the values needed to rematerialize a live
594           * value. The capacity is used to determine when we should abort testing
595           * a remat chain. In practice, shaders can have chains with more than
596           * 10k elements while only chains with less than 16 have realistic
597           * chances. There also isn't any performance benefit in rematerializing
598           * extremely long chains.
599           */
600          nir_instr *remat_chain_instrs[16];
601          struct util_dynarray remat_chain;
602          util_dynarray_init_from_stack(&remat_chain, remat_chain_instrs, sizeof(remat_chain_instrs));
603 
604          unsigned offset = shader->scratch_size;
605          for (unsigned w = 0; w < live_words; w++) {
606             BITSET_WORD spill_mask = live[w] & ~trivial_remat.set[w];
607             while (spill_mask) {
608                int i = u_bit_scan(&spill_mask);
609                assert(i >= 0);
610                unsigned index = w * BITSET_WORDBITS + i;
611                assert(index < num_ssa_defs);
612 
613                def = spill_defs[index];
614                nir_def *original_def = def, *new_def;
615                if (can_remat_ssa_def(def, &remat)) {
616                   /* If this SSA def is re-materializable or based on other
617                    * things we've already spilled, re-materialize it rather
618                    * than spilling and filling.  Anything which is trivially
619                    * re-materializable won't even get here because we take
620                    * those into account in spill_mask above.
621                    */
622                   new_def = remat_ssa_def(&after, def, remap_table);
623                } else if (can_remat_chain_ssa_def(def, &remat, &remat_chain)) {
624                   new_def = remat_chain_ssa_def(&after, &remat_chain, &remat,
625                                                 fill_defs, call_idx,
626                                                 remap_table);
627                   util_dynarray_clear(&remat_chain);
628                } else {
629                   bool is_bool = def->bit_size == 1;
630                   if (is_bool)
631                      def = nir_b2b32(&before, def);
632 
633                   const unsigned comp_size = def->bit_size / 8;
634                   offset = ALIGN(offset, comp_size);
635 
636                   new_def = spill_fill(&before, &after, def,
637                                        index, call_idx,
638                                        offset, options->stack_alignment);
639 
640                   if (is_bool)
641                      new_def = nir_b2b1(&after, new_def);
642 
643                   offset += def->num_components * comp_size;
644                }
645 
646                /* Mark this SSA def as available in the remat set so that, if
647                 * some other SSA def we need is computed based on it, we can
648                 * just re-compute instead of fetching from memory.
649                 */
650                BITSET_SET(remat.set, index);
651 
652                /* For now, we just make a note of this new SSA def.  We'll
653                 * fix things up with the phi builder as a second pass.
654                 */
655                if (fill_defs[index] == NULL) {
656                   fill_defs[index] =
657                      rzalloc_array(fill_defs, nir_def *, num_calls);
658                }
659                fill_defs[index][call_idx] = new_def;
660                _mesa_hash_table_insert(remap_table, original_def, new_def);
661             }
662          }
663 
664          nir_builder *b = &before;
665 
666          offset = ALIGN(offset, options->stack_alignment);
667          max_scratch_size = MAX2(max_scratch_size, offset);
668 
669          /* First thing on the called shader's stack is the resume address
670           * followed by a pointer to the payload.
671           */
672          nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
673 
674          /* Lower to generic intrinsics with information about the stack & resume shader. */
675          switch (call->intrinsic) {
676          case nir_intrinsic_trace_ray: {
677             nir_rt_trace_ray(b, call->src[0].ssa, call->src[1].ssa,
678                              call->src[2].ssa, call->src[3].ssa,
679                              call->src[4].ssa, call->src[5].ssa,
680                              call->src[6].ssa, call->src[7].ssa,
681                              call->src[8].ssa, call->src[9].ssa,
682                              call->src[10].ssa,
683                              .call_idx = call_idx, .stack_size = offset);
684             break;
685          }
686 
687          case nir_intrinsic_report_ray_intersection:
688             unreachable("Any-hit shaders must be inlined");
689 
690          case nir_intrinsic_execute_callable: {
691             nir_rt_execute_callable(b, call->src[0].ssa, call->src[1].ssa, .call_idx = call_idx, .stack_size = offset);
692             break;
693          }
694 
695          default:
696             unreachable("Invalid shader call instruction");
697          }
698 
699          nir_rt_resume(b, .call_idx = call_idx, .stack_size = offset);
700 
701          nir_instr_remove(&call->instr);
702 
703          call_idx++;
704       }
705    }
706    assert(call_idx == num_calls);
707    shader->scratch_size = max_scratch_size;
708 
709    struct nir_phi_builder *pb = nir_phi_builder_create(impl);
710    struct pbv_array pbv_arr = {
711       .arr = rzalloc_array(mem_ctx, struct nir_phi_builder_value *,
712                            num_ssa_defs),
713       .len = num_ssa_defs,
714    };
715 
716    const unsigned block_words = BITSET_WORDS(impl->num_blocks);
717    BITSET_WORD *def_blocks = ralloc_array(mem_ctx, BITSET_WORD, block_words);
718 
719    /* Go through and set up phi builder values for each spillable value which
720     * we ever needed to spill at any point.
721     */
722    for (unsigned index = 0; index < num_ssa_defs; index++) {
723       if (fill_defs[index] == NULL)
724          continue;
725 
726       nir_def *def = spill_defs[index];
727 
728       memset(def_blocks, 0, block_words * sizeof(BITSET_WORD));
729       BITSET_SET(def_blocks, def->parent_instr->block->index);
730       for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) {
731          if (fill_defs[index][call_idx] != NULL)
732             BITSET_SET(def_blocks, call_block_indices[call_idx]);
733       }
734 
735       pbv_arr.arr[index] = nir_phi_builder_add_value(pb, def->num_components,
736                                                      def->bit_size, def_blocks);
737    }
738 
739    /* Walk the shader one more time and rewrite SSA defs as needed using the
740     * phi builder.
741     */
742    nir_foreach_block(block, impl) {
743       nir_foreach_instr_safe(instr, block) {
744          nir_def *def = nir_instr_def(instr);
745          if (def != NULL) {
746             struct nir_phi_builder_value *pbv =
747                get_phi_builder_value_for_def(def, &pbv_arr);
748             if (pbv != NULL)
749                nir_phi_builder_value_set_block_def(pbv, block, def);
750          }
751 
752          if (instr->type == nir_instr_type_phi)
753             continue;
754 
755          nir_foreach_src(instr, rewrite_instr_src_from_phi_builder, &pbv_arr);
756 
757          if (instr->type != nir_instr_type_intrinsic)
758             continue;
759 
760          nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
761          if (resume->intrinsic != nir_intrinsic_rt_resume)
762             continue;
763 
764          call_idx = nir_intrinsic_call_idx(resume);
765 
766          /* Technically, this is the wrong place to add the fill defs to the
767           * phi builder values because we haven't seen any of the load_scratch
768           * instructions for this call yet.  However, we know based on how we
769           * emitted them that no value ever gets used until after the load
770           * instruction has been emitted so this should be safe.  If we ever
771           * fail validation due this it likely means a bug in our spilling
772           * code and not the phi re-construction code here.
773           */
774          for (unsigned index = 0; index < num_ssa_defs; index++) {
775             if (fill_defs[index] && fill_defs[index][call_idx]) {
776                nir_phi_builder_value_set_block_def(pbv_arr.arr[index], block,
777                                                    fill_defs[index][call_idx]);
778             }
779          }
780       }
781 
782       nir_if *following_if = nir_block_get_following_if(block);
783       if (following_if) {
784          nir_def *new_def =
785             get_phi_builder_def_for_src(&following_if->condition,
786                                         &pbv_arr, block);
787          if (new_def != NULL)
788             nir_src_rewrite(&following_if->condition, new_def);
789       }
790 
791       /* Handle phi sources that source from this block.  We have to do this
792        * as a separate pass because the phi builder assumes that uses and
793        * defs are processed in an order that respects dominance.  When we have
794        * loops, a phi source may be a back-edge so we have to handle it as if
795        * it were one of the last instructions in the predecessor block.
796        */
797       nir_foreach_phi_src_leaving_block(block,
798                                         rewrite_instr_src_from_phi_builder,
799                                         &pbv_arr);
800    }
801 
802    nir_phi_builder_finish(pb);
803 
804    ralloc_free(mem_ctx);
805 
806    nir_metadata_preserve(impl, nir_metadata_control_flow);
807 }
808 
809 static nir_instr *
find_resume_instr(nir_function_impl * impl,unsigned call_idx)810 find_resume_instr(nir_function_impl *impl, unsigned call_idx)
811 {
812    nir_foreach_block(block, impl) {
813       nir_foreach_instr(instr, block) {
814          if (instr->type != nir_instr_type_intrinsic)
815             continue;
816 
817          nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
818          if (resume->intrinsic != nir_intrinsic_rt_resume)
819             continue;
820 
821          if (nir_intrinsic_call_idx(resume) == call_idx)
822             return &resume->instr;
823       }
824    }
825    unreachable("Couldn't find resume instruction");
826 }
827 
828 /* Walk the CF tree and duplicate the contents of every loop, one half runs on
829  * resume and the other half is for any post-resume loop iterations.  We are
830  * careful in our duplication to ensure that resume_instr is in the resume
831  * half of the loop though a copy of resume_instr will remain in the other
832  * half as well in case the same shader call happens twice.
833  */
834 static bool
duplicate_loop_bodies(nir_function_impl * impl,nir_instr * resume_instr)835 duplicate_loop_bodies(nir_function_impl *impl, nir_instr *resume_instr)
836 {
837    nir_def *resume_reg = NULL;
838    for (nir_cf_node *node = resume_instr->block->cf_node.parent;
839         node->type != nir_cf_node_function; node = node->parent) {
840       if (node->type != nir_cf_node_loop)
841          continue;
842 
843       nir_loop *loop = nir_cf_node_as_loop(node);
844       assert(!nir_loop_has_continue_construct(loop));
845 
846       nir_builder b = nir_builder_create(impl);
847 
848       if (resume_reg == NULL) {
849          /* We only create resume_reg if we encounter a loop.  This way we can
850           * avoid re-validating the shader and calling ssa_to_reg_intrinsics in
851           * the case where it's just if-ladders.
852           */
853          resume_reg = nir_decl_reg(&b, 1, 1, 0);
854 
855          /* Initialize resume to true at the start of the shader, right after
856           * the register is declared at the start.
857           */
858          b.cursor = nir_after_instr(resume_reg->parent_instr);
859          nir_store_reg(&b, nir_imm_true(&b), resume_reg);
860 
861          /* Set resume to false right after the resume instruction */
862          b.cursor = nir_after_instr(resume_instr);
863          nir_store_reg(&b, nir_imm_false(&b), resume_reg);
864       }
865 
866       /* Before we go any further, make sure that everything which exits the
867        * loop or continues around to the top of the loop does so through
868        * registers.  We're about to duplicate the loop body and we'll have
869        * serious trouble if we don't do this.
870        */
871       nir_convert_loop_to_lcssa(loop);
872       nir_lower_phis_to_regs_block(nir_loop_first_block(loop));
873       nir_lower_phis_to_regs_block(
874          nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node)));
875 
876       nir_cf_list cf_list;
877       nir_cf_list_extract(&cf_list, &loop->body);
878 
879       nir_if *_if = nir_if_create(impl->function->shader);
880       b.cursor = nir_after_cf_list(&loop->body);
881       _if->condition = nir_src_for_ssa(nir_load_reg(&b, resume_reg));
882       nir_cf_node_insert(nir_after_cf_list(&loop->body), &_if->cf_node);
883 
884       nir_cf_list clone;
885       nir_cf_list_clone(&clone, &cf_list, &loop->cf_node, NULL);
886 
887       /* Insert the clone in the else and the original in the then so that
888        * the resume_instr remains valid even after the duplication.
889        */
890       nir_cf_reinsert(&cf_list, nir_before_cf_list(&_if->then_list));
891       nir_cf_reinsert(&clone, nir_before_cf_list(&_if->else_list));
892    }
893 
894    if (resume_reg != NULL)
895       nir_metadata_preserve(impl, nir_metadata_none);
896 
897    return resume_reg != NULL;
898 }
899 
900 static bool
cf_node_contains_block(nir_cf_node * node,nir_block * block)901 cf_node_contains_block(nir_cf_node *node, nir_block *block)
902 {
903    for (nir_cf_node *n = &block->cf_node; n != NULL; n = n->parent) {
904       if (n == node)
905          return true;
906    }
907 
908    return false;
909 }
910 
911 static void
rewrite_phis_to_pred(nir_block * block,nir_block * pred)912 rewrite_phis_to_pred(nir_block *block, nir_block *pred)
913 {
914    nir_foreach_phi(phi, block) {
915       ASSERTED bool found = false;
916       nir_foreach_phi_src(phi_src, phi) {
917          if (phi_src->pred == pred) {
918             found = true;
919             nir_def_rewrite_uses(&phi->def, phi_src->src.ssa);
920             break;
921          }
922       }
923       assert(found);
924    }
925 }
926 
927 static bool
cursor_is_after_jump(nir_cursor cursor)928 cursor_is_after_jump(nir_cursor cursor)
929 {
930    switch (cursor.option) {
931    case nir_cursor_before_instr:
932    case nir_cursor_before_block:
933       return false;
934    case nir_cursor_after_instr:
935       return cursor.instr->type == nir_instr_type_jump;
936    case nir_cursor_after_block:
937       return nir_block_ends_in_jump(cursor.block);
938       ;
939    }
940    unreachable("Invalid cursor option");
941 }
942 
943 /** Flattens if ladders leading up to a resume
944  *
945  * Given a resume_instr, this function flattens any if ladders leading to the
946  * resume instruction and deletes any code that cannot be encountered on a
947  * direct path to the resume instruction.  This way we get, for the most part,
948  * straight-line control-flow up to the resume instruction.
949  *
950  * While we do this flattening, we also move any code which is in the remat
951  * set up to the top of the function or to the top of the resume portion of
952  * the current loop.  We don't worry about control-flow as we do this because
953  * phis will never be in the remat set (see can_remat_instr) and so nothing
954  * control-dependent will ever need to be re-materialized.  It is possible
955  * that this algorithm will preserve too many instructions by moving them to
956  * the top but we leave that for DCE to clean up.  Any code not in the remat
957  * set is deleted because it's either unused in the continuation or else
958  * unspilled from a previous continuation and the unspill code is after the
959  * resume instruction.
960  *
961  * If, for instance, we have something like this:
962  *
963  *    // block 0
964  *    if (cond1) {
965  *       // block 1
966  *    } else {
967  *       // block 2
968  *       if (cond2) {
969  *          // block 3
970  *          resume;
971  *          if (cond3) {
972  *             // block 4
973  *          }
974  *       } else {
975  *          // block 5
976  *       }
977  *    }
978  *
979  * then we know, because we know the resume instruction had to be encoutered,
980  * that cond1 = false and cond2 = true and we lower as follows:
981  *
982  *    // block 0
983  *    // block 2
984  *    // block 3
985  *    resume;
986  *    if (cond3) {
987  *       // block 4
988  *    }
989  *
990  * As you can see, the code in blocks 1 and 5 was removed because there is no
991  * path from the start of the shader to the resume instruction which execute
992  * blocks 1 or 5.  Any remat code from blocks 0, 2, and 3 is preserved and
993  * moved to the top.  If the resume instruction is inside a loop then we know
994  * a priori that it is of the form
995  *
996  *    loop {
997  *       if (resume) {
998  *          // Contents containing resume_instr
999  *       } else {
1000  *          // Second copy of contents
1001  *       }
1002  *    }
1003  *
1004  * In this case, we only descend into the first half of the loop.  The second
1005  * half is left alone as that portion is only ever executed after the resume
1006  * instruction.
1007  */
1008 static bool
flatten_resume_if_ladder(nir_builder * b,nir_cf_node * parent_node,struct exec_list * child_list,bool child_list_contains_cursor,nir_instr * resume_instr,struct sized_bitset * remat)1009 flatten_resume_if_ladder(nir_builder *b,
1010                          nir_cf_node *parent_node,
1011                          struct exec_list *child_list,
1012                          bool child_list_contains_cursor,
1013                          nir_instr *resume_instr,
1014                          struct sized_bitset *remat)
1015 {
1016    nir_cf_list cf_list;
1017 
1018    /* If our child list contains the cursor instruction then we start out
1019     * before the cursor instruction.  We need to know this so that we can skip
1020     * moving instructions which are already before the cursor.
1021     */
1022    bool before_cursor = child_list_contains_cursor;
1023 
1024    nir_cf_node *resume_node = NULL;
1025    foreach_list_typed_safe(nir_cf_node, child, node, child_list) {
1026       switch (child->type) {
1027       case nir_cf_node_block: {
1028          nir_block *block = nir_cf_node_as_block(child);
1029          if (b->cursor.option == nir_cursor_before_block &&
1030              b->cursor.block == block) {
1031             assert(before_cursor);
1032             before_cursor = false;
1033          }
1034          nir_foreach_instr_safe(instr, block) {
1035             if ((b->cursor.option == nir_cursor_before_instr ||
1036                  b->cursor.option == nir_cursor_after_instr) &&
1037                 b->cursor.instr == instr) {
1038                assert(nir_cf_node_is_first(&block->cf_node));
1039                assert(before_cursor);
1040                before_cursor = false;
1041                continue;
1042             }
1043 
1044             if (instr == resume_instr)
1045                goto found_resume;
1046 
1047             if (!before_cursor && can_remat_instr(instr, remat)) {
1048                nir_instr_remove(instr);
1049                nir_instr_insert(b->cursor, instr);
1050                b->cursor = nir_after_instr(instr);
1051 
1052                nir_def *def = nir_instr_def(instr);
1053                BITSET_SET(remat->set, def->index);
1054             }
1055          }
1056          if (b->cursor.option == nir_cursor_after_block &&
1057              b->cursor.block == block) {
1058             assert(before_cursor);
1059             before_cursor = false;
1060          }
1061          break;
1062       }
1063 
1064       case nir_cf_node_if: {
1065          assert(!before_cursor);
1066          nir_if *_if = nir_cf_node_as_if(child);
1067          if (flatten_resume_if_ladder(b, &_if->cf_node, &_if->then_list,
1068                                       false, resume_instr, remat)) {
1069             resume_node = child;
1070             rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
1071                                  nir_if_last_then_block(_if));
1072             goto found_resume;
1073          }
1074 
1075          if (flatten_resume_if_ladder(b, &_if->cf_node, &_if->else_list,
1076                                       false, resume_instr, remat)) {
1077             resume_node = child;
1078             rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
1079                                  nir_if_last_else_block(_if));
1080             goto found_resume;
1081          }
1082          break;
1083       }
1084 
1085       case nir_cf_node_loop: {
1086          assert(!before_cursor);
1087          nir_loop *loop = nir_cf_node_as_loop(child);
1088          assert(!nir_loop_has_continue_construct(loop));
1089 
1090          if (cf_node_contains_block(&loop->cf_node, resume_instr->block)) {
1091             /* Thanks to our loop body duplication pass, every level of loop
1092              * containing the resume instruction contains exactly three nodes:
1093              * two blocks and an if.  We don't want to lower away this if
1094              * because it's the resume selection if.  The resume half is
1095              * always the then_list so that's what we want to flatten.
1096              */
1097             nir_block *header = nir_loop_first_block(loop);
1098             nir_if *_if = nir_cf_node_as_if(nir_cf_node_next(&header->cf_node));
1099 
1100             /* We want to place anything re-materialized from inside the loop
1101              * at the top of the resume half of the loop.
1102              */
1103             nir_builder bl = nir_builder_at(nir_before_cf_list(&_if->then_list));
1104 
1105             ASSERTED bool found =
1106                flatten_resume_if_ladder(&bl, &_if->cf_node, &_if->then_list,
1107                                         true, resume_instr, remat);
1108             assert(found);
1109             resume_node = child;
1110             goto found_resume;
1111          } else {
1112             ASSERTED bool found =
1113                flatten_resume_if_ladder(b, &loop->cf_node, &loop->body,
1114                                         false, resume_instr, remat);
1115             assert(!found);
1116          }
1117          break;
1118       }
1119 
1120       case nir_cf_node_function:
1121          unreachable("Unsupported CF node type");
1122       }
1123    }
1124    assert(!before_cursor);
1125 
1126    /* If we got here, we didn't find the resume node or instruction. */
1127    return false;
1128 
1129 found_resume:
1130    /* If we got here then we found either the resume node or the resume
1131     * instruction in this CF list.
1132     */
1133    if (resume_node) {
1134       /* If the resume instruction is buried in side one of our children CF
1135        * nodes, resume_node now points to that child.
1136        */
1137       if (resume_node->type == nir_cf_node_if) {
1138          /* Thanks to the recursive call, all of the interesting contents of
1139           * resume_node have been copied before the cursor.  We just need to
1140           * copy the stuff after resume_node.
1141           */
1142          nir_cf_extract(&cf_list, nir_after_cf_node(resume_node),
1143                         nir_after_cf_list(child_list));
1144       } else {
1145          /* The loop contains its own cursor and still has useful stuff in it.
1146           * We want to move everything after and including the loop to before
1147           * the cursor.
1148           */
1149          assert(resume_node->type == nir_cf_node_loop);
1150          nir_cf_extract(&cf_list, nir_before_cf_node(resume_node),
1151                         nir_after_cf_list(child_list));
1152       }
1153    } else {
1154       /* If we found the resume instruction in one of our blocks, grab
1155        * everything after it in the entire list (not just the one block), and
1156        * place it before the cursor instr.
1157        */
1158       nir_cf_extract(&cf_list, nir_after_instr(resume_instr),
1159                      nir_after_cf_list(child_list));
1160    }
1161 
1162    /* If the resume instruction is in the first block of the child_list,
1163     * and the cursor is still before that block, the nir_cf_extract() may
1164     * extract the block object pointed by the cursor, and instead create
1165     * a new one for the code before the resume. In such case the cursor
1166     * will be broken, as it will point to a block which is no longer
1167     * in a function.
1168     *
1169     * Luckily, in both cases when this is possible, the intended cursor
1170     * position is right before the child_list, so we can fix the cursor here.
1171     */
1172    if (child_list_contains_cursor &&
1173        b->cursor.option == nir_cursor_before_block &&
1174        b->cursor.block->cf_node.parent == NULL)
1175       b->cursor = nir_before_cf_list(child_list);
1176 
1177    if (cursor_is_after_jump(b->cursor)) {
1178       /* If the resume instruction is in a loop, it's possible cf_list ends
1179        * in a break or continue instruction, in which case we don't want to
1180        * insert anything.  It's also possible we have an early return if
1181        * someone hasn't lowered those yet.  In either case, nothing after that
1182        * point executes in this context so we can delete it.
1183        */
1184       nir_cf_delete(&cf_list);
1185    } else {
1186       b->cursor = nir_cf_reinsert(&cf_list, b->cursor);
1187    }
1188 
1189    if (!resume_node) {
1190       /* We want the resume to be the first "interesting" instruction */
1191       nir_instr_remove(resume_instr);
1192       nir_instr_insert(nir_before_impl(b->impl), resume_instr);
1193    }
1194 
1195    /* We've copied everything interesting out of this CF list to before the
1196     * cursor.  Delete everything else.
1197     */
1198    if (child_list_contains_cursor) {
1199       nir_cf_extract(&cf_list, b->cursor, nir_after_cf_list(child_list));
1200    } else {
1201       nir_cf_list_extract(&cf_list, child_list);
1202    }
1203    nir_cf_delete(&cf_list);
1204 
1205    return true;
1206 }
1207 
1208 typedef bool (*wrap_instr_callback)(nir_instr *instr);
1209 
1210 static bool
wrap_instr(nir_builder * b,nir_instr * instr,void * data)1211 wrap_instr(nir_builder *b, nir_instr *instr, void *data)
1212 {
1213    wrap_instr_callback callback = data;
1214    if (!callback(instr))
1215       return false;
1216 
1217    b->cursor = nir_before_instr(instr);
1218 
1219    nir_if *_if = nir_push_if(b, nir_imm_true(b));
1220    nir_pop_if(b, NULL);
1221 
1222    nir_cf_list cf_list;
1223    nir_cf_extract(&cf_list, nir_before_instr(instr), nir_after_instr(instr));
1224    nir_cf_reinsert(&cf_list, nir_before_block(nir_if_first_then_block(_if)));
1225 
1226    return true;
1227 }
1228 
1229 /* This pass wraps jump instructions in a dummy if block so that when
1230  * flatten_resume_if_ladder() does its job, it doesn't move a jump instruction
1231  * directly in front of another instruction which the NIR control flow helpers
1232  * do not allow.
1233  */
1234 static bool
wrap_instrs(nir_shader * shader,wrap_instr_callback callback)1235 wrap_instrs(nir_shader *shader, wrap_instr_callback callback)
1236 {
1237    return nir_shader_instructions_pass(shader, wrap_instr,
1238                                        nir_metadata_none, callback);
1239 }
1240 
1241 static bool
instr_is_jump(nir_instr * instr)1242 instr_is_jump(nir_instr *instr)
1243 {
1244    return instr->type == nir_instr_type_jump;
1245 }
1246 
1247 static nir_instr *
lower_resume(nir_shader * shader,int call_idx)1248 lower_resume(nir_shader *shader, int call_idx)
1249 {
1250    wrap_instrs(shader, instr_is_jump);
1251 
1252    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1253    nir_instr *resume_instr = find_resume_instr(impl, call_idx);
1254 
1255    if (duplicate_loop_bodies(impl, resume_instr)) {
1256       nir_validate_shader(shader, "after duplicate_loop_bodies in "
1257                                   "nir_lower_shader_calls");
1258       /* If we duplicated the bodies of any loops, run reg_intrinsics_to_ssa to
1259        * get rid of all those pesky registers we just added.
1260        */
1261       NIR_PASS_V(shader, nir_lower_reg_intrinsics_to_ssa);
1262    }
1263 
1264    /* Re-index nir_def::index.  We don't care about actual liveness in
1265     * this pass but, so we can use the same helpers as the spilling pass, we
1266     * need to make sure that live_index is something sane.  It's used
1267     * constantly for determining if an SSA value has been added since the
1268     * start of the pass.
1269     */
1270    nir_index_ssa_defs(impl);
1271 
1272    void *mem_ctx = ralloc_context(shader);
1273 
1274    /* Used to track which things may have been assumed to be re-materialized
1275     * by the spilling pass and which we shouldn't delete.
1276     */
1277    struct sized_bitset remat = bitset_create(mem_ctx, impl->ssa_alloc);
1278 
1279    /* Create a nop instruction to use as a cursor as we extract and re-insert
1280     * stuff into the CFG.
1281     */
1282    nir_builder b = nir_builder_at(nir_before_impl(impl));
1283    ASSERTED bool found =
1284       flatten_resume_if_ladder(&b, &impl->cf_node, &impl->body,
1285                                true, resume_instr, &remat);
1286    assert(found);
1287 
1288    ralloc_free(mem_ctx);
1289 
1290    nir_metadata_preserve(impl, nir_metadata_none);
1291 
1292    nir_validate_shader(shader, "after flatten_resume_if_ladder in "
1293                                "nir_lower_shader_calls");
1294 
1295    return resume_instr;
1296 }
1297 
1298 static void
replace_resume_with_halt(nir_shader * shader,nir_instr * keep)1299 replace_resume_with_halt(nir_shader *shader, nir_instr *keep)
1300 {
1301    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1302 
1303    nir_builder b = nir_builder_create(impl);
1304 
1305    nir_foreach_block_safe(block, impl) {
1306       nir_foreach_instr_safe(instr, block) {
1307          if (instr == keep)
1308             continue;
1309 
1310          if (instr->type != nir_instr_type_intrinsic)
1311             continue;
1312 
1313          nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
1314          if (resume->intrinsic != nir_intrinsic_rt_resume)
1315             continue;
1316 
1317          /* If this is some other resume, then we've kicked off a ray or
1318           * bindless thread and we don't want to go any further in this
1319           * shader.  Insert a halt so that NIR will delete any instructions
1320           * dominated by this call instruction including the scratch_load
1321           * instructions we inserted.
1322           */
1323          nir_cf_list cf_list;
1324          nir_cf_extract(&cf_list, nir_after_instr(&resume->instr),
1325                         nir_after_block(block));
1326          nir_cf_delete(&cf_list);
1327          b.cursor = nir_instr_remove(&resume->instr);
1328          nir_jump(&b, nir_jump_halt);
1329          break;
1330       }
1331    }
1332 }
1333 
1334 struct lower_scratch_state {
1335    nir_address_format address_format;
1336 };
1337 
1338 static bool
lower_stack_instr_to_scratch(struct nir_builder * b,nir_instr * instr,void * data)1339 lower_stack_instr_to_scratch(struct nir_builder *b, nir_instr *instr, void *data)
1340 {
1341    struct lower_scratch_state *state = data;
1342 
1343    if (instr->type != nir_instr_type_intrinsic)
1344       return false;
1345 
1346    nir_intrinsic_instr *stack = nir_instr_as_intrinsic(instr);
1347    switch (stack->intrinsic) {
1348    case nir_intrinsic_load_stack: {
1349       b->cursor = nir_instr_remove(instr);
1350       nir_def *data, *old_data = nir_instr_def(instr);
1351 
1352       if (state->address_format == nir_address_format_64bit_global) {
1353          nir_def *addr = nir_iadd_imm(b,
1354                                       nir_load_scratch_base_ptr(b, 1, 64, 1),
1355                                       nir_intrinsic_base(stack));
1356          data = nir_build_load_global(b,
1357                                       stack->def.num_components,
1358                                       stack->def.bit_size,
1359                                       addr,
1360                                       .align_mul = nir_intrinsic_align_mul(stack),
1361                                       .align_offset = nir_intrinsic_align_offset(stack));
1362       } else {
1363          assert(state->address_format == nir_address_format_32bit_offset);
1364          data = nir_load_scratch(b,
1365                                  old_data->num_components,
1366                                  old_data->bit_size,
1367                                  nir_imm_int(b, nir_intrinsic_base(stack)),
1368                                  .align_mul = nir_intrinsic_align_mul(stack),
1369                                  .align_offset = nir_intrinsic_align_offset(stack));
1370       }
1371       nir_def_rewrite_uses(old_data, data);
1372       break;
1373    }
1374 
1375    case nir_intrinsic_store_stack: {
1376       b->cursor = nir_instr_remove(instr);
1377       nir_def *data = stack->src[0].ssa;
1378 
1379       if (state->address_format == nir_address_format_64bit_global) {
1380          nir_def *addr = nir_iadd_imm(b,
1381                                       nir_load_scratch_base_ptr(b, 1, 64, 1),
1382                                       nir_intrinsic_base(stack));
1383          nir_store_global(b, addr,
1384                           nir_intrinsic_align_mul(stack),
1385                           data,
1386                           nir_component_mask(data->num_components));
1387       } else {
1388          assert(state->address_format == nir_address_format_32bit_offset);
1389          nir_store_scratch(b, data,
1390                            nir_imm_int(b, nir_intrinsic_base(stack)),
1391                            .align_mul = nir_intrinsic_align_mul(stack),
1392                            .write_mask = BITFIELD_MASK(data->num_components));
1393       }
1394       break;
1395    }
1396 
1397    default:
1398       return false;
1399    }
1400 
1401    return true;
1402 }
1403 
1404 static bool
nir_lower_stack_to_scratch(nir_shader * shader,nir_address_format address_format)1405 nir_lower_stack_to_scratch(nir_shader *shader,
1406                            nir_address_format address_format)
1407 {
1408    struct lower_scratch_state state = {
1409       .address_format = address_format,
1410    };
1411 
1412    return nir_shader_instructions_pass(shader,
1413                                        lower_stack_instr_to_scratch,
1414                                        nir_metadata_control_flow,
1415                                        &state);
1416 }
1417 
1418 static bool
opt_remove_respills_instr(struct nir_builder * b,nir_intrinsic_instr * store_intrin,void * data)1419 opt_remove_respills_instr(struct nir_builder *b,
1420                           nir_intrinsic_instr *store_intrin, void *data)
1421 {
1422    if (store_intrin->intrinsic != nir_intrinsic_store_stack)
1423       return false;
1424 
1425    nir_instr *value_instr = store_intrin->src[0].ssa->parent_instr;
1426    if (value_instr->type != nir_instr_type_intrinsic)
1427       return false;
1428 
1429    nir_intrinsic_instr *load_intrin = nir_instr_as_intrinsic(value_instr);
1430    if (load_intrin->intrinsic != nir_intrinsic_load_stack)
1431       return false;
1432 
1433    if (nir_intrinsic_base(load_intrin) != nir_intrinsic_base(store_intrin))
1434       return false;
1435 
1436    nir_instr_remove(&store_intrin->instr);
1437    return true;
1438 }
1439 
1440 /* After shader split, look at stack load/store operations. If we're loading
1441  * and storing the same value at the same location, we can drop the store
1442  * instruction.
1443  */
1444 static bool
nir_opt_remove_respills(nir_shader * shader)1445 nir_opt_remove_respills(nir_shader *shader)
1446 {
1447    return nir_shader_intrinsics_pass(shader, opt_remove_respills_instr,
1448                                        nir_metadata_control_flow,
1449                                        NULL);
1450 }
1451 
1452 static void
add_use_mask(struct hash_table_u64 * offset_to_mask,unsigned offset,unsigned mask)1453 add_use_mask(struct hash_table_u64 *offset_to_mask,
1454              unsigned offset, unsigned mask)
1455 {
1456    uintptr_t old_mask = (uintptr_t)
1457       _mesa_hash_table_u64_search(offset_to_mask, offset);
1458 
1459    _mesa_hash_table_u64_insert(offset_to_mask, offset,
1460                                (void *)(uintptr_t)(old_mask | mask));
1461 }
1462 
1463 /* When splitting the shaders, we might have inserted store & loads of vec4s,
1464  * because a live value is a 4 components. But sometimes, only some components
1465  * of that vec4 will be used by after the scratch load. This pass removes the
1466  * unused components of scratch load/stores.
1467  */
1468 static bool
nir_opt_trim_stack_values(nir_shader * shader)1469 nir_opt_trim_stack_values(nir_shader *shader)
1470 {
1471    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1472 
1473    struct hash_table_u64 *value_id_to_mask = _mesa_hash_table_u64_create(NULL);
1474    bool progress = false;
1475 
1476    /* Find all the loads and how their value is being used */
1477    nir_foreach_block_safe(block, impl) {
1478       nir_foreach_instr_safe(instr, block) {
1479          if (instr->type != nir_instr_type_intrinsic)
1480             continue;
1481 
1482          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1483          if (intrin->intrinsic != nir_intrinsic_load_stack)
1484             continue;
1485 
1486          const unsigned value_id = nir_intrinsic_value_id(intrin);
1487 
1488          const unsigned mask =
1489             nir_def_components_read(nir_instr_def(instr));
1490          add_use_mask(value_id_to_mask, value_id, mask);
1491       }
1492    }
1493 
1494    /* For each store, if it stores more than is being used, trim it.
1495     * Otherwise, remove it from the hash table.
1496     */
1497    nir_foreach_block_safe(block, impl) {
1498       nir_foreach_instr_safe(instr, block) {
1499          if (instr->type != nir_instr_type_intrinsic)
1500             continue;
1501 
1502          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1503          if (intrin->intrinsic != nir_intrinsic_store_stack)
1504             continue;
1505 
1506          const unsigned value_id = nir_intrinsic_value_id(intrin);
1507 
1508          const unsigned write_mask = nir_intrinsic_write_mask(intrin);
1509          const unsigned read_mask = (uintptr_t)
1510             _mesa_hash_table_u64_search(value_id_to_mask, value_id);
1511 
1512          /* Already removed from the table, nothing to do */
1513          if (read_mask == 0)
1514             continue;
1515 
1516          /* Matching read/write mask, nothing to do, remove from the table. */
1517          if (write_mask == read_mask) {
1518             _mesa_hash_table_u64_remove(value_id_to_mask, value_id);
1519             continue;
1520          }
1521 
1522          nir_builder b = nir_builder_at(nir_before_instr(instr));
1523 
1524          nir_def *value = nir_channels(&b, intrin->src[0].ssa, read_mask);
1525          nir_src_rewrite(&intrin->src[0], value);
1526 
1527          intrin->num_components = util_bitcount(read_mask);
1528          nir_intrinsic_set_write_mask(intrin, (1u << intrin->num_components) - 1);
1529 
1530          progress = true;
1531       }
1532    }
1533 
1534    /* For each load remaining in the hash table (only the ones we changed the
1535     * number of components of), apply triming/reswizzle.
1536     */
1537    nir_foreach_block_safe(block, impl) {
1538       nir_foreach_instr_safe(instr, block) {
1539          if (instr->type != nir_instr_type_intrinsic)
1540             continue;
1541 
1542          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1543          if (intrin->intrinsic != nir_intrinsic_load_stack)
1544             continue;
1545 
1546          const unsigned value_id = nir_intrinsic_value_id(intrin);
1547 
1548          unsigned read_mask = (uintptr_t)
1549             _mesa_hash_table_u64_search(value_id_to_mask, value_id);
1550          if (read_mask == 0)
1551             continue;
1552 
1553          unsigned swiz_map[NIR_MAX_VEC_COMPONENTS] = {
1554             0,
1555          };
1556          unsigned swiz_count = 0;
1557          u_foreach_bit(idx, read_mask)
1558             swiz_map[idx] = swiz_count++;
1559 
1560          nir_def *def = nir_instr_def(instr);
1561 
1562          nir_foreach_use_safe(use_src, def) {
1563             if (nir_src_parent_instr(use_src)->type == nir_instr_type_alu) {
1564                nir_alu_instr *alu = nir_instr_as_alu(nir_src_parent_instr(use_src));
1565                nir_alu_src *alu_src = exec_node_data(nir_alu_src, use_src, src);
1566 
1567                unsigned count = alu->def.num_components;
1568                for (unsigned idx = 0; idx < count; ++idx)
1569                   alu_src->swizzle[idx] = swiz_map[alu_src->swizzle[idx]];
1570             } else if (nir_src_parent_instr(use_src)->type == nir_instr_type_intrinsic) {
1571                nir_intrinsic_instr *use_intrin =
1572                   nir_instr_as_intrinsic(nir_src_parent_instr(use_src));
1573                assert(nir_intrinsic_has_write_mask(use_intrin));
1574                unsigned write_mask = nir_intrinsic_write_mask(use_intrin);
1575                unsigned new_write_mask = 0;
1576                u_foreach_bit(idx, write_mask)
1577                   new_write_mask |= 1 << swiz_map[idx];
1578                nir_intrinsic_set_write_mask(use_intrin, new_write_mask);
1579             } else {
1580                unreachable("invalid instruction type");
1581             }
1582          }
1583 
1584          intrin->def.num_components = intrin->num_components = swiz_count;
1585 
1586          progress = true;
1587       }
1588    }
1589 
1590    nir_metadata_preserve(impl,
1591                          progress ? (nir_metadata_control_flow |
1592                                      nir_metadata_loop_analysis)
1593                                   : nir_metadata_all);
1594 
1595    _mesa_hash_table_u64_destroy(value_id_to_mask);
1596 
1597    return progress;
1598 }
1599 
1600 struct scratch_item {
1601    unsigned old_offset;
1602    unsigned new_offset;
1603    unsigned bit_size;
1604    unsigned num_components;
1605    unsigned value;
1606    unsigned call_idx;
1607 };
1608 
1609 static int
sort_scratch_item_by_size_and_value_id(const void * _item1,const void * _item2)1610 sort_scratch_item_by_size_and_value_id(const void *_item1, const void *_item2)
1611 {
1612    const struct scratch_item *item1 = _item1;
1613    const struct scratch_item *item2 = _item2;
1614 
1615    /* By ascending value_id */
1616    if (item1->bit_size == item2->bit_size)
1617       return (int)item1->value - (int)item2->value;
1618 
1619    /* By descending size */
1620    return (int)item2->bit_size - (int)item1->bit_size;
1621 }
1622 
1623 static bool
nir_opt_sort_and_pack_stack(nir_shader * shader,unsigned start_call_scratch,unsigned stack_alignment,unsigned num_calls)1624 nir_opt_sort_and_pack_stack(nir_shader *shader,
1625                             unsigned start_call_scratch,
1626                             unsigned stack_alignment,
1627                             unsigned num_calls)
1628 {
1629    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1630 
1631    void *mem_ctx = ralloc_context(NULL);
1632 
1633    struct hash_table_u64 *value_id_to_item =
1634       _mesa_hash_table_u64_create(mem_ctx);
1635    struct util_dynarray ops;
1636    util_dynarray_init(&ops, mem_ctx);
1637 
1638    for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) {
1639       _mesa_hash_table_u64_clear(value_id_to_item);
1640       util_dynarray_clear(&ops);
1641 
1642       /* Find all the stack load and their offset. */
1643       nir_foreach_block_safe(block, impl) {
1644          nir_foreach_instr_safe(instr, block) {
1645             if (instr->type != nir_instr_type_intrinsic)
1646                continue;
1647 
1648             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1649             if (intrin->intrinsic != nir_intrinsic_load_stack)
1650                continue;
1651 
1652             if (nir_intrinsic_call_idx(intrin) != call_idx)
1653                continue;
1654 
1655             const unsigned value_id = nir_intrinsic_value_id(intrin);
1656             nir_def *def = nir_instr_def(instr);
1657 
1658             assert(_mesa_hash_table_u64_search(value_id_to_item,
1659                                                value_id) == NULL);
1660 
1661             struct scratch_item item = {
1662                .old_offset = nir_intrinsic_base(intrin),
1663                .bit_size = def->bit_size,
1664                .num_components = def->num_components,
1665                .value = value_id,
1666             };
1667 
1668             util_dynarray_append(&ops, struct scratch_item, item);
1669             _mesa_hash_table_u64_insert(value_id_to_item, value_id, (void *)(uintptr_t) true);
1670          }
1671       }
1672 
1673       /* Sort scratch item by component size. */
1674       if (util_dynarray_num_elements(&ops, struct scratch_item)) {
1675          qsort(util_dynarray_begin(&ops),
1676                util_dynarray_num_elements(&ops, struct scratch_item),
1677                sizeof(struct scratch_item),
1678                sort_scratch_item_by_size_and_value_id);
1679       }
1680 
1681       /* Reorder things on the stack */
1682       _mesa_hash_table_u64_clear(value_id_to_item);
1683 
1684       unsigned scratch_size = start_call_scratch;
1685       util_dynarray_foreach(&ops, struct scratch_item, item) {
1686          item->new_offset = ALIGN(scratch_size, item->bit_size / 8);
1687          scratch_size = item->new_offset + (item->bit_size * item->num_components) / 8;
1688          _mesa_hash_table_u64_insert(value_id_to_item, item->value, item);
1689       }
1690       shader->scratch_size = ALIGN(scratch_size, stack_alignment);
1691 
1692       /* Update offsets in the instructions */
1693       nir_foreach_block_safe(block, impl) {
1694          nir_foreach_instr_safe(instr, block) {
1695             if (instr->type != nir_instr_type_intrinsic)
1696                continue;
1697 
1698             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1699             switch (intrin->intrinsic) {
1700             case nir_intrinsic_load_stack:
1701             case nir_intrinsic_store_stack: {
1702                if (nir_intrinsic_call_idx(intrin) != call_idx)
1703                   continue;
1704 
1705                struct scratch_item *item =
1706                   _mesa_hash_table_u64_search(value_id_to_item,
1707                                               nir_intrinsic_value_id(intrin));
1708                assert(item);
1709 
1710                nir_intrinsic_set_base(intrin, item->new_offset);
1711                break;
1712             }
1713 
1714             case nir_intrinsic_rt_trace_ray:
1715             case nir_intrinsic_rt_execute_callable:
1716             case nir_intrinsic_rt_resume:
1717                if (nir_intrinsic_call_idx(intrin) != call_idx)
1718                   continue;
1719                nir_intrinsic_set_stack_size(intrin, shader->scratch_size);
1720                break;
1721 
1722             default:
1723                break;
1724             }
1725          }
1726       }
1727    }
1728 
1729    ralloc_free(mem_ctx);
1730 
1731    nir_shader_preserve_all_metadata(shader);
1732 
1733    return true;
1734 }
1735 
1736 static unsigned
nir_block_loop_depth(nir_block * block)1737 nir_block_loop_depth(nir_block *block)
1738 {
1739    nir_cf_node *node = &block->cf_node;
1740    unsigned loop_depth = 0;
1741 
1742    while (node != NULL) {
1743       if (node->type == nir_cf_node_loop)
1744          loop_depth++;
1745       node = node->parent;
1746    }
1747 
1748    return loop_depth;
1749 }
1750 
1751 /* Find the last block dominating all the uses of a SSA value. */
1752 static nir_block *
find_last_dominant_use_block(nir_function_impl * impl,nir_def * value)1753 find_last_dominant_use_block(nir_function_impl *impl, nir_def *value)
1754 {
1755    nir_block *old_block = value->parent_instr->block;
1756    unsigned old_block_loop_depth = nir_block_loop_depth(old_block);
1757 
1758    nir_foreach_block_reverse_safe(block, impl) {
1759       bool fits = true;
1760 
1761       /* Store on the current block of the value */
1762       if (block == old_block)
1763          return block;
1764 
1765       /* Don't move instructions deeper into loops, this would generate more
1766        * memory traffic.
1767        */
1768       unsigned block_loop_depth = nir_block_loop_depth(block);
1769       if (block_loop_depth > old_block_loop_depth)
1770          continue;
1771 
1772       nir_foreach_if_use(src, value) {
1773          nir_block *block_before_if =
1774             nir_cf_node_as_block(nir_cf_node_prev(&nir_src_parent_if(src)->cf_node));
1775          if (!nir_block_dominates(block, block_before_if)) {
1776             fits = false;
1777             break;
1778          }
1779       }
1780       if (!fits)
1781          continue;
1782 
1783       nir_foreach_use(src, value) {
1784          if (nir_src_parent_instr(src)->type == nir_instr_type_phi &&
1785              block == nir_src_parent_instr(src)->block) {
1786             fits = false;
1787             break;
1788          }
1789 
1790          if (!nir_block_dominates(block, nir_src_parent_instr(src)->block)) {
1791             fits = false;
1792             break;
1793          }
1794       }
1795       if (!fits)
1796          continue;
1797 
1798       return block;
1799    }
1800    unreachable("Cannot find block");
1801 }
1802 
1803 /* Put the scratch loads in the branches where they're needed. */
1804 static bool
nir_opt_stack_loads(nir_shader * shader)1805 nir_opt_stack_loads(nir_shader *shader)
1806 {
1807    bool progress = false;
1808 
1809    nir_foreach_function_impl(impl, shader) {
1810       nir_metadata_require(impl, nir_metadata_control_flow);
1811 
1812       bool func_progress = false;
1813       nir_foreach_block_safe(block, impl) {
1814          nir_foreach_instr_safe(instr, block) {
1815             if (instr->type != nir_instr_type_intrinsic)
1816                continue;
1817 
1818             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1819             if (intrin->intrinsic != nir_intrinsic_load_stack)
1820                continue;
1821 
1822             nir_def *value = &intrin->def;
1823             nir_block *new_block = find_last_dominant_use_block(impl, value);
1824             if (new_block == block)
1825                continue;
1826 
1827             /* Move the scratch load in the new block, after the phis. */
1828             nir_instr_remove(instr);
1829             nir_instr_insert(nir_before_block_after_phis(new_block), instr);
1830 
1831             func_progress = true;
1832          }
1833       }
1834 
1835       nir_metadata_preserve(impl,
1836                             func_progress ? (nir_metadata_control_flow |
1837                                              nir_metadata_loop_analysis)
1838                                           : nir_metadata_all);
1839 
1840       progress |= func_progress;
1841    }
1842 
1843    return progress;
1844 }
1845 
1846 static bool
split_stack_components_instr(struct nir_builder * b,nir_intrinsic_instr * intrin,void * data)1847 split_stack_components_instr(struct nir_builder *b,
1848                              nir_intrinsic_instr *intrin, void *data)
1849 {
1850    if (intrin->intrinsic != nir_intrinsic_load_stack &&
1851        intrin->intrinsic != nir_intrinsic_store_stack)
1852       return false;
1853 
1854    if (intrin->intrinsic == nir_intrinsic_load_stack &&
1855        intrin->def.num_components == 1)
1856       return false;
1857 
1858    if (intrin->intrinsic == nir_intrinsic_store_stack &&
1859        intrin->src[0].ssa->num_components == 1)
1860       return false;
1861 
1862    b->cursor = nir_before_instr(&intrin->instr);
1863 
1864    unsigned align_mul = nir_intrinsic_align_mul(intrin);
1865    unsigned align_offset = nir_intrinsic_align_offset(intrin);
1866    if (intrin->intrinsic == nir_intrinsic_load_stack) {
1867       nir_def *components[NIR_MAX_VEC_COMPONENTS] = {
1868          0,
1869       };
1870       for (unsigned c = 0; c < intrin->def.num_components; c++) {
1871          unsigned offset = c * intrin->def.bit_size / 8;
1872          components[c] = nir_load_stack(b, 1, intrin->def.bit_size,
1873                                         .base = nir_intrinsic_base(intrin) + offset,
1874                                         .call_idx = nir_intrinsic_call_idx(intrin),
1875                                         .value_id = nir_intrinsic_value_id(intrin),
1876                                         .align_mul = align_mul,
1877                                         .align_offset = (align_offset + offset) % align_mul);
1878       }
1879 
1880       nir_def_rewrite_uses(&intrin->def,
1881                            nir_vec(b, components,
1882                                    intrin->def.num_components));
1883    } else {
1884       assert(intrin->intrinsic == nir_intrinsic_store_stack);
1885       for (unsigned c = 0; c < intrin->src[0].ssa->num_components; c++) {
1886          unsigned offset = c * intrin->src[0].ssa->bit_size / 8;
1887          nir_store_stack(b, nir_channel(b, intrin->src[0].ssa, c),
1888                          .base = nir_intrinsic_base(intrin) + offset,
1889                          .call_idx = nir_intrinsic_call_idx(intrin),
1890                          .align_mul = align_mul,
1891                          .align_offset = (align_offset + offset) % align_mul,
1892                          .value_id = nir_intrinsic_value_id(intrin),
1893                          .write_mask = 0x1);
1894       }
1895    }
1896 
1897    nir_instr_remove(&intrin->instr);
1898 
1899    return true;
1900 }
1901 
1902 /* Break the load_stack/store_stack intrinsics into single compoments. This
1903  * helps the vectorizer to pack components.
1904  */
1905 static bool
nir_split_stack_components(nir_shader * shader)1906 nir_split_stack_components(nir_shader *shader)
1907 {
1908    return nir_shader_intrinsics_pass(shader, split_stack_components_instr,
1909                                        nir_metadata_control_flow,
1910                                        NULL);
1911 }
1912 
1913 struct stack_op_vectorizer_state {
1914    nir_should_vectorize_mem_func driver_callback;
1915    void *driver_data;
1916 };
1917 
1918 static bool
should_vectorize(unsigned align_mul,unsigned align_offset,unsigned bit_size,unsigned num_components,nir_intrinsic_instr * low,nir_intrinsic_instr * high,void * data)1919 should_vectorize(unsigned align_mul,
1920                  unsigned align_offset,
1921                  unsigned bit_size,
1922                  unsigned num_components,
1923                  nir_intrinsic_instr *low, nir_intrinsic_instr *high,
1924                  void *data)
1925 {
1926    /* We only care about those intrinsics */
1927    if ((low->intrinsic != nir_intrinsic_load_stack &&
1928         low->intrinsic != nir_intrinsic_store_stack) ||
1929        (high->intrinsic != nir_intrinsic_load_stack &&
1930         high->intrinsic != nir_intrinsic_store_stack))
1931       return false;
1932 
1933    struct stack_op_vectorizer_state *state = data;
1934 
1935    return state->driver_callback(align_mul, align_offset,
1936                                  bit_size, num_components,
1937                                  low, high, state->driver_data);
1938 }
1939 
1940 /** Lower shader call instructions to split shaders.
1941  *
1942  * Shader calls can be split into an initial shader and a series of "resume"
1943  * shaders.   When the shader is first invoked, it is the initial shader which
1944  * is executed.  At any point in the initial shader or any one of the resume
1945  * shaders, a shader call operation may be performed.  The possible shader call
1946  * operations are:
1947  *
1948  *  - trace_ray
1949  *  - report_ray_intersection
1950  *  - execute_callable
1951  *
1952  * When a shader call operation is performed, we push all live values to the
1953  * stack,call rt_trace_ray/rt_execute_callable and then kill the shader. Once
1954  * the operation we invoked is complete, a callee shader will return execution
1955  * to the respective resume shader. The resume shader pops the contents off
1956  * the stack and picks up where the calling shader left off.
1957  *
1958  * Stack management is assumed to be done after this pass. Call
1959  * instructions and their resumes get annotated with stack information that
1960  * should be enough for the backend to implement proper stack management.
1961  */
1962 bool
nir_lower_shader_calls(nir_shader * shader,const nir_lower_shader_calls_options * options,nir_shader *** resume_shaders_out,uint32_t * num_resume_shaders_out,void * mem_ctx)1963 nir_lower_shader_calls(nir_shader *shader,
1964                        const nir_lower_shader_calls_options *options,
1965                        nir_shader ***resume_shaders_out,
1966                        uint32_t *num_resume_shaders_out,
1967                        void *mem_ctx)
1968 {
1969    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1970 
1971    int num_calls = 0;
1972    nir_foreach_block(block, impl) {
1973       nir_foreach_instr_safe(instr, block) {
1974          if (instr_is_shader_call(instr))
1975             num_calls++;
1976       }
1977    }
1978 
1979    if (num_calls == 0) {
1980       nir_shader_preserve_all_metadata(shader);
1981       *num_resume_shaders_out = 0;
1982       return false;
1983    }
1984 
1985    /* Some intrinsics not only can't be re-materialized but aren't preserved
1986     * when moving to the continuation shader.  We have to move them to the top
1987     * to ensure they get spilled as needed.
1988     */
1989    {
1990       bool progress = false;
1991       NIR_PASS(progress, shader, move_system_values_to_top);
1992       if (progress)
1993          NIR_PASS(progress, shader, nir_opt_cse);
1994    }
1995 
1996    /* Deref chains contain metadata information that is needed by other passes
1997     * after this one. If we don't rematerialize the derefs in the blocks where
1998     * they're used here, the following lowerings will insert phis which can
1999     * prevent other passes from chasing deref chains. Additionally, derefs need
2000     * to be rematerialized after shader call instructions to avoid spilling.
2001     */
2002    {
2003       bool progress = false;
2004       NIR_PASS(progress, shader, wrap_instrs, instr_is_shader_call);
2005 
2006       nir_rematerialize_derefs_in_use_blocks_impl(impl);
2007 
2008       if (progress)
2009          NIR_PASS(_, shader, nir_opt_dead_cf);
2010    }
2011 
2012    /* Save the start point of the call stack in scratch */
2013    unsigned start_call_scratch = shader->scratch_size;
2014 
2015    NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls,
2016               num_calls, options);
2017 
2018    NIR_PASS_V(shader, nir_opt_remove_phis);
2019 
2020    NIR_PASS_V(shader, nir_opt_trim_stack_values);
2021    NIR_PASS_V(shader, nir_opt_sort_and_pack_stack,
2022               start_call_scratch, options->stack_alignment, num_calls);
2023 
2024    /* Make N copies of our shader */
2025    nir_shader **resume_shaders = ralloc_array(mem_ctx, nir_shader *, num_calls);
2026    for (unsigned i = 0; i < num_calls; i++) {
2027       resume_shaders[i] = nir_shader_clone(mem_ctx, shader);
2028 
2029       /* Give them a recognizable name */
2030       resume_shaders[i]->info.name =
2031          ralloc_asprintf(mem_ctx, "%s%sresume_%u",
2032                          shader->info.name ? shader->info.name : "",
2033                          shader->info.name ? "-" : "",
2034                          i);
2035    }
2036 
2037    replace_resume_with_halt(shader, NULL);
2038    nir_opt_dce(shader);
2039    nir_opt_dead_cf(shader);
2040    for (unsigned i = 0; i < num_calls; i++) {
2041       nir_instr *resume_instr = lower_resume(resume_shaders[i], i);
2042       replace_resume_with_halt(resume_shaders[i], resume_instr);
2043       /* Remove CF after halt before nir_opt_if(). */
2044       nir_opt_dead_cf(resume_shaders[i]);
2045       /* Remove the dummy blocks added by flatten_resume_if_ladder() */
2046       nir_opt_if(resume_shaders[i], nir_opt_if_optimize_phi_true_false);
2047       nir_opt_dce(resume_shaders[i]);
2048       nir_opt_dead_cf(resume_shaders[i]);
2049       nir_opt_remove_phis(resume_shaders[i]);
2050    }
2051 
2052    for (unsigned i = 0; i < num_calls; i++)
2053       NIR_PASS_V(resume_shaders[i], nir_opt_remove_respills);
2054 
2055    if (options->localized_loads) {
2056       /* Once loads have been combined we can try to put them closer to where
2057        * they're needed.
2058        */
2059       for (unsigned i = 0; i < num_calls; i++)
2060          NIR_PASS_V(resume_shaders[i], nir_opt_stack_loads);
2061    }
2062 
2063    struct stack_op_vectorizer_state vectorizer_state = {
2064       .driver_callback = options->vectorizer_callback,
2065       .driver_data = options->vectorizer_data,
2066    };
2067    nir_load_store_vectorize_options vect_opts = {
2068       .modes = nir_var_shader_temp,
2069       .callback = should_vectorize,
2070       .cb_data = &vectorizer_state,
2071    };
2072 
2073    if (options->vectorizer_callback != NULL) {
2074       NIR_PASS_V(shader, nir_split_stack_components);
2075       NIR_PASS_V(shader, nir_opt_load_store_vectorize, &vect_opts);
2076    }
2077    NIR_PASS_V(shader, nir_lower_stack_to_scratch, options->address_format);
2078    nir_opt_cse(shader);
2079    for (unsigned i = 0; i < num_calls; i++) {
2080       if (options->vectorizer_callback != NULL) {
2081          NIR_PASS_V(resume_shaders[i], nir_split_stack_components);
2082          NIR_PASS_V(resume_shaders[i], nir_opt_load_store_vectorize, &vect_opts);
2083       }
2084       NIR_PASS_V(resume_shaders[i], nir_lower_stack_to_scratch,
2085                  options->address_format);
2086       nir_opt_cse(resume_shaders[i]);
2087    }
2088 
2089    *resume_shaders_out = resume_shaders;
2090    *num_resume_shaders_out = num_calls;
2091 
2092    return true;
2093 }
2094