1 /*
2 * Copyright © 2020 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "util/u_dynarray.h"
25 #include "util/u_math.h"
26 #include "nir.h"
27 #include "nir_builder.h"
28 #include "nir_phi_builder.h"
29
30 static bool
move_system_values_to_top(nir_shader * shader)31 move_system_values_to_top(nir_shader *shader)
32 {
33 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
34
35 bool progress = false;
36 nir_foreach_block(block, impl) {
37 nir_foreach_instr_safe(instr, block) {
38 if (instr->type != nir_instr_type_intrinsic)
39 continue;
40
41 /* These intrinsics not only can't be re-materialized but aren't
42 * preserved when moving to the continuation shader. We have to move
43 * them to the top to ensure they get spilled as needed.
44 */
45 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
46 switch (intrin->intrinsic) {
47 case nir_intrinsic_load_shader_record_ptr:
48 case nir_intrinsic_load_btd_local_arg_addr_intel:
49 nir_instr_remove(instr);
50 nir_instr_insert(nir_before_impl(impl), instr);
51 progress = true;
52 break;
53
54 default:
55 break;
56 }
57 }
58 }
59
60 if (progress) {
61 nir_metadata_preserve(impl, nir_metadata_control_flow);
62 } else {
63 nir_metadata_preserve(impl, nir_metadata_all);
64 }
65
66 return progress;
67 }
68
69 static bool
instr_is_shader_call(nir_instr * instr)70 instr_is_shader_call(nir_instr *instr)
71 {
72 if (instr->type != nir_instr_type_intrinsic)
73 return false;
74
75 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
76 return intrin->intrinsic == nir_intrinsic_trace_ray ||
77 intrin->intrinsic == nir_intrinsic_report_ray_intersection ||
78 intrin->intrinsic == nir_intrinsic_execute_callable;
79 }
80
81 /* Previously named bitset, it had to be renamed as FreeBSD defines a struct
82 * named bitset in sys/_bitset.h required by pthread_np.h which is included
83 * from src/util/u_thread.h that is indirectly included by this file.
84 */
85 struct sized_bitset {
86 BITSET_WORD *set;
87 unsigned size;
88 };
89
90 static struct sized_bitset
bitset_create(void * mem_ctx,unsigned size)91 bitset_create(void *mem_ctx, unsigned size)
92 {
93 return (struct sized_bitset){
94 .set = rzalloc_array(mem_ctx, BITSET_WORD, BITSET_WORDS(size)),
95 .size = size,
96 };
97 }
98
99 static bool
src_is_in_bitset(nir_src * src,void * _set)100 src_is_in_bitset(nir_src *src, void *_set)
101 {
102 struct sized_bitset *set = _set;
103
104 /* Any SSA values which were added after we generated liveness information
105 * are things generated by this pass and, while most of it is arithmetic
106 * which we could re-materialize, we don't need to because it's only used
107 * for a single load/store and so shouldn't cross any shader calls.
108 */
109 if (src->ssa->index >= set->size)
110 return false;
111
112 return BITSET_TEST(set->set, src->ssa->index);
113 }
114
115 static void
add_ssa_def_to_bitset(nir_def * def,struct sized_bitset * set)116 add_ssa_def_to_bitset(nir_def *def, struct sized_bitset *set)
117 {
118 if (def->index >= set->size)
119 return;
120
121 BITSET_SET(set->set, def->index);
122 }
123
124 static bool
can_remat_instr(nir_instr * instr,struct sized_bitset * remat)125 can_remat_instr(nir_instr *instr, struct sized_bitset *remat)
126 {
127 /* Set of all values which are trivially re-materializable and we shouldn't
128 * ever spill them. This includes:
129 *
130 * - Undef values
131 * - Constants
132 * - Uniforms (UBO or push constant)
133 * - ALU combinations of any of the above
134 * - Derefs which are either complete or casts of any of the above
135 *
136 * Because this pass rewrites things in-order and phis are always turned
137 * into register writes, we can use "is it SSA?" to answer the question
138 * "can my source be re-materialized?". Register writes happen via
139 * non-rematerializable intrinsics.
140 */
141 switch (instr->type) {
142 case nir_instr_type_alu:
143 return nir_foreach_src(instr, src_is_in_bitset, remat);
144
145 case nir_instr_type_deref:
146 return nir_foreach_src(instr, src_is_in_bitset, remat);
147
148 case nir_instr_type_intrinsic: {
149 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
150 switch (intrin->intrinsic) {
151 case nir_intrinsic_load_uniform:
152 case nir_intrinsic_load_ubo:
153 case nir_intrinsic_vulkan_resource_index:
154 case nir_intrinsic_vulkan_resource_reindex:
155 case nir_intrinsic_load_vulkan_descriptor:
156 case nir_intrinsic_load_push_constant:
157 case nir_intrinsic_load_global_constant:
158 /* These intrinsics don't need to be spilled as long as they don't
159 * depend on any spilled values.
160 */
161 return nir_foreach_src(instr, src_is_in_bitset, remat);
162
163 case nir_intrinsic_load_scratch_base_ptr:
164 case nir_intrinsic_load_ray_launch_id:
165 case nir_intrinsic_load_topology_id_intel:
166 case nir_intrinsic_load_btd_global_arg_addr_intel:
167 case nir_intrinsic_load_btd_resume_sbt_addr_intel:
168 case nir_intrinsic_load_ray_base_mem_addr_intel:
169 case nir_intrinsic_load_ray_hw_stack_size_intel:
170 case nir_intrinsic_load_ray_sw_stack_size_intel:
171 case nir_intrinsic_load_ray_num_dss_rt_stacks_intel:
172 case nir_intrinsic_load_ray_hit_sbt_addr_intel:
173 case nir_intrinsic_load_ray_hit_sbt_stride_intel:
174 case nir_intrinsic_load_ray_miss_sbt_addr_intel:
175 case nir_intrinsic_load_ray_miss_sbt_stride_intel:
176 case nir_intrinsic_load_callable_sbt_addr_intel:
177 case nir_intrinsic_load_callable_sbt_stride_intel:
178 case nir_intrinsic_load_reloc_const_intel:
179 case nir_intrinsic_load_ray_query_global_intel:
180 case nir_intrinsic_load_ray_launch_size:
181 /* Notably missing from the above list is btd_local_arg_addr_intel.
182 * This is because the resume shader will have a different local
183 * argument pointer because it has a different BSR. Any access of
184 * the original shader's local arguments needs to be preserved so
185 * that pointer has to be saved on the stack.
186 *
187 * TODO: There may be some system values we want to avoid
188 * re-materializing as well but we have to be very careful
189 * to ensure that it's a system value which cannot change
190 * across a shader call.
191 */
192 return true;
193
194 case nir_intrinsic_resource_intel:
195 return nir_foreach_src(instr, src_is_in_bitset, remat);
196
197 default:
198 return false;
199 }
200 }
201
202 case nir_instr_type_undef:
203 case nir_instr_type_load_const:
204 return true;
205
206 default:
207 return false;
208 }
209 }
210
211 static bool
can_remat_ssa_def(nir_def * def,struct sized_bitset * remat)212 can_remat_ssa_def(nir_def *def, struct sized_bitset *remat)
213 {
214 return can_remat_instr(def->parent_instr, remat);
215 }
216
217 struct add_instr_data {
218 struct util_dynarray *buf;
219 struct sized_bitset *remat;
220 };
221
222 static bool
add_src_instr(nir_src * src,void * state)223 add_src_instr(nir_src *src, void *state)
224 {
225 struct add_instr_data *data = state;
226 if (BITSET_TEST(data->remat->set, src->ssa->index))
227 return true;
228
229 util_dynarray_foreach(data->buf, nir_instr *, instr_ptr) {
230 if (*instr_ptr == src->ssa->parent_instr)
231 return true;
232 }
233
234 /* Abort rematerializing an instruction chain if it is too long. */
235 if (data->buf->size >= data->buf->capacity)
236 return false;
237
238 util_dynarray_append(data->buf, nir_instr *, src->ssa->parent_instr);
239 return true;
240 }
241
242 static int
compare_instr_indexes(const void * _inst1,const void * _inst2)243 compare_instr_indexes(const void *_inst1, const void *_inst2)
244 {
245 const nir_instr *const *inst1 = _inst1;
246 const nir_instr *const *inst2 = _inst2;
247
248 return (*inst1)->index - (*inst2)->index;
249 }
250
251 static bool
can_remat_chain_ssa_def(nir_def * def,struct sized_bitset * remat,struct util_dynarray * buf)252 can_remat_chain_ssa_def(nir_def *def, struct sized_bitset *remat, struct util_dynarray *buf)
253 {
254 assert(util_dynarray_num_elements(buf, nir_instr *) == 0);
255
256 void *mem_ctx = ralloc_context(NULL);
257
258 /* Add all the instructions involved in build this ssa_def */
259 util_dynarray_append(buf, nir_instr *, def->parent_instr);
260
261 unsigned idx = 0;
262 struct add_instr_data data = {
263 .buf = buf,
264 .remat = remat,
265 };
266 while (idx < util_dynarray_num_elements(buf, nir_instr *)) {
267 nir_instr *instr = *util_dynarray_element(buf, nir_instr *, idx++);
268 if (!nir_foreach_src(instr, add_src_instr, &data))
269 goto fail;
270 }
271
272 /* Sort instructions by index */
273 qsort(util_dynarray_begin(buf),
274 util_dynarray_num_elements(buf, nir_instr *),
275 sizeof(nir_instr *),
276 compare_instr_indexes);
277
278 /* Create a temporary bitset with all values already
279 * rematerialized/rematerializable. We'll add to this bit set as we go
280 * through values that might not be in that set but that we can
281 * rematerialize.
282 */
283 struct sized_bitset potential_remat = bitset_create(mem_ctx, remat->size);
284 memcpy(potential_remat.set, remat->set, BITSET_WORDS(remat->size) * sizeof(BITSET_WORD));
285
286 util_dynarray_foreach(buf, nir_instr *, instr_ptr) {
287 nir_def *instr_ssa_def = nir_instr_def(*instr_ptr);
288
289 /* If already in the potential rematerializable, nothing to do. */
290 if (BITSET_TEST(potential_remat.set, instr_ssa_def->index))
291 continue;
292
293 if (!can_remat_instr(*instr_ptr, &potential_remat))
294 goto fail;
295
296 /* All the sources are rematerializable and the instruction is also
297 * rematerializable, mark it as rematerializable too.
298 */
299 BITSET_SET(potential_remat.set, instr_ssa_def->index);
300 }
301
302 ralloc_free(mem_ctx);
303
304 return true;
305
306 fail:
307 util_dynarray_clear(buf);
308 ralloc_free(mem_ctx);
309 return false;
310 }
311
312 static nir_def *
remat_ssa_def(nir_builder * b,nir_def * def,struct hash_table * remap_table)313 remat_ssa_def(nir_builder *b, nir_def *def, struct hash_table *remap_table)
314 {
315 nir_instr *clone = nir_instr_clone_deep(b->shader, def->parent_instr, remap_table);
316 nir_builder_instr_insert(b, clone);
317 return nir_instr_def(clone);
318 }
319
320 static nir_def *
remat_chain_ssa_def(nir_builder * b,struct util_dynarray * buf,struct sized_bitset * remat,nir_def *** fill_defs,unsigned call_idx,struct hash_table * remap_table)321 remat_chain_ssa_def(nir_builder *b, struct util_dynarray *buf,
322 struct sized_bitset *remat, nir_def ***fill_defs,
323 unsigned call_idx, struct hash_table *remap_table)
324 {
325 nir_def *last_def = NULL;
326
327 util_dynarray_foreach(buf, nir_instr *, instr_ptr) {
328 nir_def *instr_ssa_def = nir_instr_def(*instr_ptr);
329 unsigned ssa_index = instr_ssa_def->index;
330
331 if (fill_defs[ssa_index] != NULL &&
332 fill_defs[ssa_index][call_idx] != NULL)
333 continue;
334
335 /* Clone the instruction we want to rematerialize */
336 nir_def *clone_ssa_def = remat_ssa_def(b, instr_ssa_def, remap_table);
337
338 if (fill_defs[ssa_index] == NULL) {
339 fill_defs[ssa_index] =
340 rzalloc_array(fill_defs, nir_def *, remat->size);
341 }
342
343 /* Add the new ssa_def to the list fill_defs and flag it as
344 * rematerialized
345 */
346 fill_defs[ssa_index][call_idx] = last_def = clone_ssa_def;
347 BITSET_SET(remat->set, ssa_index);
348
349 _mesa_hash_table_insert(remap_table, instr_ssa_def, last_def);
350 }
351
352 return last_def;
353 }
354
355 struct pbv_array {
356 struct nir_phi_builder_value **arr;
357 unsigned len;
358 };
359
360 static struct nir_phi_builder_value *
get_phi_builder_value_for_def(nir_def * def,struct pbv_array * pbv_arr)361 get_phi_builder_value_for_def(nir_def *def,
362 struct pbv_array *pbv_arr)
363 {
364 if (def->index >= pbv_arr->len)
365 return NULL;
366
367 return pbv_arr->arr[def->index];
368 }
369
370 static nir_def *
get_phi_builder_def_for_src(nir_src * src,struct pbv_array * pbv_arr,nir_block * block)371 get_phi_builder_def_for_src(nir_src *src, struct pbv_array *pbv_arr,
372 nir_block *block)
373 {
374
375 struct nir_phi_builder_value *pbv =
376 get_phi_builder_value_for_def(src->ssa, pbv_arr);
377 if (pbv == NULL)
378 return NULL;
379
380 return nir_phi_builder_value_get_block_def(pbv, block);
381 }
382
383 static bool
rewrite_instr_src_from_phi_builder(nir_src * src,void * _pbv_arr)384 rewrite_instr_src_from_phi_builder(nir_src *src, void *_pbv_arr)
385 {
386 nir_block *block;
387 if (nir_src_parent_instr(src)->type == nir_instr_type_phi) {
388 nir_phi_src *phi_src = exec_node_data(nir_phi_src, src, src);
389 block = phi_src->pred;
390 } else {
391 block = nir_src_parent_instr(src)->block;
392 }
393
394 nir_def *new_def = get_phi_builder_def_for_src(src, _pbv_arr, block);
395 if (new_def != NULL)
396 nir_src_rewrite(src, new_def);
397 return true;
398 }
399
400 static nir_def *
spill_fill(nir_builder * before,nir_builder * after,nir_def * def,unsigned value_id,unsigned call_idx,unsigned offset,unsigned stack_alignment)401 spill_fill(nir_builder *before, nir_builder *after, nir_def *def,
402 unsigned value_id, unsigned call_idx,
403 unsigned offset, unsigned stack_alignment)
404 {
405 const unsigned comp_size = def->bit_size / 8;
406
407 nir_store_stack(before, def,
408 .base = offset,
409 .call_idx = call_idx,
410 .align_mul = MIN2(comp_size, stack_alignment),
411 .value_id = value_id,
412 .write_mask = BITFIELD_MASK(def->num_components));
413 return nir_load_stack(after, def->num_components, def->bit_size,
414 .base = offset,
415 .call_idx = call_idx,
416 .value_id = value_id,
417 .align_mul = MIN2(comp_size, stack_alignment));
418 }
419
420 static bool
add_src_to_call_live_bitset(nir_src * src,void * state)421 add_src_to_call_live_bitset(nir_src *src, void *state)
422 {
423 BITSET_WORD *call_live = state;
424
425 BITSET_SET(call_live, src->ssa->index);
426 return true;
427 }
428
429 static void
spill_ssa_defs_and_lower_shader_calls(nir_shader * shader,uint32_t num_calls,const nir_lower_shader_calls_options * options)430 spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
431 const nir_lower_shader_calls_options *options)
432 {
433 /* TODO: If a SSA def is filled more than once, we probably want to just
434 * spill it at the LCM of the fill sites so we avoid unnecessary
435 * extra spills
436 *
437 * TODO: If a SSA def is defined outside a loop but live through some call
438 * inside the loop, we probably want to spill outside the loop. We
439 * may also want to fill outside the loop if it's not used in the
440 * loop.
441 *
442 * TODO: Right now, we only re-materialize things if their immediate
443 * sources are things which we filled. We probably want to expand
444 * that to re-materialize things whose sources are things we can
445 * re-materialize from things we filled. We may want some DAG depth
446 * heuristic on this.
447 */
448
449 /* This happens per-shader rather than per-impl because we mess with
450 * nir_shader::scratch_size.
451 */
452 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
453
454 nir_metadata_require(impl, nir_metadata_live_defs |
455 nir_metadata_dominance |
456 nir_metadata_block_index |
457 nir_metadata_instr_index);
458
459 void *mem_ctx = ralloc_context(shader);
460
461 const unsigned num_ssa_defs = impl->ssa_alloc;
462 const unsigned live_words = BITSET_WORDS(num_ssa_defs);
463 struct sized_bitset trivial_remat = bitset_create(mem_ctx, num_ssa_defs);
464
465 /* Array of all live SSA defs which are spill candidates */
466 nir_def **spill_defs =
467 rzalloc_array(mem_ctx, nir_def *, num_ssa_defs);
468
469 /* For each spill candidate, an array of every time it's defined by a fill,
470 * indexed by call instruction index.
471 */
472 nir_def ***fill_defs =
473 rzalloc_array(mem_ctx, nir_def **, num_ssa_defs);
474
475 /* For each call instruction, the liveness set at the call */
476 const BITSET_WORD **call_live =
477 rzalloc_array(mem_ctx, const BITSET_WORD *, num_calls);
478
479 /* For each call instruction, the block index of the block it lives in */
480 uint32_t *call_block_indices = rzalloc_array(mem_ctx, uint32_t, num_calls);
481
482 /* Remap table when rebuilding instructions out of fill operations */
483 struct hash_table *trivial_remap_table =
484 _mesa_pointer_hash_table_create(mem_ctx);
485
486 /* Walk the call instructions and fetch the liveness set and block index
487 * for each one. We need to do this before we start modifying the shader
488 * so that liveness doesn't complain that it's been invalidated. Don't
489 * worry, we'll be very careful with our live sets. :-)
490 */
491 unsigned call_idx = 0;
492 nir_foreach_block(block, impl) {
493 nir_foreach_instr(instr, block) {
494 if (!instr_is_shader_call(instr))
495 continue;
496
497 call_block_indices[call_idx] = block->index;
498
499 /* The objective here is to preserve values around shader call
500 * instructions. Therefore, we use the live set after the
501 * instruction as the set of things we want to preserve. Because
502 * none of our shader call intrinsics return anything, we don't have
503 * to worry about spilling over a return value.
504 *
505 * TODO: This isn't quite true for report_intersection.
506 */
507 call_live[call_idx] =
508 nir_get_live_defs(nir_after_instr(instr), mem_ctx);
509
510 call_idx++;
511 }
512 }
513
514 /* If a should_remat_callback is given, call it on each of the live values
515 * for each call site. If it returns true we need to rematerialize that
516 * instruction (instead of spill/fill). Therefore we need to add the
517 * sources as live values so that we can rematerialize on top of those
518 * spilled/filled sources.
519 */
520 if (options->should_remat_callback) {
521 BITSET_WORD **updated_call_live =
522 rzalloc_array(mem_ctx, BITSET_WORD *, num_calls);
523
524 nir_foreach_block(block, impl) {
525 nir_foreach_instr(instr, block) {
526 nir_def *def = nir_instr_def(instr);
527 if (def == NULL)
528 continue;
529
530 for (unsigned c = 0; c < num_calls; c++) {
531 if (!BITSET_TEST(call_live[c], def->index))
532 continue;
533
534 if (!options->should_remat_callback(def->parent_instr,
535 options->should_remat_data))
536 continue;
537
538 if (updated_call_live[c] == NULL) {
539 const unsigned bitset_words = BITSET_WORDS(impl->ssa_alloc);
540 updated_call_live[c] = ralloc_array(mem_ctx, BITSET_WORD, bitset_words);
541 memcpy(updated_call_live[c], call_live[c], bitset_words * sizeof(BITSET_WORD));
542 }
543
544 nir_foreach_src(instr, add_src_to_call_live_bitset, updated_call_live[c]);
545 }
546 }
547 }
548
549 for (unsigned c = 0; c < num_calls; c++) {
550 if (updated_call_live[c] != NULL)
551 call_live[c] = updated_call_live[c];
552 }
553 }
554
555 nir_builder before, after;
556 before = nir_builder_create(impl);
557 after = nir_builder_create(impl);
558
559 call_idx = 0;
560 unsigned max_scratch_size = shader->scratch_size;
561 nir_foreach_block(block, impl) {
562 nir_foreach_instr_safe(instr, block) {
563 nir_def *def = nir_instr_def(instr);
564 if (def != NULL) {
565 if (can_remat_ssa_def(def, &trivial_remat)) {
566 add_ssa_def_to_bitset(def, &trivial_remat);
567 _mesa_hash_table_insert(trivial_remap_table, def, def);
568 } else {
569 spill_defs[def->index] = def;
570 }
571 }
572
573 if (!instr_is_shader_call(instr))
574 continue;
575
576 const BITSET_WORD *live = call_live[call_idx];
577
578 struct hash_table *remap_table =
579 _mesa_hash_table_clone(trivial_remap_table, mem_ctx);
580
581 /* Make a copy of trivial_remat that we'll update as we crawl through
582 * the live SSA defs and unspill them.
583 */
584 struct sized_bitset remat = bitset_create(mem_ctx, num_ssa_defs);
585 memcpy(remat.set, trivial_remat.set, live_words * sizeof(BITSET_WORD));
586
587 /* Before the two builders are always separated by the call
588 * instruction, it won't break anything to have two of them.
589 */
590 before.cursor = nir_before_instr(instr);
591 after.cursor = nir_after_instr(instr);
592
593 /* Array used to hold all the values needed to rematerialize a live
594 * value. The capacity is used to determine when we should abort testing
595 * a remat chain. In practice, shaders can have chains with more than
596 * 10k elements while only chains with less than 16 have realistic
597 * chances. There also isn't any performance benefit in rematerializing
598 * extremely long chains.
599 */
600 nir_instr *remat_chain_instrs[16];
601 struct util_dynarray remat_chain;
602 util_dynarray_init_from_stack(&remat_chain, remat_chain_instrs, sizeof(remat_chain_instrs));
603
604 unsigned offset = shader->scratch_size;
605 for (unsigned w = 0; w < live_words; w++) {
606 BITSET_WORD spill_mask = live[w] & ~trivial_remat.set[w];
607 while (spill_mask) {
608 int i = u_bit_scan(&spill_mask);
609 assert(i >= 0);
610 unsigned index = w * BITSET_WORDBITS + i;
611 assert(index < num_ssa_defs);
612
613 def = spill_defs[index];
614 nir_def *original_def = def, *new_def;
615 if (can_remat_ssa_def(def, &remat)) {
616 /* If this SSA def is re-materializable or based on other
617 * things we've already spilled, re-materialize it rather
618 * than spilling and filling. Anything which is trivially
619 * re-materializable won't even get here because we take
620 * those into account in spill_mask above.
621 */
622 new_def = remat_ssa_def(&after, def, remap_table);
623 } else if (can_remat_chain_ssa_def(def, &remat, &remat_chain)) {
624 new_def = remat_chain_ssa_def(&after, &remat_chain, &remat,
625 fill_defs, call_idx,
626 remap_table);
627 util_dynarray_clear(&remat_chain);
628 } else {
629 bool is_bool = def->bit_size == 1;
630 if (is_bool)
631 def = nir_b2b32(&before, def);
632
633 const unsigned comp_size = def->bit_size / 8;
634 offset = ALIGN(offset, comp_size);
635
636 new_def = spill_fill(&before, &after, def,
637 index, call_idx,
638 offset, options->stack_alignment);
639
640 if (is_bool)
641 new_def = nir_b2b1(&after, new_def);
642
643 offset += def->num_components * comp_size;
644 }
645
646 /* Mark this SSA def as available in the remat set so that, if
647 * some other SSA def we need is computed based on it, we can
648 * just re-compute instead of fetching from memory.
649 */
650 BITSET_SET(remat.set, index);
651
652 /* For now, we just make a note of this new SSA def. We'll
653 * fix things up with the phi builder as a second pass.
654 */
655 if (fill_defs[index] == NULL) {
656 fill_defs[index] =
657 rzalloc_array(fill_defs, nir_def *, num_calls);
658 }
659 fill_defs[index][call_idx] = new_def;
660 _mesa_hash_table_insert(remap_table, original_def, new_def);
661 }
662 }
663
664 nir_builder *b = &before;
665
666 offset = ALIGN(offset, options->stack_alignment);
667 max_scratch_size = MAX2(max_scratch_size, offset);
668
669 /* First thing on the called shader's stack is the resume address
670 * followed by a pointer to the payload.
671 */
672 nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
673
674 /* Lower to generic intrinsics with information about the stack & resume shader. */
675 switch (call->intrinsic) {
676 case nir_intrinsic_trace_ray: {
677 nir_rt_trace_ray(b, call->src[0].ssa, call->src[1].ssa,
678 call->src[2].ssa, call->src[3].ssa,
679 call->src[4].ssa, call->src[5].ssa,
680 call->src[6].ssa, call->src[7].ssa,
681 call->src[8].ssa, call->src[9].ssa,
682 call->src[10].ssa,
683 .call_idx = call_idx, .stack_size = offset);
684 break;
685 }
686
687 case nir_intrinsic_report_ray_intersection:
688 unreachable("Any-hit shaders must be inlined");
689
690 case nir_intrinsic_execute_callable: {
691 nir_rt_execute_callable(b, call->src[0].ssa, call->src[1].ssa, .call_idx = call_idx, .stack_size = offset);
692 break;
693 }
694
695 default:
696 unreachable("Invalid shader call instruction");
697 }
698
699 nir_rt_resume(b, .call_idx = call_idx, .stack_size = offset);
700
701 nir_instr_remove(&call->instr);
702
703 call_idx++;
704 }
705 }
706 assert(call_idx == num_calls);
707 shader->scratch_size = max_scratch_size;
708
709 struct nir_phi_builder *pb = nir_phi_builder_create(impl);
710 struct pbv_array pbv_arr = {
711 .arr = rzalloc_array(mem_ctx, struct nir_phi_builder_value *,
712 num_ssa_defs),
713 .len = num_ssa_defs,
714 };
715
716 const unsigned block_words = BITSET_WORDS(impl->num_blocks);
717 BITSET_WORD *def_blocks = ralloc_array(mem_ctx, BITSET_WORD, block_words);
718
719 /* Go through and set up phi builder values for each spillable value which
720 * we ever needed to spill at any point.
721 */
722 for (unsigned index = 0; index < num_ssa_defs; index++) {
723 if (fill_defs[index] == NULL)
724 continue;
725
726 nir_def *def = spill_defs[index];
727
728 memset(def_blocks, 0, block_words * sizeof(BITSET_WORD));
729 BITSET_SET(def_blocks, def->parent_instr->block->index);
730 for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) {
731 if (fill_defs[index][call_idx] != NULL)
732 BITSET_SET(def_blocks, call_block_indices[call_idx]);
733 }
734
735 pbv_arr.arr[index] = nir_phi_builder_add_value(pb, def->num_components,
736 def->bit_size, def_blocks);
737 }
738
739 /* Walk the shader one more time and rewrite SSA defs as needed using the
740 * phi builder.
741 */
742 nir_foreach_block(block, impl) {
743 nir_foreach_instr_safe(instr, block) {
744 nir_def *def = nir_instr_def(instr);
745 if (def != NULL) {
746 struct nir_phi_builder_value *pbv =
747 get_phi_builder_value_for_def(def, &pbv_arr);
748 if (pbv != NULL)
749 nir_phi_builder_value_set_block_def(pbv, block, def);
750 }
751
752 if (instr->type == nir_instr_type_phi)
753 continue;
754
755 nir_foreach_src(instr, rewrite_instr_src_from_phi_builder, &pbv_arr);
756
757 if (instr->type != nir_instr_type_intrinsic)
758 continue;
759
760 nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
761 if (resume->intrinsic != nir_intrinsic_rt_resume)
762 continue;
763
764 call_idx = nir_intrinsic_call_idx(resume);
765
766 /* Technically, this is the wrong place to add the fill defs to the
767 * phi builder values because we haven't seen any of the load_scratch
768 * instructions for this call yet. However, we know based on how we
769 * emitted them that no value ever gets used until after the load
770 * instruction has been emitted so this should be safe. If we ever
771 * fail validation due this it likely means a bug in our spilling
772 * code and not the phi re-construction code here.
773 */
774 for (unsigned index = 0; index < num_ssa_defs; index++) {
775 if (fill_defs[index] && fill_defs[index][call_idx]) {
776 nir_phi_builder_value_set_block_def(pbv_arr.arr[index], block,
777 fill_defs[index][call_idx]);
778 }
779 }
780 }
781
782 nir_if *following_if = nir_block_get_following_if(block);
783 if (following_if) {
784 nir_def *new_def =
785 get_phi_builder_def_for_src(&following_if->condition,
786 &pbv_arr, block);
787 if (new_def != NULL)
788 nir_src_rewrite(&following_if->condition, new_def);
789 }
790
791 /* Handle phi sources that source from this block. We have to do this
792 * as a separate pass because the phi builder assumes that uses and
793 * defs are processed in an order that respects dominance. When we have
794 * loops, a phi source may be a back-edge so we have to handle it as if
795 * it were one of the last instructions in the predecessor block.
796 */
797 nir_foreach_phi_src_leaving_block(block,
798 rewrite_instr_src_from_phi_builder,
799 &pbv_arr);
800 }
801
802 nir_phi_builder_finish(pb);
803
804 ralloc_free(mem_ctx);
805
806 nir_metadata_preserve(impl, nir_metadata_control_flow);
807 }
808
809 static nir_instr *
find_resume_instr(nir_function_impl * impl,unsigned call_idx)810 find_resume_instr(nir_function_impl *impl, unsigned call_idx)
811 {
812 nir_foreach_block(block, impl) {
813 nir_foreach_instr(instr, block) {
814 if (instr->type != nir_instr_type_intrinsic)
815 continue;
816
817 nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
818 if (resume->intrinsic != nir_intrinsic_rt_resume)
819 continue;
820
821 if (nir_intrinsic_call_idx(resume) == call_idx)
822 return &resume->instr;
823 }
824 }
825 unreachable("Couldn't find resume instruction");
826 }
827
828 /* Walk the CF tree and duplicate the contents of every loop, one half runs on
829 * resume and the other half is for any post-resume loop iterations. We are
830 * careful in our duplication to ensure that resume_instr is in the resume
831 * half of the loop though a copy of resume_instr will remain in the other
832 * half as well in case the same shader call happens twice.
833 */
834 static bool
duplicate_loop_bodies(nir_function_impl * impl,nir_instr * resume_instr)835 duplicate_loop_bodies(nir_function_impl *impl, nir_instr *resume_instr)
836 {
837 nir_def *resume_reg = NULL;
838 for (nir_cf_node *node = resume_instr->block->cf_node.parent;
839 node->type != nir_cf_node_function; node = node->parent) {
840 if (node->type != nir_cf_node_loop)
841 continue;
842
843 nir_loop *loop = nir_cf_node_as_loop(node);
844 assert(!nir_loop_has_continue_construct(loop));
845
846 nir_builder b = nir_builder_create(impl);
847
848 if (resume_reg == NULL) {
849 /* We only create resume_reg if we encounter a loop. This way we can
850 * avoid re-validating the shader and calling ssa_to_reg_intrinsics in
851 * the case where it's just if-ladders.
852 */
853 resume_reg = nir_decl_reg(&b, 1, 1, 0);
854
855 /* Initialize resume to true at the start of the shader, right after
856 * the register is declared at the start.
857 */
858 b.cursor = nir_after_instr(resume_reg->parent_instr);
859 nir_store_reg(&b, nir_imm_true(&b), resume_reg);
860
861 /* Set resume to false right after the resume instruction */
862 b.cursor = nir_after_instr(resume_instr);
863 nir_store_reg(&b, nir_imm_false(&b), resume_reg);
864 }
865
866 /* Before we go any further, make sure that everything which exits the
867 * loop or continues around to the top of the loop does so through
868 * registers. We're about to duplicate the loop body and we'll have
869 * serious trouble if we don't do this.
870 */
871 nir_convert_loop_to_lcssa(loop);
872 nir_lower_phis_to_regs_block(nir_loop_first_block(loop));
873 nir_lower_phis_to_regs_block(
874 nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node)));
875
876 nir_cf_list cf_list;
877 nir_cf_list_extract(&cf_list, &loop->body);
878
879 nir_if *_if = nir_if_create(impl->function->shader);
880 b.cursor = nir_after_cf_list(&loop->body);
881 _if->condition = nir_src_for_ssa(nir_load_reg(&b, resume_reg));
882 nir_cf_node_insert(nir_after_cf_list(&loop->body), &_if->cf_node);
883
884 nir_cf_list clone;
885 nir_cf_list_clone(&clone, &cf_list, &loop->cf_node, NULL);
886
887 /* Insert the clone in the else and the original in the then so that
888 * the resume_instr remains valid even after the duplication.
889 */
890 nir_cf_reinsert(&cf_list, nir_before_cf_list(&_if->then_list));
891 nir_cf_reinsert(&clone, nir_before_cf_list(&_if->else_list));
892 }
893
894 if (resume_reg != NULL)
895 nir_metadata_preserve(impl, nir_metadata_none);
896
897 return resume_reg != NULL;
898 }
899
900 static bool
cf_node_contains_block(nir_cf_node * node,nir_block * block)901 cf_node_contains_block(nir_cf_node *node, nir_block *block)
902 {
903 for (nir_cf_node *n = &block->cf_node; n != NULL; n = n->parent) {
904 if (n == node)
905 return true;
906 }
907
908 return false;
909 }
910
911 static void
rewrite_phis_to_pred(nir_block * block,nir_block * pred)912 rewrite_phis_to_pred(nir_block *block, nir_block *pred)
913 {
914 nir_foreach_phi(phi, block) {
915 ASSERTED bool found = false;
916 nir_foreach_phi_src(phi_src, phi) {
917 if (phi_src->pred == pred) {
918 found = true;
919 nir_def_rewrite_uses(&phi->def, phi_src->src.ssa);
920 break;
921 }
922 }
923 assert(found);
924 }
925 }
926
927 static bool
cursor_is_after_jump(nir_cursor cursor)928 cursor_is_after_jump(nir_cursor cursor)
929 {
930 switch (cursor.option) {
931 case nir_cursor_before_instr:
932 case nir_cursor_before_block:
933 return false;
934 case nir_cursor_after_instr:
935 return cursor.instr->type == nir_instr_type_jump;
936 case nir_cursor_after_block:
937 return nir_block_ends_in_jump(cursor.block);
938 ;
939 }
940 unreachable("Invalid cursor option");
941 }
942
943 /** Flattens if ladders leading up to a resume
944 *
945 * Given a resume_instr, this function flattens any if ladders leading to the
946 * resume instruction and deletes any code that cannot be encountered on a
947 * direct path to the resume instruction. This way we get, for the most part,
948 * straight-line control-flow up to the resume instruction.
949 *
950 * While we do this flattening, we also move any code which is in the remat
951 * set up to the top of the function or to the top of the resume portion of
952 * the current loop. We don't worry about control-flow as we do this because
953 * phis will never be in the remat set (see can_remat_instr) and so nothing
954 * control-dependent will ever need to be re-materialized. It is possible
955 * that this algorithm will preserve too many instructions by moving them to
956 * the top but we leave that for DCE to clean up. Any code not in the remat
957 * set is deleted because it's either unused in the continuation or else
958 * unspilled from a previous continuation and the unspill code is after the
959 * resume instruction.
960 *
961 * If, for instance, we have something like this:
962 *
963 * // block 0
964 * if (cond1) {
965 * // block 1
966 * } else {
967 * // block 2
968 * if (cond2) {
969 * // block 3
970 * resume;
971 * if (cond3) {
972 * // block 4
973 * }
974 * } else {
975 * // block 5
976 * }
977 * }
978 *
979 * then we know, because we know the resume instruction had to be encoutered,
980 * that cond1 = false and cond2 = true and we lower as follows:
981 *
982 * // block 0
983 * // block 2
984 * // block 3
985 * resume;
986 * if (cond3) {
987 * // block 4
988 * }
989 *
990 * As you can see, the code in blocks 1 and 5 was removed because there is no
991 * path from the start of the shader to the resume instruction which execute
992 * blocks 1 or 5. Any remat code from blocks 0, 2, and 3 is preserved and
993 * moved to the top. If the resume instruction is inside a loop then we know
994 * a priori that it is of the form
995 *
996 * loop {
997 * if (resume) {
998 * // Contents containing resume_instr
999 * } else {
1000 * // Second copy of contents
1001 * }
1002 * }
1003 *
1004 * In this case, we only descend into the first half of the loop. The second
1005 * half is left alone as that portion is only ever executed after the resume
1006 * instruction.
1007 */
1008 static bool
flatten_resume_if_ladder(nir_builder * b,nir_cf_node * parent_node,struct exec_list * child_list,bool child_list_contains_cursor,nir_instr * resume_instr,struct sized_bitset * remat)1009 flatten_resume_if_ladder(nir_builder *b,
1010 nir_cf_node *parent_node,
1011 struct exec_list *child_list,
1012 bool child_list_contains_cursor,
1013 nir_instr *resume_instr,
1014 struct sized_bitset *remat)
1015 {
1016 nir_cf_list cf_list;
1017
1018 /* If our child list contains the cursor instruction then we start out
1019 * before the cursor instruction. We need to know this so that we can skip
1020 * moving instructions which are already before the cursor.
1021 */
1022 bool before_cursor = child_list_contains_cursor;
1023
1024 nir_cf_node *resume_node = NULL;
1025 foreach_list_typed_safe(nir_cf_node, child, node, child_list) {
1026 switch (child->type) {
1027 case nir_cf_node_block: {
1028 nir_block *block = nir_cf_node_as_block(child);
1029 if (b->cursor.option == nir_cursor_before_block &&
1030 b->cursor.block == block) {
1031 assert(before_cursor);
1032 before_cursor = false;
1033 }
1034 nir_foreach_instr_safe(instr, block) {
1035 if ((b->cursor.option == nir_cursor_before_instr ||
1036 b->cursor.option == nir_cursor_after_instr) &&
1037 b->cursor.instr == instr) {
1038 assert(nir_cf_node_is_first(&block->cf_node));
1039 assert(before_cursor);
1040 before_cursor = false;
1041 continue;
1042 }
1043
1044 if (instr == resume_instr)
1045 goto found_resume;
1046
1047 if (!before_cursor && can_remat_instr(instr, remat)) {
1048 nir_instr_remove(instr);
1049 nir_instr_insert(b->cursor, instr);
1050 b->cursor = nir_after_instr(instr);
1051
1052 nir_def *def = nir_instr_def(instr);
1053 BITSET_SET(remat->set, def->index);
1054 }
1055 }
1056 if (b->cursor.option == nir_cursor_after_block &&
1057 b->cursor.block == block) {
1058 assert(before_cursor);
1059 before_cursor = false;
1060 }
1061 break;
1062 }
1063
1064 case nir_cf_node_if: {
1065 assert(!before_cursor);
1066 nir_if *_if = nir_cf_node_as_if(child);
1067 if (flatten_resume_if_ladder(b, &_if->cf_node, &_if->then_list,
1068 false, resume_instr, remat)) {
1069 resume_node = child;
1070 rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
1071 nir_if_last_then_block(_if));
1072 goto found_resume;
1073 }
1074
1075 if (flatten_resume_if_ladder(b, &_if->cf_node, &_if->else_list,
1076 false, resume_instr, remat)) {
1077 resume_node = child;
1078 rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
1079 nir_if_last_else_block(_if));
1080 goto found_resume;
1081 }
1082 break;
1083 }
1084
1085 case nir_cf_node_loop: {
1086 assert(!before_cursor);
1087 nir_loop *loop = nir_cf_node_as_loop(child);
1088 assert(!nir_loop_has_continue_construct(loop));
1089
1090 if (cf_node_contains_block(&loop->cf_node, resume_instr->block)) {
1091 /* Thanks to our loop body duplication pass, every level of loop
1092 * containing the resume instruction contains exactly three nodes:
1093 * two blocks and an if. We don't want to lower away this if
1094 * because it's the resume selection if. The resume half is
1095 * always the then_list so that's what we want to flatten.
1096 */
1097 nir_block *header = nir_loop_first_block(loop);
1098 nir_if *_if = nir_cf_node_as_if(nir_cf_node_next(&header->cf_node));
1099
1100 /* We want to place anything re-materialized from inside the loop
1101 * at the top of the resume half of the loop.
1102 */
1103 nir_builder bl = nir_builder_at(nir_before_cf_list(&_if->then_list));
1104
1105 ASSERTED bool found =
1106 flatten_resume_if_ladder(&bl, &_if->cf_node, &_if->then_list,
1107 true, resume_instr, remat);
1108 assert(found);
1109 resume_node = child;
1110 goto found_resume;
1111 } else {
1112 ASSERTED bool found =
1113 flatten_resume_if_ladder(b, &loop->cf_node, &loop->body,
1114 false, resume_instr, remat);
1115 assert(!found);
1116 }
1117 break;
1118 }
1119
1120 case nir_cf_node_function:
1121 unreachable("Unsupported CF node type");
1122 }
1123 }
1124 assert(!before_cursor);
1125
1126 /* If we got here, we didn't find the resume node or instruction. */
1127 return false;
1128
1129 found_resume:
1130 /* If we got here then we found either the resume node or the resume
1131 * instruction in this CF list.
1132 */
1133 if (resume_node) {
1134 /* If the resume instruction is buried in side one of our children CF
1135 * nodes, resume_node now points to that child.
1136 */
1137 if (resume_node->type == nir_cf_node_if) {
1138 /* Thanks to the recursive call, all of the interesting contents of
1139 * resume_node have been copied before the cursor. We just need to
1140 * copy the stuff after resume_node.
1141 */
1142 nir_cf_extract(&cf_list, nir_after_cf_node(resume_node),
1143 nir_after_cf_list(child_list));
1144 } else {
1145 /* The loop contains its own cursor and still has useful stuff in it.
1146 * We want to move everything after and including the loop to before
1147 * the cursor.
1148 */
1149 assert(resume_node->type == nir_cf_node_loop);
1150 nir_cf_extract(&cf_list, nir_before_cf_node(resume_node),
1151 nir_after_cf_list(child_list));
1152 }
1153 } else {
1154 /* If we found the resume instruction in one of our blocks, grab
1155 * everything after it in the entire list (not just the one block), and
1156 * place it before the cursor instr.
1157 */
1158 nir_cf_extract(&cf_list, nir_after_instr(resume_instr),
1159 nir_after_cf_list(child_list));
1160 }
1161
1162 /* If the resume instruction is in the first block of the child_list,
1163 * and the cursor is still before that block, the nir_cf_extract() may
1164 * extract the block object pointed by the cursor, and instead create
1165 * a new one for the code before the resume. In such case the cursor
1166 * will be broken, as it will point to a block which is no longer
1167 * in a function.
1168 *
1169 * Luckily, in both cases when this is possible, the intended cursor
1170 * position is right before the child_list, so we can fix the cursor here.
1171 */
1172 if (child_list_contains_cursor &&
1173 b->cursor.option == nir_cursor_before_block &&
1174 b->cursor.block->cf_node.parent == NULL)
1175 b->cursor = nir_before_cf_list(child_list);
1176
1177 if (cursor_is_after_jump(b->cursor)) {
1178 /* If the resume instruction is in a loop, it's possible cf_list ends
1179 * in a break or continue instruction, in which case we don't want to
1180 * insert anything. It's also possible we have an early return if
1181 * someone hasn't lowered those yet. In either case, nothing after that
1182 * point executes in this context so we can delete it.
1183 */
1184 nir_cf_delete(&cf_list);
1185 } else {
1186 b->cursor = nir_cf_reinsert(&cf_list, b->cursor);
1187 }
1188
1189 if (!resume_node) {
1190 /* We want the resume to be the first "interesting" instruction */
1191 nir_instr_remove(resume_instr);
1192 nir_instr_insert(nir_before_impl(b->impl), resume_instr);
1193 }
1194
1195 /* We've copied everything interesting out of this CF list to before the
1196 * cursor. Delete everything else.
1197 */
1198 if (child_list_contains_cursor) {
1199 nir_cf_extract(&cf_list, b->cursor, nir_after_cf_list(child_list));
1200 } else {
1201 nir_cf_list_extract(&cf_list, child_list);
1202 }
1203 nir_cf_delete(&cf_list);
1204
1205 return true;
1206 }
1207
1208 typedef bool (*wrap_instr_callback)(nir_instr *instr);
1209
1210 static bool
wrap_instr(nir_builder * b,nir_instr * instr,void * data)1211 wrap_instr(nir_builder *b, nir_instr *instr, void *data)
1212 {
1213 wrap_instr_callback callback = data;
1214 if (!callback(instr))
1215 return false;
1216
1217 b->cursor = nir_before_instr(instr);
1218
1219 nir_if *_if = nir_push_if(b, nir_imm_true(b));
1220 nir_pop_if(b, NULL);
1221
1222 nir_cf_list cf_list;
1223 nir_cf_extract(&cf_list, nir_before_instr(instr), nir_after_instr(instr));
1224 nir_cf_reinsert(&cf_list, nir_before_block(nir_if_first_then_block(_if)));
1225
1226 return true;
1227 }
1228
1229 /* This pass wraps jump instructions in a dummy if block so that when
1230 * flatten_resume_if_ladder() does its job, it doesn't move a jump instruction
1231 * directly in front of another instruction which the NIR control flow helpers
1232 * do not allow.
1233 */
1234 static bool
wrap_instrs(nir_shader * shader,wrap_instr_callback callback)1235 wrap_instrs(nir_shader *shader, wrap_instr_callback callback)
1236 {
1237 return nir_shader_instructions_pass(shader, wrap_instr,
1238 nir_metadata_none, callback);
1239 }
1240
1241 static bool
instr_is_jump(nir_instr * instr)1242 instr_is_jump(nir_instr *instr)
1243 {
1244 return instr->type == nir_instr_type_jump;
1245 }
1246
1247 static nir_instr *
lower_resume(nir_shader * shader,int call_idx)1248 lower_resume(nir_shader *shader, int call_idx)
1249 {
1250 wrap_instrs(shader, instr_is_jump);
1251
1252 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1253 nir_instr *resume_instr = find_resume_instr(impl, call_idx);
1254
1255 if (duplicate_loop_bodies(impl, resume_instr)) {
1256 nir_validate_shader(shader, "after duplicate_loop_bodies in "
1257 "nir_lower_shader_calls");
1258 /* If we duplicated the bodies of any loops, run reg_intrinsics_to_ssa to
1259 * get rid of all those pesky registers we just added.
1260 */
1261 NIR_PASS_V(shader, nir_lower_reg_intrinsics_to_ssa);
1262 }
1263
1264 /* Re-index nir_def::index. We don't care about actual liveness in
1265 * this pass but, so we can use the same helpers as the spilling pass, we
1266 * need to make sure that live_index is something sane. It's used
1267 * constantly for determining if an SSA value has been added since the
1268 * start of the pass.
1269 */
1270 nir_index_ssa_defs(impl);
1271
1272 void *mem_ctx = ralloc_context(shader);
1273
1274 /* Used to track which things may have been assumed to be re-materialized
1275 * by the spilling pass and which we shouldn't delete.
1276 */
1277 struct sized_bitset remat = bitset_create(mem_ctx, impl->ssa_alloc);
1278
1279 /* Create a nop instruction to use as a cursor as we extract and re-insert
1280 * stuff into the CFG.
1281 */
1282 nir_builder b = nir_builder_at(nir_before_impl(impl));
1283 ASSERTED bool found =
1284 flatten_resume_if_ladder(&b, &impl->cf_node, &impl->body,
1285 true, resume_instr, &remat);
1286 assert(found);
1287
1288 ralloc_free(mem_ctx);
1289
1290 nir_metadata_preserve(impl, nir_metadata_none);
1291
1292 nir_validate_shader(shader, "after flatten_resume_if_ladder in "
1293 "nir_lower_shader_calls");
1294
1295 return resume_instr;
1296 }
1297
1298 static void
replace_resume_with_halt(nir_shader * shader,nir_instr * keep)1299 replace_resume_with_halt(nir_shader *shader, nir_instr *keep)
1300 {
1301 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1302
1303 nir_builder b = nir_builder_create(impl);
1304
1305 nir_foreach_block_safe(block, impl) {
1306 nir_foreach_instr_safe(instr, block) {
1307 if (instr == keep)
1308 continue;
1309
1310 if (instr->type != nir_instr_type_intrinsic)
1311 continue;
1312
1313 nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
1314 if (resume->intrinsic != nir_intrinsic_rt_resume)
1315 continue;
1316
1317 /* If this is some other resume, then we've kicked off a ray or
1318 * bindless thread and we don't want to go any further in this
1319 * shader. Insert a halt so that NIR will delete any instructions
1320 * dominated by this call instruction including the scratch_load
1321 * instructions we inserted.
1322 */
1323 nir_cf_list cf_list;
1324 nir_cf_extract(&cf_list, nir_after_instr(&resume->instr),
1325 nir_after_block(block));
1326 nir_cf_delete(&cf_list);
1327 b.cursor = nir_instr_remove(&resume->instr);
1328 nir_jump(&b, nir_jump_halt);
1329 break;
1330 }
1331 }
1332 }
1333
1334 struct lower_scratch_state {
1335 nir_address_format address_format;
1336 };
1337
1338 static bool
lower_stack_instr_to_scratch(struct nir_builder * b,nir_instr * instr,void * data)1339 lower_stack_instr_to_scratch(struct nir_builder *b, nir_instr *instr, void *data)
1340 {
1341 struct lower_scratch_state *state = data;
1342
1343 if (instr->type != nir_instr_type_intrinsic)
1344 return false;
1345
1346 nir_intrinsic_instr *stack = nir_instr_as_intrinsic(instr);
1347 switch (stack->intrinsic) {
1348 case nir_intrinsic_load_stack: {
1349 b->cursor = nir_instr_remove(instr);
1350 nir_def *data, *old_data = nir_instr_def(instr);
1351
1352 if (state->address_format == nir_address_format_64bit_global) {
1353 nir_def *addr = nir_iadd_imm(b,
1354 nir_load_scratch_base_ptr(b, 1, 64, 1),
1355 nir_intrinsic_base(stack));
1356 data = nir_build_load_global(b,
1357 stack->def.num_components,
1358 stack->def.bit_size,
1359 addr,
1360 .align_mul = nir_intrinsic_align_mul(stack),
1361 .align_offset = nir_intrinsic_align_offset(stack));
1362 } else {
1363 assert(state->address_format == nir_address_format_32bit_offset);
1364 data = nir_load_scratch(b,
1365 old_data->num_components,
1366 old_data->bit_size,
1367 nir_imm_int(b, nir_intrinsic_base(stack)),
1368 .align_mul = nir_intrinsic_align_mul(stack),
1369 .align_offset = nir_intrinsic_align_offset(stack));
1370 }
1371 nir_def_rewrite_uses(old_data, data);
1372 break;
1373 }
1374
1375 case nir_intrinsic_store_stack: {
1376 b->cursor = nir_instr_remove(instr);
1377 nir_def *data = stack->src[0].ssa;
1378
1379 if (state->address_format == nir_address_format_64bit_global) {
1380 nir_def *addr = nir_iadd_imm(b,
1381 nir_load_scratch_base_ptr(b, 1, 64, 1),
1382 nir_intrinsic_base(stack));
1383 nir_store_global(b, addr,
1384 nir_intrinsic_align_mul(stack),
1385 data,
1386 nir_component_mask(data->num_components));
1387 } else {
1388 assert(state->address_format == nir_address_format_32bit_offset);
1389 nir_store_scratch(b, data,
1390 nir_imm_int(b, nir_intrinsic_base(stack)),
1391 .align_mul = nir_intrinsic_align_mul(stack),
1392 .write_mask = BITFIELD_MASK(data->num_components));
1393 }
1394 break;
1395 }
1396
1397 default:
1398 return false;
1399 }
1400
1401 return true;
1402 }
1403
1404 static bool
nir_lower_stack_to_scratch(nir_shader * shader,nir_address_format address_format)1405 nir_lower_stack_to_scratch(nir_shader *shader,
1406 nir_address_format address_format)
1407 {
1408 struct lower_scratch_state state = {
1409 .address_format = address_format,
1410 };
1411
1412 return nir_shader_instructions_pass(shader,
1413 lower_stack_instr_to_scratch,
1414 nir_metadata_control_flow,
1415 &state);
1416 }
1417
1418 static bool
opt_remove_respills_instr(struct nir_builder * b,nir_intrinsic_instr * store_intrin,void * data)1419 opt_remove_respills_instr(struct nir_builder *b,
1420 nir_intrinsic_instr *store_intrin, void *data)
1421 {
1422 if (store_intrin->intrinsic != nir_intrinsic_store_stack)
1423 return false;
1424
1425 nir_instr *value_instr = store_intrin->src[0].ssa->parent_instr;
1426 if (value_instr->type != nir_instr_type_intrinsic)
1427 return false;
1428
1429 nir_intrinsic_instr *load_intrin = nir_instr_as_intrinsic(value_instr);
1430 if (load_intrin->intrinsic != nir_intrinsic_load_stack)
1431 return false;
1432
1433 if (nir_intrinsic_base(load_intrin) != nir_intrinsic_base(store_intrin))
1434 return false;
1435
1436 nir_instr_remove(&store_intrin->instr);
1437 return true;
1438 }
1439
1440 /* After shader split, look at stack load/store operations. If we're loading
1441 * and storing the same value at the same location, we can drop the store
1442 * instruction.
1443 */
1444 static bool
nir_opt_remove_respills(nir_shader * shader)1445 nir_opt_remove_respills(nir_shader *shader)
1446 {
1447 return nir_shader_intrinsics_pass(shader, opt_remove_respills_instr,
1448 nir_metadata_control_flow,
1449 NULL);
1450 }
1451
1452 static void
add_use_mask(struct hash_table_u64 * offset_to_mask,unsigned offset,unsigned mask)1453 add_use_mask(struct hash_table_u64 *offset_to_mask,
1454 unsigned offset, unsigned mask)
1455 {
1456 uintptr_t old_mask = (uintptr_t)
1457 _mesa_hash_table_u64_search(offset_to_mask, offset);
1458
1459 _mesa_hash_table_u64_insert(offset_to_mask, offset,
1460 (void *)(uintptr_t)(old_mask | mask));
1461 }
1462
1463 /* When splitting the shaders, we might have inserted store & loads of vec4s,
1464 * because a live value is a 4 components. But sometimes, only some components
1465 * of that vec4 will be used by after the scratch load. This pass removes the
1466 * unused components of scratch load/stores.
1467 */
1468 static bool
nir_opt_trim_stack_values(nir_shader * shader)1469 nir_opt_trim_stack_values(nir_shader *shader)
1470 {
1471 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1472
1473 struct hash_table_u64 *value_id_to_mask = _mesa_hash_table_u64_create(NULL);
1474 bool progress = false;
1475
1476 /* Find all the loads and how their value is being used */
1477 nir_foreach_block_safe(block, impl) {
1478 nir_foreach_instr_safe(instr, block) {
1479 if (instr->type != nir_instr_type_intrinsic)
1480 continue;
1481
1482 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1483 if (intrin->intrinsic != nir_intrinsic_load_stack)
1484 continue;
1485
1486 const unsigned value_id = nir_intrinsic_value_id(intrin);
1487
1488 const unsigned mask =
1489 nir_def_components_read(nir_instr_def(instr));
1490 add_use_mask(value_id_to_mask, value_id, mask);
1491 }
1492 }
1493
1494 /* For each store, if it stores more than is being used, trim it.
1495 * Otherwise, remove it from the hash table.
1496 */
1497 nir_foreach_block_safe(block, impl) {
1498 nir_foreach_instr_safe(instr, block) {
1499 if (instr->type != nir_instr_type_intrinsic)
1500 continue;
1501
1502 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1503 if (intrin->intrinsic != nir_intrinsic_store_stack)
1504 continue;
1505
1506 const unsigned value_id = nir_intrinsic_value_id(intrin);
1507
1508 const unsigned write_mask = nir_intrinsic_write_mask(intrin);
1509 const unsigned read_mask = (uintptr_t)
1510 _mesa_hash_table_u64_search(value_id_to_mask, value_id);
1511
1512 /* Already removed from the table, nothing to do */
1513 if (read_mask == 0)
1514 continue;
1515
1516 /* Matching read/write mask, nothing to do, remove from the table. */
1517 if (write_mask == read_mask) {
1518 _mesa_hash_table_u64_remove(value_id_to_mask, value_id);
1519 continue;
1520 }
1521
1522 nir_builder b = nir_builder_at(nir_before_instr(instr));
1523
1524 nir_def *value = nir_channels(&b, intrin->src[0].ssa, read_mask);
1525 nir_src_rewrite(&intrin->src[0], value);
1526
1527 intrin->num_components = util_bitcount(read_mask);
1528 nir_intrinsic_set_write_mask(intrin, (1u << intrin->num_components) - 1);
1529
1530 progress = true;
1531 }
1532 }
1533
1534 /* For each load remaining in the hash table (only the ones we changed the
1535 * number of components of), apply triming/reswizzle.
1536 */
1537 nir_foreach_block_safe(block, impl) {
1538 nir_foreach_instr_safe(instr, block) {
1539 if (instr->type != nir_instr_type_intrinsic)
1540 continue;
1541
1542 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1543 if (intrin->intrinsic != nir_intrinsic_load_stack)
1544 continue;
1545
1546 const unsigned value_id = nir_intrinsic_value_id(intrin);
1547
1548 unsigned read_mask = (uintptr_t)
1549 _mesa_hash_table_u64_search(value_id_to_mask, value_id);
1550 if (read_mask == 0)
1551 continue;
1552
1553 unsigned swiz_map[NIR_MAX_VEC_COMPONENTS] = {
1554 0,
1555 };
1556 unsigned swiz_count = 0;
1557 u_foreach_bit(idx, read_mask)
1558 swiz_map[idx] = swiz_count++;
1559
1560 nir_def *def = nir_instr_def(instr);
1561
1562 nir_foreach_use_safe(use_src, def) {
1563 if (nir_src_parent_instr(use_src)->type == nir_instr_type_alu) {
1564 nir_alu_instr *alu = nir_instr_as_alu(nir_src_parent_instr(use_src));
1565 nir_alu_src *alu_src = exec_node_data(nir_alu_src, use_src, src);
1566
1567 unsigned count = alu->def.num_components;
1568 for (unsigned idx = 0; idx < count; ++idx)
1569 alu_src->swizzle[idx] = swiz_map[alu_src->swizzle[idx]];
1570 } else if (nir_src_parent_instr(use_src)->type == nir_instr_type_intrinsic) {
1571 nir_intrinsic_instr *use_intrin =
1572 nir_instr_as_intrinsic(nir_src_parent_instr(use_src));
1573 assert(nir_intrinsic_has_write_mask(use_intrin));
1574 unsigned write_mask = nir_intrinsic_write_mask(use_intrin);
1575 unsigned new_write_mask = 0;
1576 u_foreach_bit(idx, write_mask)
1577 new_write_mask |= 1 << swiz_map[idx];
1578 nir_intrinsic_set_write_mask(use_intrin, new_write_mask);
1579 } else {
1580 unreachable("invalid instruction type");
1581 }
1582 }
1583
1584 intrin->def.num_components = intrin->num_components = swiz_count;
1585
1586 progress = true;
1587 }
1588 }
1589
1590 nir_metadata_preserve(impl,
1591 progress ? (nir_metadata_control_flow |
1592 nir_metadata_loop_analysis)
1593 : nir_metadata_all);
1594
1595 _mesa_hash_table_u64_destroy(value_id_to_mask);
1596
1597 return progress;
1598 }
1599
1600 struct scratch_item {
1601 unsigned old_offset;
1602 unsigned new_offset;
1603 unsigned bit_size;
1604 unsigned num_components;
1605 unsigned value;
1606 unsigned call_idx;
1607 };
1608
1609 static int
sort_scratch_item_by_size_and_value_id(const void * _item1,const void * _item2)1610 sort_scratch_item_by_size_and_value_id(const void *_item1, const void *_item2)
1611 {
1612 const struct scratch_item *item1 = _item1;
1613 const struct scratch_item *item2 = _item2;
1614
1615 /* By ascending value_id */
1616 if (item1->bit_size == item2->bit_size)
1617 return (int)item1->value - (int)item2->value;
1618
1619 /* By descending size */
1620 return (int)item2->bit_size - (int)item1->bit_size;
1621 }
1622
1623 static bool
nir_opt_sort_and_pack_stack(nir_shader * shader,unsigned start_call_scratch,unsigned stack_alignment,unsigned num_calls)1624 nir_opt_sort_and_pack_stack(nir_shader *shader,
1625 unsigned start_call_scratch,
1626 unsigned stack_alignment,
1627 unsigned num_calls)
1628 {
1629 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1630
1631 void *mem_ctx = ralloc_context(NULL);
1632
1633 struct hash_table_u64 *value_id_to_item =
1634 _mesa_hash_table_u64_create(mem_ctx);
1635 struct util_dynarray ops;
1636 util_dynarray_init(&ops, mem_ctx);
1637
1638 for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) {
1639 _mesa_hash_table_u64_clear(value_id_to_item);
1640 util_dynarray_clear(&ops);
1641
1642 /* Find all the stack load and their offset. */
1643 nir_foreach_block_safe(block, impl) {
1644 nir_foreach_instr_safe(instr, block) {
1645 if (instr->type != nir_instr_type_intrinsic)
1646 continue;
1647
1648 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1649 if (intrin->intrinsic != nir_intrinsic_load_stack)
1650 continue;
1651
1652 if (nir_intrinsic_call_idx(intrin) != call_idx)
1653 continue;
1654
1655 const unsigned value_id = nir_intrinsic_value_id(intrin);
1656 nir_def *def = nir_instr_def(instr);
1657
1658 assert(_mesa_hash_table_u64_search(value_id_to_item,
1659 value_id) == NULL);
1660
1661 struct scratch_item item = {
1662 .old_offset = nir_intrinsic_base(intrin),
1663 .bit_size = def->bit_size,
1664 .num_components = def->num_components,
1665 .value = value_id,
1666 };
1667
1668 util_dynarray_append(&ops, struct scratch_item, item);
1669 _mesa_hash_table_u64_insert(value_id_to_item, value_id, (void *)(uintptr_t) true);
1670 }
1671 }
1672
1673 /* Sort scratch item by component size. */
1674 if (util_dynarray_num_elements(&ops, struct scratch_item)) {
1675 qsort(util_dynarray_begin(&ops),
1676 util_dynarray_num_elements(&ops, struct scratch_item),
1677 sizeof(struct scratch_item),
1678 sort_scratch_item_by_size_and_value_id);
1679 }
1680
1681 /* Reorder things on the stack */
1682 _mesa_hash_table_u64_clear(value_id_to_item);
1683
1684 unsigned scratch_size = start_call_scratch;
1685 util_dynarray_foreach(&ops, struct scratch_item, item) {
1686 item->new_offset = ALIGN(scratch_size, item->bit_size / 8);
1687 scratch_size = item->new_offset + (item->bit_size * item->num_components) / 8;
1688 _mesa_hash_table_u64_insert(value_id_to_item, item->value, item);
1689 }
1690 shader->scratch_size = ALIGN(scratch_size, stack_alignment);
1691
1692 /* Update offsets in the instructions */
1693 nir_foreach_block_safe(block, impl) {
1694 nir_foreach_instr_safe(instr, block) {
1695 if (instr->type != nir_instr_type_intrinsic)
1696 continue;
1697
1698 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1699 switch (intrin->intrinsic) {
1700 case nir_intrinsic_load_stack:
1701 case nir_intrinsic_store_stack: {
1702 if (nir_intrinsic_call_idx(intrin) != call_idx)
1703 continue;
1704
1705 struct scratch_item *item =
1706 _mesa_hash_table_u64_search(value_id_to_item,
1707 nir_intrinsic_value_id(intrin));
1708 assert(item);
1709
1710 nir_intrinsic_set_base(intrin, item->new_offset);
1711 break;
1712 }
1713
1714 case nir_intrinsic_rt_trace_ray:
1715 case nir_intrinsic_rt_execute_callable:
1716 case nir_intrinsic_rt_resume:
1717 if (nir_intrinsic_call_idx(intrin) != call_idx)
1718 continue;
1719 nir_intrinsic_set_stack_size(intrin, shader->scratch_size);
1720 break;
1721
1722 default:
1723 break;
1724 }
1725 }
1726 }
1727 }
1728
1729 ralloc_free(mem_ctx);
1730
1731 nir_shader_preserve_all_metadata(shader);
1732
1733 return true;
1734 }
1735
1736 static unsigned
nir_block_loop_depth(nir_block * block)1737 nir_block_loop_depth(nir_block *block)
1738 {
1739 nir_cf_node *node = &block->cf_node;
1740 unsigned loop_depth = 0;
1741
1742 while (node != NULL) {
1743 if (node->type == nir_cf_node_loop)
1744 loop_depth++;
1745 node = node->parent;
1746 }
1747
1748 return loop_depth;
1749 }
1750
1751 /* Find the last block dominating all the uses of a SSA value. */
1752 static nir_block *
find_last_dominant_use_block(nir_function_impl * impl,nir_def * value)1753 find_last_dominant_use_block(nir_function_impl *impl, nir_def *value)
1754 {
1755 nir_block *old_block = value->parent_instr->block;
1756 unsigned old_block_loop_depth = nir_block_loop_depth(old_block);
1757
1758 nir_foreach_block_reverse_safe(block, impl) {
1759 bool fits = true;
1760
1761 /* Store on the current block of the value */
1762 if (block == old_block)
1763 return block;
1764
1765 /* Don't move instructions deeper into loops, this would generate more
1766 * memory traffic.
1767 */
1768 unsigned block_loop_depth = nir_block_loop_depth(block);
1769 if (block_loop_depth > old_block_loop_depth)
1770 continue;
1771
1772 nir_foreach_if_use(src, value) {
1773 nir_block *block_before_if =
1774 nir_cf_node_as_block(nir_cf_node_prev(&nir_src_parent_if(src)->cf_node));
1775 if (!nir_block_dominates(block, block_before_if)) {
1776 fits = false;
1777 break;
1778 }
1779 }
1780 if (!fits)
1781 continue;
1782
1783 nir_foreach_use(src, value) {
1784 if (nir_src_parent_instr(src)->type == nir_instr_type_phi &&
1785 block == nir_src_parent_instr(src)->block) {
1786 fits = false;
1787 break;
1788 }
1789
1790 if (!nir_block_dominates(block, nir_src_parent_instr(src)->block)) {
1791 fits = false;
1792 break;
1793 }
1794 }
1795 if (!fits)
1796 continue;
1797
1798 return block;
1799 }
1800 unreachable("Cannot find block");
1801 }
1802
1803 /* Put the scratch loads in the branches where they're needed. */
1804 static bool
nir_opt_stack_loads(nir_shader * shader)1805 nir_opt_stack_loads(nir_shader *shader)
1806 {
1807 bool progress = false;
1808
1809 nir_foreach_function_impl(impl, shader) {
1810 nir_metadata_require(impl, nir_metadata_control_flow);
1811
1812 bool func_progress = false;
1813 nir_foreach_block_safe(block, impl) {
1814 nir_foreach_instr_safe(instr, block) {
1815 if (instr->type != nir_instr_type_intrinsic)
1816 continue;
1817
1818 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1819 if (intrin->intrinsic != nir_intrinsic_load_stack)
1820 continue;
1821
1822 nir_def *value = &intrin->def;
1823 nir_block *new_block = find_last_dominant_use_block(impl, value);
1824 if (new_block == block)
1825 continue;
1826
1827 /* Move the scratch load in the new block, after the phis. */
1828 nir_instr_remove(instr);
1829 nir_instr_insert(nir_before_block_after_phis(new_block), instr);
1830
1831 func_progress = true;
1832 }
1833 }
1834
1835 nir_metadata_preserve(impl,
1836 func_progress ? (nir_metadata_control_flow |
1837 nir_metadata_loop_analysis)
1838 : nir_metadata_all);
1839
1840 progress |= func_progress;
1841 }
1842
1843 return progress;
1844 }
1845
1846 static bool
split_stack_components_instr(struct nir_builder * b,nir_intrinsic_instr * intrin,void * data)1847 split_stack_components_instr(struct nir_builder *b,
1848 nir_intrinsic_instr *intrin, void *data)
1849 {
1850 if (intrin->intrinsic != nir_intrinsic_load_stack &&
1851 intrin->intrinsic != nir_intrinsic_store_stack)
1852 return false;
1853
1854 if (intrin->intrinsic == nir_intrinsic_load_stack &&
1855 intrin->def.num_components == 1)
1856 return false;
1857
1858 if (intrin->intrinsic == nir_intrinsic_store_stack &&
1859 intrin->src[0].ssa->num_components == 1)
1860 return false;
1861
1862 b->cursor = nir_before_instr(&intrin->instr);
1863
1864 unsigned align_mul = nir_intrinsic_align_mul(intrin);
1865 unsigned align_offset = nir_intrinsic_align_offset(intrin);
1866 if (intrin->intrinsic == nir_intrinsic_load_stack) {
1867 nir_def *components[NIR_MAX_VEC_COMPONENTS] = {
1868 0,
1869 };
1870 for (unsigned c = 0; c < intrin->def.num_components; c++) {
1871 unsigned offset = c * intrin->def.bit_size / 8;
1872 components[c] = nir_load_stack(b, 1, intrin->def.bit_size,
1873 .base = nir_intrinsic_base(intrin) + offset,
1874 .call_idx = nir_intrinsic_call_idx(intrin),
1875 .value_id = nir_intrinsic_value_id(intrin),
1876 .align_mul = align_mul,
1877 .align_offset = (align_offset + offset) % align_mul);
1878 }
1879
1880 nir_def_rewrite_uses(&intrin->def,
1881 nir_vec(b, components,
1882 intrin->def.num_components));
1883 } else {
1884 assert(intrin->intrinsic == nir_intrinsic_store_stack);
1885 for (unsigned c = 0; c < intrin->src[0].ssa->num_components; c++) {
1886 unsigned offset = c * intrin->src[0].ssa->bit_size / 8;
1887 nir_store_stack(b, nir_channel(b, intrin->src[0].ssa, c),
1888 .base = nir_intrinsic_base(intrin) + offset,
1889 .call_idx = nir_intrinsic_call_idx(intrin),
1890 .align_mul = align_mul,
1891 .align_offset = (align_offset + offset) % align_mul,
1892 .value_id = nir_intrinsic_value_id(intrin),
1893 .write_mask = 0x1);
1894 }
1895 }
1896
1897 nir_instr_remove(&intrin->instr);
1898
1899 return true;
1900 }
1901
1902 /* Break the load_stack/store_stack intrinsics into single compoments. This
1903 * helps the vectorizer to pack components.
1904 */
1905 static bool
nir_split_stack_components(nir_shader * shader)1906 nir_split_stack_components(nir_shader *shader)
1907 {
1908 return nir_shader_intrinsics_pass(shader, split_stack_components_instr,
1909 nir_metadata_control_flow,
1910 NULL);
1911 }
1912
1913 struct stack_op_vectorizer_state {
1914 nir_should_vectorize_mem_func driver_callback;
1915 void *driver_data;
1916 };
1917
1918 static bool
should_vectorize(unsigned align_mul,unsigned align_offset,unsigned bit_size,unsigned num_components,nir_intrinsic_instr * low,nir_intrinsic_instr * high,void * data)1919 should_vectorize(unsigned align_mul,
1920 unsigned align_offset,
1921 unsigned bit_size,
1922 unsigned num_components,
1923 nir_intrinsic_instr *low, nir_intrinsic_instr *high,
1924 void *data)
1925 {
1926 /* We only care about those intrinsics */
1927 if ((low->intrinsic != nir_intrinsic_load_stack &&
1928 low->intrinsic != nir_intrinsic_store_stack) ||
1929 (high->intrinsic != nir_intrinsic_load_stack &&
1930 high->intrinsic != nir_intrinsic_store_stack))
1931 return false;
1932
1933 struct stack_op_vectorizer_state *state = data;
1934
1935 return state->driver_callback(align_mul, align_offset,
1936 bit_size, num_components,
1937 low, high, state->driver_data);
1938 }
1939
1940 /** Lower shader call instructions to split shaders.
1941 *
1942 * Shader calls can be split into an initial shader and a series of "resume"
1943 * shaders. When the shader is first invoked, it is the initial shader which
1944 * is executed. At any point in the initial shader or any one of the resume
1945 * shaders, a shader call operation may be performed. The possible shader call
1946 * operations are:
1947 *
1948 * - trace_ray
1949 * - report_ray_intersection
1950 * - execute_callable
1951 *
1952 * When a shader call operation is performed, we push all live values to the
1953 * stack,call rt_trace_ray/rt_execute_callable and then kill the shader. Once
1954 * the operation we invoked is complete, a callee shader will return execution
1955 * to the respective resume shader. The resume shader pops the contents off
1956 * the stack and picks up where the calling shader left off.
1957 *
1958 * Stack management is assumed to be done after this pass. Call
1959 * instructions and their resumes get annotated with stack information that
1960 * should be enough for the backend to implement proper stack management.
1961 */
1962 bool
nir_lower_shader_calls(nir_shader * shader,const nir_lower_shader_calls_options * options,nir_shader *** resume_shaders_out,uint32_t * num_resume_shaders_out,void * mem_ctx)1963 nir_lower_shader_calls(nir_shader *shader,
1964 const nir_lower_shader_calls_options *options,
1965 nir_shader ***resume_shaders_out,
1966 uint32_t *num_resume_shaders_out,
1967 void *mem_ctx)
1968 {
1969 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1970
1971 int num_calls = 0;
1972 nir_foreach_block(block, impl) {
1973 nir_foreach_instr_safe(instr, block) {
1974 if (instr_is_shader_call(instr))
1975 num_calls++;
1976 }
1977 }
1978
1979 if (num_calls == 0) {
1980 nir_shader_preserve_all_metadata(shader);
1981 *num_resume_shaders_out = 0;
1982 return false;
1983 }
1984
1985 /* Some intrinsics not only can't be re-materialized but aren't preserved
1986 * when moving to the continuation shader. We have to move them to the top
1987 * to ensure they get spilled as needed.
1988 */
1989 {
1990 bool progress = false;
1991 NIR_PASS(progress, shader, move_system_values_to_top);
1992 if (progress)
1993 NIR_PASS(progress, shader, nir_opt_cse);
1994 }
1995
1996 /* Deref chains contain metadata information that is needed by other passes
1997 * after this one. If we don't rematerialize the derefs in the blocks where
1998 * they're used here, the following lowerings will insert phis which can
1999 * prevent other passes from chasing deref chains. Additionally, derefs need
2000 * to be rematerialized after shader call instructions to avoid spilling.
2001 */
2002 {
2003 bool progress = false;
2004 NIR_PASS(progress, shader, wrap_instrs, instr_is_shader_call);
2005
2006 nir_rematerialize_derefs_in_use_blocks_impl(impl);
2007
2008 if (progress)
2009 NIR_PASS(_, shader, nir_opt_dead_cf);
2010 }
2011
2012 /* Save the start point of the call stack in scratch */
2013 unsigned start_call_scratch = shader->scratch_size;
2014
2015 NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls,
2016 num_calls, options);
2017
2018 NIR_PASS_V(shader, nir_opt_remove_phis);
2019
2020 NIR_PASS_V(shader, nir_opt_trim_stack_values);
2021 NIR_PASS_V(shader, nir_opt_sort_and_pack_stack,
2022 start_call_scratch, options->stack_alignment, num_calls);
2023
2024 /* Make N copies of our shader */
2025 nir_shader **resume_shaders = ralloc_array(mem_ctx, nir_shader *, num_calls);
2026 for (unsigned i = 0; i < num_calls; i++) {
2027 resume_shaders[i] = nir_shader_clone(mem_ctx, shader);
2028
2029 /* Give them a recognizable name */
2030 resume_shaders[i]->info.name =
2031 ralloc_asprintf(mem_ctx, "%s%sresume_%u",
2032 shader->info.name ? shader->info.name : "",
2033 shader->info.name ? "-" : "",
2034 i);
2035 }
2036
2037 replace_resume_with_halt(shader, NULL);
2038 nir_opt_dce(shader);
2039 nir_opt_dead_cf(shader);
2040 for (unsigned i = 0; i < num_calls; i++) {
2041 nir_instr *resume_instr = lower_resume(resume_shaders[i], i);
2042 replace_resume_with_halt(resume_shaders[i], resume_instr);
2043 /* Remove CF after halt before nir_opt_if(). */
2044 nir_opt_dead_cf(resume_shaders[i]);
2045 /* Remove the dummy blocks added by flatten_resume_if_ladder() */
2046 nir_opt_if(resume_shaders[i], nir_opt_if_optimize_phi_true_false);
2047 nir_opt_dce(resume_shaders[i]);
2048 nir_opt_dead_cf(resume_shaders[i]);
2049 nir_opt_remove_phis(resume_shaders[i]);
2050 }
2051
2052 for (unsigned i = 0; i < num_calls; i++)
2053 NIR_PASS_V(resume_shaders[i], nir_opt_remove_respills);
2054
2055 if (options->localized_loads) {
2056 /* Once loads have been combined we can try to put them closer to where
2057 * they're needed.
2058 */
2059 for (unsigned i = 0; i < num_calls; i++)
2060 NIR_PASS_V(resume_shaders[i], nir_opt_stack_loads);
2061 }
2062
2063 struct stack_op_vectorizer_state vectorizer_state = {
2064 .driver_callback = options->vectorizer_callback,
2065 .driver_data = options->vectorizer_data,
2066 };
2067 nir_load_store_vectorize_options vect_opts = {
2068 .modes = nir_var_shader_temp,
2069 .callback = should_vectorize,
2070 .cb_data = &vectorizer_state,
2071 };
2072
2073 if (options->vectorizer_callback != NULL) {
2074 NIR_PASS_V(shader, nir_split_stack_components);
2075 NIR_PASS_V(shader, nir_opt_load_store_vectorize, &vect_opts);
2076 }
2077 NIR_PASS_V(shader, nir_lower_stack_to_scratch, options->address_format);
2078 nir_opt_cse(shader);
2079 for (unsigned i = 0; i < num_calls; i++) {
2080 if (options->vectorizer_callback != NULL) {
2081 NIR_PASS_V(resume_shaders[i], nir_split_stack_components);
2082 NIR_PASS_V(resume_shaders[i], nir_opt_load_store_vectorize, &vect_opts);
2083 }
2084 NIR_PASS_V(resume_shaders[i], nir_lower_stack_to_scratch,
2085 options->address_format);
2086 nir_opt_cse(resume_shaders[i]);
2087 }
2088
2089 *resume_shaders_out = resume_shaders;
2090 *num_resume_shaders_out = num_calls;
2091
2092 return true;
2093 }
2094