xref: /aosp_15_r20/external/mesa3d/src/microsoft/compiler/dxil_nir_tess.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_control_flow.h"
27 
28 #include "dxil_nir.h"
29 
30 static bool
is_memory_barrier_tcs_patch(const nir_intrinsic_instr * intr)31 is_memory_barrier_tcs_patch(const nir_intrinsic_instr *intr)
32 {
33    if (intr->intrinsic == nir_intrinsic_barrier &&
34        nir_intrinsic_memory_modes(intr) & nir_var_shader_out) {
35       assert(nir_intrinsic_memory_modes(intr) == nir_var_shader_out);
36       assert(nir_intrinsic_memory_scope(intr) == SCOPE_WORKGROUP || nir_intrinsic_memory_scope(intr) == SCOPE_INVOCATION);
37       return true;
38    } else {
39       return false;
40    }
41 }
42 
43 static void
remove_hs_intrinsics(nir_function_impl * impl)44 remove_hs_intrinsics(nir_function_impl *impl)
45 {
46    nir_foreach_block(block, impl) {
47       nir_foreach_instr_safe(instr, block) {
48          if (instr->type != nir_instr_type_intrinsic)
49             continue;
50          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
51          if (intr->intrinsic != nir_intrinsic_store_output &&
52              !is_memory_barrier_tcs_patch(intr))
53             continue;
54          nir_instr_remove(instr);
55       }
56    }
57    nir_metadata_preserve(impl, nir_metadata_control_flow);
58 }
59 
60 static void
61 add_instr_and_srcs_to_set(struct set *instr_set, nir_instr *instr);
62 
63 static bool
add_srcs_to_set(nir_src * src,void * state)64 add_srcs_to_set(nir_src *src, void *state)
65 {
66    add_instr_and_srcs_to_set(state, src->ssa->parent_instr);
67    return true;
68 }
69 
70 static void
add_instr_and_srcs_to_set(struct set * instr_set,nir_instr * instr)71 add_instr_and_srcs_to_set(struct set *instr_set, nir_instr *instr)
72 {
73    bool was_already_found = false;
74    _mesa_set_search_or_add(instr_set, instr, &was_already_found);
75    if (!was_already_found)
76       nir_foreach_src(instr, add_srcs_to_set, instr_set);
77 }
78 
79 static void
prune_patch_function_to_intrinsic_and_srcs(nir_function_impl * impl)80 prune_patch_function_to_intrinsic_and_srcs(nir_function_impl *impl)
81 {
82    struct set *instr_set = _mesa_pointer_set_create(NULL);
83 
84    /* Do this in two phases:
85     * 1. Find all instructions that contribute to a store_output and add them to
86     *    the set. Also, add instructions that contribute to control flow.
87     * 2. Erase every instruction that isn't in the set
88     */
89    nir_foreach_block(block, impl) {
90       nir_if *following_if = nir_block_get_following_if(block);
91       if (following_if) {
92          add_instr_and_srcs_to_set(instr_set, following_if->condition.ssa->parent_instr);
93       }
94       nir_foreach_instr_safe(instr, block) {
95          if (instr->type == nir_instr_type_intrinsic) {
96             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
97             if (intr->intrinsic != nir_intrinsic_store_output &&
98                 !is_memory_barrier_tcs_patch(intr))
99                continue;
100          } else if (instr->type != nir_instr_type_jump)
101             continue;
102          add_instr_and_srcs_to_set(instr_set, instr);
103       }
104    }
105 
106    nir_foreach_block_reverse(block, impl) {
107       nir_foreach_instr_reverse_safe(instr, block) {
108          struct set_entry *entry = _mesa_set_search(instr_set, instr);
109          if (!entry)
110             nir_instr_remove(instr);
111       }
112    }
113 
114    _mesa_set_destroy(instr_set, NULL);
115 }
116 
117 static nir_cursor
get_cursor_for_instr_without_cf(nir_instr * instr)118 get_cursor_for_instr_without_cf(nir_instr *instr)
119 {
120    nir_block *block = instr->block;
121    if (block->cf_node.parent->type == nir_cf_node_function)
122       return nir_before_instr(instr);
123 
124    do {
125       block = nir_cf_node_as_block(nir_cf_node_prev(block->cf_node.parent));
126    } while (block->cf_node.parent->type != nir_cf_node_function);
127    return nir_after_block_before_jump(block);
128 }
129 
130 struct tcs_patch_loop_state {
131    nir_def *deref, *count;
132    nir_cursor begin_cursor, end_cursor, insert_cursor;
133    nir_loop *loop;
134 };
135 
136 static void
start_tcs_loop(nir_builder * b,struct tcs_patch_loop_state * state,nir_deref_instr * loop_var_deref)137 start_tcs_loop(nir_builder *b, struct tcs_patch_loop_state *state, nir_deref_instr *loop_var_deref)
138 {
139    if (!loop_var_deref)
140       return;
141 
142    nir_store_deref(b, loop_var_deref, nir_imm_int(b, 0), 1);
143    state->loop = nir_push_loop(b);
144    state->count = nir_load_deref(b, loop_var_deref);
145    nir_break_if(b, nir_ige_imm(b, state->count, b->impl->function->shader->info.tess.tcs_vertices_out));
146    state->insert_cursor = b->cursor;
147    nir_store_deref(b, loop_var_deref, nir_iadd_imm(b, state->count, 1), 1);
148    nir_pop_loop(b, state->loop);
149 }
150 
151 static void
end_tcs_loop(nir_builder * b,struct tcs_patch_loop_state * state)152 end_tcs_loop(nir_builder *b, struct tcs_patch_loop_state *state)
153 {
154    if (!state->loop)
155       return;
156 
157    nir_cf_list extracted;
158    nir_cf_extract(&extracted, state->begin_cursor, state->end_cursor);
159    nir_cf_reinsert(&extracted, state->insert_cursor);
160 
161    *state = (struct tcs_patch_loop_state ){ 0 };
162 }
163 
164 /* In HLSL/DXIL, the hull (tesselation control) shader is split into two:
165  * 1. The main hull shader, which runs once per output control point.
166  * 2. A patch constant function, which runs once overall.
167  * In GLSL/NIR, these are combined. Each invocation must write to the output
168  * array with a constant gl_InvocationID, which is (apparently) lowered to an
169  * if/else ladder in nir. Each invocation must write the same value to patch
170  * constants - or else undefined behavior strikes. NIR uses store_output to
171  * write the patch constants, and store_per_vertex_output to write the control
172  * point values.
173  *
174  * We clone the NIR function to produce 2: one with the store_output intrinsics
175  * removed, which becomes the main shader (only writes control points), and one
176  * with everything that doesn't contribute to store_output removed, which becomes
177  * the patch constant function.
178  *
179  * For the patch constant function, if the expressions rely on gl_InvocationID,
180  * then we need to run the resulting logic in a loop, using the loop counter to
181  * replace gl_InvocationID. This loop can be terminated when a barrier is hit. If
182  * gl_InvocationID is used again after the barrier, then another loop needs to begin.
183  */
184 void
dxil_nir_split_tess_ctrl(nir_shader * nir,nir_function ** patch_const_func)185 dxil_nir_split_tess_ctrl(nir_shader *nir, nir_function **patch_const_func)
186 {
187    assert(nir->info.stage == MESA_SHADER_TESS_CTRL);
188    assert(exec_list_length(&nir->functions) == 1);
189    nir_function_impl *entrypoint = nir_shader_get_entrypoint(nir);
190 
191    *patch_const_func = nir_function_create(nir, "PatchConstantFunc");
192    nir_function_impl *patch_const_func_impl = nir_function_impl_clone(nir, entrypoint);
193    nir_function_set_impl(*patch_const_func, patch_const_func_impl);
194 
195    remove_hs_intrinsics(entrypoint);
196    prune_patch_function_to_intrinsic_and_srcs(patch_const_func_impl);
197 
198    /* Kill dead references to the invocation ID from the patch const func so we don't
199     * insert unnecessarily loops
200     */
201    bool progress;
202    do {
203       progress = false;
204       progress |= nir_opt_dead_cf(nir);
205       progress |= nir_opt_dce(nir);
206    } while (progress);
207 
208    /* Now, the patch constant function needs to be split into blocks and loops.
209     * The series of instructions up to the first block containing a load_invocation_id
210     * will run sequentially. Then a loop is inserted so load_invocation_id will load the
211     * loop counter. This loop continues until a barrier is reached, when the loop
212     * is closed and the process begins again.
213     *
214     * First, sink load_invocation_id so that it's present on both sides of barriers.
215     * Each use gets a unique load of the invocation ID.
216     */
217    nir_builder b = nir_builder_create(patch_const_func_impl);
218    nir_foreach_block(block, patch_const_func_impl) {
219       nir_foreach_instr_safe(instr, block) {
220          if (instr->type != nir_instr_type_intrinsic)
221             continue;
222          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
223          if (intr->intrinsic != nir_intrinsic_load_invocation_id ||
224              list_is_empty(&intr->def.uses) ||
225              list_is_singular(&intr->def.uses))
226             continue;
227          nir_foreach_use_including_if_safe(src, &intr->def) {
228             b.cursor = nir_before_src(src);
229             nir_src_rewrite(src, nir_load_invocation_id(&b));
230          }
231          nir_instr_remove(instr);
232       }
233    }
234 
235    /* Now replace those invocation ID loads with loads of a local variable that's used as a loop counter */
236    nir_variable *loop_var = NULL;
237    nir_deref_instr *loop_var_deref = NULL;
238    struct tcs_patch_loop_state state = { 0 };
239    nir_foreach_block_safe(block, patch_const_func_impl) {
240       nir_foreach_instr_safe(instr, block) {
241          if (instr->type != nir_instr_type_intrinsic)
242             continue;
243          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
244          switch (intr->intrinsic) {
245          case nir_intrinsic_load_invocation_id: {
246             if (!loop_var) {
247                loop_var = nir_local_variable_create(patch_const_func_impl, glsl_int_type(), "PatchConstInvocId");
248                b.cursor = nir_before_impl(patch_const_func_impl);
249                loop_var_deref = nir_build_deref_var(&b, loop_var);
250             }
251             if (!state.loop) {
252                b.cursor = state.begin_cursor = get_cursor_for_instr_without_cf(instr);
253                start_tcs_loop(&b, &state, loop_var_deref);
254             }
255             nir_def_rewrite_uses(&intr->def, state.count);
256             break;
257          }
258          case nir_intrinsic_barrier:
259             if (!is_memory_barrier_tcs_patch(intr))
260                break;
261 
262             /* The GL tessellation spec says:
263              * The barrier() function may only be called inside the main entry point of the tessellation control shader
264              * and may not be called in potentially divergent flow control.  In particular, barrier() may not be called
265              * inside a switch statement, in either sub-statement of an if statement, inside a do, for, or while loop,
266              * or at any point after a return statement in the function main().
267              *
268              * Therefore, we should be at function-level control flow.
269              */
270             assert(nir_cursors_equal(nir_before_instr(instr), get_cursor_for_instr_without_cf(instr)));
271             state.end_cursor = nir_before_instr(instr);
272             end_tcs_loop(&b, &state);
273             nir_instr_remove(instr);
274             break;
275          default:
276             break;
277          }
278       }
279    }
280    state.end_cursor = nir_after_block_before_jump(nir_impl_last_block(patch_const_func_impl));
281    end_tcs_loop(&b, &state);
282 }
283 
284 struct remove_tess_level_accesses_data {
285    unsigned location;
286    unsigned size;
287 };
288 
289 static bool
remove_tess_level_accesses(nir_builder * b,nir_instr * instr,void * _data)290 remove_tess_level_accesses(nir_builder *b, nir_instr *instr, void *_data)
291 {
292    struct remove_tess_level_accesses_data *data = _data;
293    if (instr->type != nir_instr_type_intrinsic)
294       return false;
295 
296    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
297    if (intr->intrinsic != nir_intrinsic_store_output &&
298        intr->intrinsic != nir_intrinsic_load_input)
299       return false;
300 
301    nir_io_semantics io = nir_intrinsic_io_semantics(intr);
302    if (io.location != data->location)
303       return false;
304 
305    if (nir_intrinsic_component(intr) < data->size)
306       return false;
307 
308    if (intr->intrinsic == nir_intrinsic_store_output) {
309       assert(nir_src_num_components(intr->src[0]) == 1);
310       nir_instr_remove(instr);
311    } else {
312       b->cursor = nir_after_instr(instr);
313       assert(intr->def.num_components == 1);
314       nir_def_rewrite_uses(&intr->def, nir_undef(b, 1, intr->def.bit_size));
315    }
316    return true;
317 }
318 
319 /* Update the types of the tess level variables and remove writes to removed components.
320  * GL always has a 4-component outer tess level and 2-component inner, while D3D requires
321  * the number of components to vary based on the primitive mode.
322  * The 4 and 2 is for quads, while triangles are 3 and 1, and lines are 2 and 0.
323  */
324 bool
dxil_nir_fixup_tess_level_for_domain(nir_shader * nir)325 dxil_nir_fixup_tess_level_for_domain(nir_shader *nir)
326 {
327    bool progress = false;
328    if (nir->info.tess._primitive_mode != TESS_PRIMITIVE_QUADS) {
329       nir_foreach_variable_with_modes_safe(var, nir, nir_var_shader_out | nir_var_shader_in) {
330          unsigned new_array_size = 4;
331          unsigned old_array_size = glsl_array_size(var->type);
332          if (var->data.location == VARYING_SLOT_TESS_LEVEL_OUTER) {
333             new_array_size = nir->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES ? 3 : 2;
334             assert(var->data.compact && (old_array_size == 4 || old_array_size == new_array_size));
335          } else if (var->data.location == VARYING_SLOT_TESS_LEVEL_INNER) {
336             new_array_size = nir->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES ? 1 : 0;
337             assert(var->data.compact && (old_array_size == 2 || old_array_size == new_array_size));
338          } else
339             continue;
340 
341          if (new_array_size == old_array_size)
342             continue;
343 
344          progress = true;
345          if (new_array_size)
346             var->type = glsl_array_type(glsl_float_type(), new_array_size, 0);
347          else {
348             exec_node_remove(&var->node);
349             ralloc_free(var);
350          }
351 
352          struct remove_tess_level_accesses_data pass_data = {
353             .location = var->data.location,
354             .size = new_array_size
355          };
356 
357          nir_shader_instructions_pass(nir, remove_tess_level_accesses,
358             nir_metadata_control_flow, &pass_data);
359       }
360    }
361    return progress;
362 }
363 
364 static bool
tcs_update_deref_input_types(nir_builder * b,nir_instr * instr,void * data)365 tcs_update_deref_input_types(nir_builder *b, nir_instr *instr, void *data)
366 {
367    if (instr->type != nir_instr_type_deref)
368       return false;
369 
370    nir_deref_instr *deref = nir_instr_as_deref(instr);
371    if (deref->deref_type != nir_deref_type_var)
372       return false;
373 
374    nir_variable *var = deref->var;
375    deref->type = var->type;
376    return true;
377 }
378 
379 bool
dxil_nir_set_tcs_patches_in(nir_shader * nir,unsigned num_control_points)380 dxil_nir_set_tcs_patches_in(nir_shader *nir, unsigned num_control_points)
381 {
382    bool progress = false;
383    nir_foreach_variable_with_modes(var, nir, nir_var_shader_in) {
384       if (nir_is_arrayed_io(var, MESA_SHADER_TESS_CTRL)) {
385          var->type = glsl_array_type(glsl_get_array_element(var->type), num_control_points, 0);
386          progress = true;
387       }
388    }
389 
390    if (progress)
391       nir_shader_instructions_pass(nir, tcs_update_deref_input_types, nir_metadata_all, NULL);
392 
393    return progress;
394 }
395