xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_gs_intrinsics.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2015 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_xfb_info.h"
27 
28 /**
29  * \file nir_lower_gs_intrinsics.c
30  *
31  * Geometry Shaders can call EmitVertex()/EmitStreamVertex() to output an
32  * arbitrary number of vertices.  However, the shader must declare the maximum
33  * number of vertices that it will ever output - further attempts to emit
34  * vertices result in undefined behavior according to the GLSL specification.
35  *
36  * Drivers might use this maximum number of vertices to allocate enough space
37  * to hold the geometry shader's output.  Some drivers (such as i965) need to
38  * implement "safety checks" which ensure that the shader hasn't emitted too
39  * many vertices, to avoid overflowing that space and trashing other memory.
40  *
41  * The count of emitted vertices can also be useful in buffer offset
42  * calculations, so drivers know where to write the GS output.
43  *
44  * However, for simple geometry shaders that emit a statically determinable
45  * number of vertices, this extra bookkeeping is unnecessary and inefficient.
46  * By tracking the vertex count in NIR, we allow constant folding/propagation
47  * and dead control flow optimizations to eliminate most of it where possible.
48  *
49  * This pass introduces a new global variable which stores the current vertex
50  * count (initialized to 0), and converts emit_vertex/end_primitive intrinsics
51  * to their *_with_counter variants.  emit_vertex is also wrapped in a safety
52  * check to avoid buffer overflows.  Finally, it adds a set_vertex_count
53  * intrinsic at the end of the program, informing the driver of the final
54  * vertex count.
55  */
56 
57 struct state {
58    nir_builder *builder;
59    nir_variable *vertex_count_vars[NIR_MAX_XFB_STREAMS];
60    nir_variable *vtxcnt_per_prim_vars[NIR_MAX_XFB_STREAMS];
61    nir_variable *primitive_count_vars[NIR_MAX_XFB_STREAMS];
62    nir_variable *decomposed_primitive_count_vars[NIR_MAX_XFB_STREAMS];
63    bool per_stream;
64    bool count_prims;
65    bool count_vtx_per_prim;
66    bool count_decomposed_prims;
67    bool overwrite_incomplete;
68    bool is_points;
69    bool progress;
70 };
71 
72 /**
73  * Replace emit_vertex intrinsics with:
74  *
75  * if (vertex_count < max_vertices) {
76  *    emit_vertex_with_counter vertex_count, vertex_count_per_primitive (optional) ...
77  *    vertex_count += 1
78  *    vertex_count_per_primitive += 1
79  * }
80  */
81 static void
rewrite_emit_vertex(nir_intrinsic_instr * intrin,struct state * state)82 rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
83 {
84    nir_builder *b = state->builder;
85    unsigned stream = nir_intrinsic_stream_id(intrin);
86 
87    /* Load the vertex count */
88    b->cursor = nir_before_instr(&intrin->instr);
89    assert(state->vertex_count_vars[stream] != NULL);
90    nir_def *count = nir_load_var(b, state->vertex_count_vars[stream]);
91    nir_def *count_per_primitive;
92    nir_def *primitive_count;
93    nir_def *decomposed_primitive_count;
94 
95    if (state->count_vtx_per_prim)
96       count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]);
97    else if (state->is_points)
98       count_per_primitive = nir_imm_int(b, 0);
99    else
100       count_per_primitive = nir_undef(b, 1, 32);
101 
102    if (state->count_prims)
103       primitive_count = nir_load_var(b, state->primitive_count_vars[stream]);
104    else
105       primitive_count = nir_undef(b, 1, 32);
106 
107    if (state->count_decomposed_prims) {
108       decomposed_primitive_count =
109          nir_load_var(b, state->decomposed_primitive_count_vars[stream]);
110    } else {
111       decomposed_primitive_count = nir_undef(b, 1, 32);
112    }
113 
114    /* Create: if (vertex_count < max_vertices) and insert it.
115     *
116     * The new if statement needs to be hooked up to the control flow graph
117     * before we start inserting instructions into it.
118     */
119    nir_push_if(b, nir_ilt_imm(b, count, b->shader->info.gs.vertices_out));
120 
121    nir_emit_vertex_with_counter(b, count, count_per_primitive, primitive_count,
122                                 decomposed_primitive_count, stream);
123 
124    /* Increment the vertex count by 1 */
125    nir_store_var(b, state->vertex_count_vars[stream],
126                  nir_iadd_imm(b, count, 1),
127                  0x1); /* .x */
128 
129    if (state->count_vtx_per_prim) {
130       /* Increment the per-primitive vertex count by 1 */
131       nir_variable *var = state->vtxcnt_per_prim_vars[stream];
132       nir_def *vtx_per_prim_cnt = nir_load_var(b, var);
133       nir_store_var(b, var,
134                     nir_iadd_imm(b, vtx_per_prim_cnt, 1),
135                     0x1); /* .x */
136    }
137 
138    if (state->count_decomposed_prims) {
139       nir_variable *vtx_var = state->vtxcnt_per_prim_vars[stream];
140       nir_def *vtx_per_prim_cnt = state->is_points ? nir_imm_int(b, 1) :
141                                                      nir_load_var(b, vtx_var);
142 
143       /* We form a new primitive for every vertex emitted after the first
144        * complete primitive (since we're outputting strips).
145        */
146       unsigned min_verts =
147          mesa_vertices_per_prim(b->shader->info.gs.output_primitive);
148       nir_def *new_prim = nir_uge_imm(b, vtx_per_prim_cnt, min_verts);
149 
150       /* Increment the decomposed primitive count by 1 if we formed a complete
151        * primitive.
152        */
153       nir_variable *var = state->decomposed_primitive_count_vars[stream];
154       nir_def *cnt = nir_load_var(b, var);
155       nir_store_var(b, var,
156                     nir_iadd(b, cnt, nir_b2i32(b, new_prim)),
157                     0x1); /* .x */
158    }
159 
160    nir_pop_if(b, NULL);
161 
162    nir_instr_remove(&intrin->instr);
163 
164    state->progress = true;
165 }
166 
167 /**
168  * Emits code that overwrites incomplete primitives and their vertices.
169  *
170  * A primitive is considered incomplete when it doesn't have enough vertices.
171  * For example, a triangle strip that has 2 or fewer vertices, or a line strip
172  * with 1 vertex are considered incomplete.
173  *
174  * After each end_primitive and at the end of the shader before emitting
175  * set_vertex_and_primitive_count, we check if the primitive that is being
176  * emitted has enough vertices or not, and we adjust the vertex and primitive
177  * counters accordingly.
178  *
179  * This means that the following emit_vertex can reuse the vertex index of
180  * a previous vertex, if the previous primitive was incomplete, so the compiler
181  * backend is expected to simply overwrite any data that belonged to those.
182  */
183 static void
overwrite_incomplete_primitives(struct state * state,unsigned stream)184 overwrite_incomplete_primitives(struct state *state, unsigned stream)
185 {
186    assert(state->count_vtx_per_prim);
187 
188    nir_builder *b = state->builder;
189    unsigned outprim_min_vertices =
190       mesa_vertices_per_prim(b->shader->info.gs.output_primitive);
191 
192    /* Total count of vertices emitted so far. */
193    nir_def *vtxcnt_total =
194       nir_load_var(b, state->vertex_count_vars[stream]);
195 
196    /* Number of vertices emitted for the last primitive */
197    nir_def *vtxcnt_per_primitive =
198       nir_load_var(b, state->vtxcnt_per_prim_vars[stream]);
199 
200    /* See if the current primitive is a incomplete */
201    nir_def *is_inc_prim =
202       nir_ilt_imm(b, vtxcnt_per_primitive, outprim_min_vertices);
203 
204    /* Number of vertices in the incomplete primitive */
205    nir_def *num_inc_vtx =
206       nir_bcsel(b, is_inc_prim, vtxcnt_per_primitive, nir_imm_int(b, 0));
207 
208    /* Store corrected total vertex count */
209    nir_store_var(b, state->vertex_count_vars[stream],
210                  nir_isub(b, vtxcnt_total, num_inc_vtx),
211                  0x1); /* .x */
212 
213    if (state->count_prims) {
214       /* Number of incomplete primitives (0 or 1) */
215       nir_def *num_inc_prim = nir_b2i32(b, is_inc_prim);
216 
217       /* Store corrected primitive count */
218       nir_def *prim_cnt = nir_load_var(b, state->primitive_count_vars[stream]);
219       nir_store_var(b, state->primitive_count_vars[stream],
220                     nir_isub(b, prim_cnt, num_inc_prim),
221                     0x1); /* .x */
222    }
223 }
224 
225 /**
226  * Replace end_primitive with end_primitive_with_counter.
227  */
228 static void
rewrite_end_primitive(nir_intrinsic_instr * intrin,struct state * state)229 rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state)
230 {
231    nir_builder *b = state->builder;
232    unsigned stream = nir_intrinsic_stream_id(intrin);
233 
234    b->cursor = nir_instr_remove(&intrin->instr);
235    state->progress = true;
236 
237    /* end_primitive doesn't do anything for points, remove without replacing */
238    if (state->is_points) {
239       b->shader->info.gs.uses_end_primitive = false;
240       return;
241    }
242 
243    assert(state->vertex_count_vars[stream] != NULL);
244    nir_def *count = nir_load_var(b, state->vertex_count_vars[stream]);
245    nir_def *count_per_primitive;
246    nir_def *primitive_count;
247    nir_def *decomposed_primitive_count;
248 
249    if (state->count_vtx_per_prim)
250       count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]);
251    else
252       count_per_primitive = nir_undef(b, count->num_components, count->bit_size);
253 
254    if (state->count_prims)
255       primitive_count = nir_load_var(b, state->primitive_count_vars[stream]);
256    else
257       primitive_count = nir_undef(b, 1, 32);
258 
259    if (state->count_decomposed_prims) {
260       decomposed_primitive_count =
261          nir_load_var(b, state->decomposed_primitive_count_vars[stream]);
262    } else {
263       decomposed_primitive_count = nir_undef(b, 1, 32);
264    }
265 
266    nir_end_primitive_with_counter(b, count, count_per_primitive,
267                                   primitive_count,
268                                   decomposed_primitive_count, stream);
269 
270    if (state->count_prims) {
271       /* Increment the primitive count by 1 */
272       nir_def *prim_cnt = nir_load_var(b, state->primitive_count_vars[stream]);
273       nir_store_var(b, state->primitive_count_vars[stream],
274                     nir_iadd_imm(b, prim_cnt, 1),
275                     0x1); /* .x */
276    }
277 
278    if (state->count_vtx_per_prim) {
279       if (state->overwrite_incomplete)
280          overwrite_incomplete_primitives(state, stream);
281 
282       /* Store 0 to per-primitive vertex count */
283       nir_store_var(b, state->vtxcnt_per_prim_vars[stream],
284                     nir_imm_int(b, 0),
285                     0x1); /* .x */
286    }
287 }
288 
289 static bool
rewrite_intrinsics(nir_block * block,struct state * state)290 rewrite_intrinsics(nir_block *block, struct state *state)
291 {
292    nir_foreach_instr_safe(instr, block) {
293       if (instr->type != nir_instr_type_intrinsic)
294          continue;
295 
296       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
297       switch (intrin->intrinsic) {
298       case nir_intrinsic_emit_vertex:
299       case nir_intrinsic_emit_vertex_with_counter:
300          rewrite_emit_vertex(intrin, state);
301          break;
302       case nir_intrinsic_end_primitive:
303       case nir_intrinsic_end_primitive_with_counter:
304          rewrite_end_primitive(intrin, state);
305          break;
306       default:
307          /* not interesting; skip this */
308          break;
309       }
310    }
311 
312    return true;
313 }
314 
315 /**
316  * Add a set_vertex_and_primitive_count intrinsic at the end of the program
317  * (representing the final total vertex and primitive count).
318  */
319 static void
append_set_vertex_and_primitive_count(nir_block * end_block,struct state * state)320 append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state)
321 {
322    nir_builder *b = state->builder;
323    nir_shader *shader = state->builder->shader;
324 
325    /* Insert the new intrinsic in all of the predecessors of the end block,
326     * but before any jump instructions (return).
327     */
328    set_foreach(end_block->predecessors, entry) {
329       nir_block *pred = (nir_block *)entry->key;
330       b->cursor = nir_after_block_before_jump(pred);
331 
332       for (unsigned stream = 0; stream < NIR_MAX_XFB_STREAMS; ++stream) {
333          /* When it's not per-stream, we only need to write one variable. */
334          if (!state->per_stream && stream != 0)
335             continue;
336 
337          nir_def *vtx_cnt;
338          nir_def *prim_cnt;
339          nir_def *decomposed_prim_cnt;
340 
341          if (state->per_stream && !(shader->info.gs.active_stream_mask & (1 << stream))) {
342             /* Inactive stream: vertex count is 0, primitive count is 0 or undef. */
343             vtx_cnt = nir_imm_int(b, 0);
344             prim_cnt = state->count_prims || state->is_points
345                           ? nir_imm_int(b, 0)
346                           : nir_undef(b, 1, 32);
347             decomposed_prim_cnt = prim_cnt;
348          } else {
349             if (state->overwrite_incomplete)
350                overwrite_incomplete_primitives(state, stream);
351 
352             vtx_cnt = nir_load_var(b, state->vertex_count_vars[stream]);
353 
354             if (state->count_prims)
355                prim_cnt = nir_load_var(b, state->primitive_count_vars[stream]);
356             else if (state->is_points)
357                /* EndPrimitive does not affect primitive count for points,
358                 * just use vertex count instead
359                 */
360                prim_cnt = vtx_cnt;
361             else
362                prim_cnt = nir_undef(b, 1, 32);
363 
364             if (state->count_decomposed_prims) {
365                decomposed_prim_cnt =
366                   nir_load_var(b, state->decomposed_primitive_count_vars[stream]);
367             } else {
368                decomposed_prim_cnt = nir_undef(b, 1, 32);
369             }
370          }
371 
372          nir_set_vertex_and_primitive_count(b, vtx_cnt, prim_cnt,
373                                             decomposed_prim_cnt, stream);
374          state->progress = true;
375       }
376    }
377 }
378 
379 /*
380  * Append an EndPrimitive intrinsic to the end of the geometry shader. This
381  * allows the backend to emit primitives only when EndPrimitive is used. If this
382  * EndPrimitive is not needed, it will be predicated out via
383  * overwrite_incomplete_primitives.
384  */
385 static void
append_end_primitive(nir_block * end_block,struct state * state)386 append_end_primitive(nir_block *end_block, struct state *state)
387 {
388    nir_builder *b = state->builder;
389 
390    /* Only end a primitive if there is a primitive to end */
391    if (b->shader->info.gs.active_stream_mask == 0)
392       return;
393 
394    /* Insert the new intrinsic in all of the predecessors of the end block,
395     * but before any jump instructions (return).
396     */
397    set_foreach(end_block->predecessors, entry) {
398       nir_block *pred = (nir_block *) entry->key;
399       b->cursor = nir_after_block_before_jump(pred);
400 
401       nir_end_primitive(b);
402    }
403 }
404 
405 /**
406  * Check to see if there are any blocks that need set_vertex_and_primitive_count
407  *
408  * If every block that could need the set_vertex_and_primitive_count intrinsic
409  * already has one, there is nothing for this pass to do.
410  */
411 static bool
a_block_needs_set_vertex_and_primitive_count(nir_block * end_block,bool per_stream)412 a_block_needs_set_vertex_and_primitive_count(nir_block *end_block, bool per_stream)
413 {
414    set_foreach(end_block->predecessors, entry) {
415       nir_block *pred = (nir_block *)entry->key;
416 
417       for (unsigned stream = 0; stream < NIR_MAX_XFB_STREAMS; ++stream) {
418          /* When it's not per-stream, we only need to write one variable. */
419          if (!per_stream && stream != 0)
420             continue;
421 
422          bool found = false;
423 
424          nir_foreach_instr_reverse(instr, pred) {
425             if (instr->type != nir_instr_type_intrinsic)
426                continue;
427 
428             const nir_intrinsic_instr *const intrin =
429                nir_instr_as_intrinsic(instr);
430 
431             if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count &&
432                 intrin->const_index[0] == stream) {
433                found = true;
434                break;
435             }
436          }
437 
438          if (!found)
439             return true;
440       }
441    }
442 
443    return false;
444 }
445 
446 bool
nir_lower_gs_intrinsics(nir_shader * shader,nir_lower_gs_intrinsics_flags options)447 nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options)
448 {
449    bool per_stream = options & nir_lower_gs_intrinsics_per_stream;
450    bool count_primitives = options & nir_lower_gs_intrinsics_count_primitives;
451    bool overwrite_incomplete = options & nir_lower_gs_intrinsics_overwrite_incomplete;
452    bool always_end_primitive_non_points = options & nir_lower_gs_intrinsics_always_end_primitive;
453    bool count_vtx_per_prim =
454       overwrite_incomplete ||
455       (options & nir_lower_gs_intrinsics_count_vertices_per_primitive);
456    bool count_decomposed_prims = options & nir_lower_gs_intrinsics_count_decomposed_primitives;
457 
458    bool is_points = shader->info.gs.output_primitive == MESA_PRIM_POINTS;
459    /* points are always complete primitives with a single vertex, so these are
460     * not needed when primitive is points.
461     */
462    if (is_points) {
463       count_primitives = false;
464       overwrite_incomplete = false;
465       count_vtx_per_prim = false;
466    }
467 
468    struct state state;
469    state.progress = false;
470    state.count_prims = count_primitives;
471    state.count_vtx_per_prim = count_vtx_per_prim;
472    state.count_decomposed_prims = count_decomposed_prims;
473    state.overwrite_incomplete = overwrite_incomplete;
474    state.per_stream = per_stream;
475    state.is_points = is_points;
476 
477    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
478    assert(impl);
479 
480    if (!a_block_needs_set_vertex_and_primitive_count(impl->end_block, per_stream))
481       return false;
482 
483    nir_builder b = nir_builder_at(nir_before_impl(impl));
484    state.builder = &b;
485 
486    for (unsigned i = 0; i < NIR_MAX_XFB_STREAMS; i++) {
487       if (per_stream && !(shader->info.gs.active_stream_mask & (1 << i)))
488          continue;
489 
490       if (i == 0 || per_stream) {
491          state.vertex_count_vars[i] =
492             nir_local_variable_create(impl, glsl_uint_type(), "vertex_count");
493          /* initialize to 0 */
494          nir_store_var(&b, state.vertex_count_vars[i], nir_imm_int(&b, 0), 0x1);
495 
496          if (count_primitives) {
497             state.primitive_count_vars[i] =
498                nir_local_variable_create(impl, glsl_uint_type(), "primitive_count");
499             /* initialize to 1 */
500             nir_store_var(&b, state.primitive_count_vars[i], nir_imm_int(&b, 1), 0x1);
501          }
502          if (count_vtx_per_prim) {
503             state.vtxcnt_per_prim_vars[i] =
504                nir_local_variable_create(impl, glsl_uint_type(), "vertices_per_primitive");
505             /* initialize to 0 */
506             nir_store_var(&b, state.vtxcnt_per_prim_vars[i], nir_imm_int(&b, 0), 0x1);
507          }
508          if (count_decomposed_prims) {
509             state.decomposed_primitive_count_vars[i] =
510                nir_local_variable_create(impl, glsl_uint_type(), "decomposed_primitive_count");
511             /* initialize to 0 */
512             nir_store_var(&b, state.decomposed_primitive_count_vars[i],
513                            nir_imm_int(&b, 0), 0x1);
514          }
515       } else {
516          /* If per_stream is false, we only have one counter of each kind which we
517           * want to use for all streams. Duplicate the counter pointers so all
518           * streams use the same counters.
519           */
520          state.vertex_count_vars[i] = state.vertex_count_vars[0];
521 
522          if (count_primitives)
523             state.primitive_count_vars[i] = state.primitive_count_vars[0];
524          if (count_vtx_per_prim)
525             state.vtxcnt_per_prim_vars[i] = state.vtxcnt_per_prim_vars[0];
526          if (count_decomposed_prims)
527             state.decomposed_primitive_count_vars[i] = state.decomposed_primitive_count_vars[0];
528       }
529    }
530 
531    if (always_end_primitive_non_points && !is_points)
532       append_end_primitive(impl->end_block, &state);
533 
534    nir_foreach_block_safe(block, impl)
535       rewrite_intrinsics(block, &state);
536 
537    /* This only works because we have a single main() function. */
538    append_set_vertex_and_primitive_count(impl->end_block, &state);
539 
540    nir_metadata_preserve(impl, nir_metadata_none);
541 
542    return state.progress;
543 }
544