xref: /aosp_15_r20/external/mesa3d/src/asahi/lib/agx_nir_lower_gs.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Alyssa Rosenzweig
3  * Copyright 2023 Valve Corporation
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "agx_nir_lower_gs.h"
8 #include "asahi/compiler/agx_compile.h"
9 #include "compiler/nir/nir_builder.h"
10 #include "gallium/include/pipe/p_defines.h"
11 #include "shaders/draws.h"
12 #include "shaders/geometry.h"
13 #include "util/bitscan.h"
14 #include "util/list.h"
15 #include "util/macros.h"
16 #include "util/ralloc.h"
17 #include "util/u_math.h"
18 #include "libagx_shaders.h"
19 #include "nir.h"
20 #include "nir_builder_opcodes.h"
21 #include "nir_intrinsics.h"
22 #include "nir_intrinsics_indices.h"
23 #include "nir_xfb_info.h"
24 #include "shader_enums.h"
25 
26 /* Marks a transform feedback store, which must not be stripped from the
27  * prepass since that's where the transform feedback happens. Chosen as a
28  * vendored flag not to alias other flags we'll see.
29  */
30 #define ACCESS_XFB (ACCESS_IS_SWIZZLED_AMD)
31 
32 enum gs_counter {
33    GS_COUNTER_VERTICES = 0,
34    GS_COUNTER_PRIMITIVES,
35    GS_COUNTER_XFB_PRIMITIVES,
36    GS_NUM_COUNTERS
37 };
38 
39 #define MAX_PRIM_OUT_SIZE 3
40 
41 struct lower_gs_state {
42    int static_count[GS_NUM_COUNTERS][MAX_VERTEX_STREAMS];
43    nir_variable *outputs[NUM_TOTAL_VARYING_SLOTS][MAX_PRIM_OUT_SIZE];
44 
45    /* The count buffer contains `count_stride_el` 32-bit words in a row for each
46     * input primitive, for `input_primitives * count_stride_el * 4` total bytes.
47     */
48    unsigned count_stride_el;
49 
50    /* The index of each counter in the count buffer, or -1 if it's not in the
51     * count buffer.
52     *
53     * Invariant: count_stride_el == sum(count_index[i][j] >= 0).
54     */
55    int count_index[MAX_VERTEX_STREAMS][GS_NUM_COUNTERS];
56 
57    bool rasterizer_discard;
58 };
59 
60 /* Helpers for loading from the geometry state buffer */
61 static nir_def *
load_geometry_param_offset(nir_builder * b,uint32_t offset,uint8_t bytes)62 load_geometry_param_offset(nir_builder *b, uint32_t offset, uint8_t bytes)
63 {
64    nir_def *base = nir_load_geometry_param_buffer_agx(b);
65    nir_def *addr = nir_iadd_imm(b, base, offset);
66 
67    assert((offset % bytes) == 0 && "must be naturally aligned");
68 
69    return nir_load_global_constant(b, addr, bytes, 1, bytes * 8);
70 }
71 
72 static void
store_geometry_param_offset(nir_builder * b,nir_def * def,uint32_t offset,uint8_t bytes)73 store_geometry_param_offset(nir_builder *b, nir_def *def, uint32_t offset,
74                             uint8_t bytes)
75 {
76    nir_def *base = nir_load_geometry_param_buffer_agx(b);
77    nir_def *addr = nir_iadd_imm(b, base, offset);
78 
79    assert((offset % bytes) == 0 && "must be naturally aligned");
80 
81    nir_store_global(b, addr, 4, def, nir_component_mask(def->num_components));
82 }
83 
84 #define store_geometry_param(b, field, def)                                    \
85    store_geometry_param_offset(                                                \
86       b, def, offsetof(struct agx_geometry_params, field),                     \
87       sizeof(((struct agx_geometry_params *)0)->field))
88 
89 #define load_geometry_param(b, field)                                          \
90    load_geometry_param_offset(                                                 \
91       b, offsetof(struct agx_geometry_params, field),                          \
92       sizeof(((struct agx_geometry_params *)0)->field))
93 
94 /* Helper for updating counters */
95 static void
add_counter(nir_builder * b,nir_def * counter,nir_def * increment)96 add_counter(nir_builder *b, nir_def *counter, nir_def *increment)
97 {
98    /* If the counter is NULL, the counter is disabled. Skip the update. */
99    nir_if *nif = nir_push_if(b, nir_ine_imm(b, counter, 0));
100    {
101       nir_def *old = nir_load_global(b, counter, 4, 1, 32);
102       nir_def *new_ = nir_iadd(b, old, increment);
103       nir_store_global(b, counter, 4, new_, nir_component_mask(1));
104    }
105    nir_pop_if(b, nif);
106 }
107 
108 /* Helpers for lowering I/O to variables */
109 static void
lower_store_to_var(nir_builder * b,nir_intrinsic_instr * intr,struct agx_lower_output_to_var_state * state)110 lower_store_to_var(nir_builder *b, nir_intrinsic_instr *intr,
111                    struct agx_lower_output_to_var_state *state)
112 {
113    b->cursor = nir_instr_remove(&intr->instr);
114    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
115    unsigned component = nir_intrinsic_component(intr);
116    nir_def *value = intr->src[0].ssa;
117 
118    assert(nir_src_is_const(intr->src[1]) && "no indirect outputs");
119    assert(nir_intrinsic_write_mask(intr) == nir_component_mask(1) &&
120           "should be scalarized");
121 
122    nir_variable *var =
123       state->outputs[sem.location + nir_src_as_uint(intr->src[1])];
124    if (!var) {
125       assert(sem.location == VARYING_SLOT_PSIZ &&
126              "otherwise in outputs_written");
127       return;
128    }
129 
130    unsigned nr_components = glsl_get_components(glsl_without_array(var->type));
131    assert(component < nr_components);
132 
133    /* Turn it into a vec4 write like NIR expects */
134    value = nir_vector_insert_imm(b, nir_undef(b, nr_components, 32), value,
135                                  component);
136 
137    nir_store_var(b, var, value, BITFIELD_BIT(component));
138 }
139 
140 bool
agx_lower_output_to_var(nir_builder * b,nir_instr * instr,void * data)141 agx_lower_output_to_var(nir_builder *b, nir_instr *instr, void *data)
142 {
143    if (instr->type != nir_instr_type_intrinsic)
144       return false;
145 
146    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
147    if (intr->intrinsic != nir_intrinsic_store_output)
148       return false;
149 
150    lower_store_to_var(b, intr, data);
151    return true;
152 }
153 
154 /*
155  * Geometry shader invocations are compute-like:
156  *
157  * (primitive ID, instance ID, 1)
158  */
159 static nir_def *
load_primitive_id(nir_builder * b)160 load_primitive_id(nir_builder *b)
161 {
162    return nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
163 }
164 
165 static nir_def *
load_instance_id(nir_builder * b)166 load_instance_id(nir_builder *b)
167 {
168    return nir_channel(b, nir_load_global_invocation_id(b, 32), 1);
169 }
170 
171 /* Geometry shaders use software input assembly. The software vertex shader
172  * is invoked for each index, and the geometry shader applies the topology. This
173  * helper applies the topology.
174  */
175 static nir_def *
vertex_id_for_topology_class(nir_builder * b,nir_def * vert,enum mesa_prim cls)176 vertex_id_for_topology_class(nir_builder *b, nir_def *vert, enum mesa_prim cls)
177 {
178    nir_def *prim = nir_load_primitive_id(b);
179    nir_def *flatshade_first = nir_ieq_imm(b, nir_load_provoking_last(b), 0);
180    nir_def *nr = load_geometry_param(b, gs_grid[0]);
181    nir_def *topology = nir_load_input_topology_agx(b);
182 
183    switch (cls) {
184    case MESA_PRIM_POINTS:
185       return prim;
186 
187    case MESA_PRIM_LINES:
188       return libagx_vertex_id_for_line_class(b, topology, prim, vert, nr);
189 
190    case MESA_PRIM_TRIANGLES:
191       return libagx_vertex_id_for_tri_class(b, topology, prim, vert,
192                                             flatshade_first);
193 
194    case MESA_PRIM_LINES_ADJACENCY:
195       return libagx_vertex_id_for_line_adj_class(b, topology, prim, vert);
196 
197    case MESA_PRIM_TRIANGLES_ADJACENCY:
198       return libagx_vertex_id_for_tri_adj_class(b, topology, prim, vert, nr,
199                                                 flatshade_first);
200 
201    default:
202       unreachable("invalid topology class");
203    }
204 }
205 
206 nir_def *
agx_load_per_vertex_input(nir_builder * b,nir_intrinsic_instr * intr,nir_def * vertex)207 agx_load_per_vertex_input(nir_builder *b, nir_intrinsic_instr *intr,
208                           nir_def *vertex)
209 {
210    assert(intr->intrinsic == nir_intrinsic_load_per_vertex_input);
211    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
212 
213    nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
214    nir_def *addr;
215 
216    if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
217       /* GS may be preceded by VS or TES so specified as param */
218       addr = libagx_geometry_input_address(
219          b, nir_load_geometry_param_buffer_agx(b), vertex, location);
220    } else {
221       assert(b->shader->info.stage == MESA_SHADER_TESS_CTRL);
222 
223       /* TCS always preceded by VS so we use the VS state directly */
224       addr = libagx_vertex_output_address(b, nir_load_vs_output_buffer_agx(b),
225                                           nir_load_vs_outputs_agx(b), vertex,
226                                           location);
227    }
228 
229    addr = nir_iadd_imm(b, addr, 4 * nir_intrinsic_component(intr));
230    return nir_load_global_constant(b, addr, 4, intr->def.num_components,
231                                    intr->def.bit_size);
232 }
233 
234 static bool
lower_gs_inputs(nir_builder * b,nir_intrinsic_instr * intr,void * _)235 lower_gs_inputs(nir_builder *b, nir_intrinsic_instr *intr, void *_)
236 {
237    if (intr->intrinsic != nir_intrinsic_load_per_vertex_input)
238       return false;
239 
240    b->cursor = nir_instr_remove(&intr->instr);
241 
242    /* Calculate the vertex ID we're pulling, based on the topology class */
243    nir_def *vert_in_prim = intr->src[0].ssa;
244    nir_def *vertex = vertex_id_for_topology_class(
245       b, vert_in_prim, b->shader->info.gs.input_primitive);
246 
247    nir_def *verts = load_geometry_param(b, vs_grid[0]);
248    nir_def *unrolled =
249       nir_iadd(b, nir_imul(b, nir_load_instance_id(b), verts), vertex);
250 
251    nir_def *val = agx_load_per_vertex_input(b, intr, unrolled);
252    nir_def_rewrite_uses(&intr->def, val);
253    return true;
254 }
255 
256 /*
257  * Unrolled ID is the index of the primitive in the count buffer, given as
258  * (instance ID * # vertices/instance) + vertex ID
259  */
260 static nir_def *
calc_unrolled_id(nir_builder * b)261 calc_unrolled_id(nir_builder *b)
262 {
263    return nir_iadd(
264       b, nir_imul(b, load_instance_id(b), load_geometry_param(b, gs_grid[0])),
265       load_primitive_id(b));
266 }
267 
268 static unsigned
output_vertex_id_stride(nir_shader * gs)269 output_vertex_id_stride(nir_shader *gs)
270 {
271    /* round up to power of two for cheap multiply/division */
272    return util_next_power_of_two(MAX2(gs->info.gs.vertices_out, 1));
273 }
274 
275 /* Variant of calc_unrolled_id that uses a power-of-two stride for indices. This
276  * is sparser (acceptable for index buffer values, not for count buffer
277  * indices). It has the nice property of being cheap to invert, unlike
278  * calc_unrolled_id. So, we use calc_unrolled_id for count buffers and
279  * calc_unrolled_index_id for index values.
280  *
281  * This also multiplies by the appropriate stride to calculate the final index
282  * base value.
283  */
284 static nir_def *
calc_unrolled_index_id(nir_builder * b)285 calc_unrolled_index_id(nir_builder *b)
286 {
287    unsigned vertex_stride = output_vertex_id_stride(b->shader);
288    nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
289 
290    nir_def *instance = nir_ishl(b, load_instance_id(b), primitives_log2);
291    nir_def *prim = nir_iadd(b, instance, load_primitive_id(b));
292 
293    return nir_imul_imm(b, prim, vertex_stride);
294 }
295 
296 static nir_def *
load_count_address(nir_builder * b,struct lower_gs_state * state,nir_def * unrolled_id,unsigned stream,enum gs_counter counter)297 load_count_address(nir_builder *b, struct lower_gs_state *state,
298                    nir_def *unrolled_id, unsigned stream,
299                    enum gs_counter counter)
300 {
301    int index = state->count_index[stream][counter];
302    if (index < 0)
303       return NULL;
304 
305    nir_def *prim_offset_el =
306       nir_imul_imm(b, unrolled_id, state->count_stride_el);
307 
308    nir_def *offset_el = nir_iadd_imm(b, prim_offset_el, index);
309 
310    return nir_iadd(b, load_geometry_param(b, count_buffer),
311                    nir_u2u64(b, nir_imul_imm(b, offset_el, 4)));
312 }
313 
314 static void
write_counts(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)315 write_counts(nir_builder *b, nir_intrinsic_instr *intr,
316              struct lower_gs_state *state)
317 {
318    /* Store each required counter */
319    nir_def *counts[GS_NUM_COUNTERS] = {
320       [GS_COUNTER_VERTICES] = intr->src[0].ssa,
321       [GS_COUNTER_PRIMITIVES] = intr->src[1].ssa,
322       [GS_COUNTER_XFB_PRIMITIVES] = intr->src[2].ssa,
323    };
324 
325    for (unsigned i = 0; i < GS_NUM_COUNTERS; ++i) {
326       nir_def *addr = load_count_address(b, state, calc_unrolled_id(b),
327                                          nir_intrinsic_stream_id(intr), i);
328 
329       if (addr)
330          nir_store_global(b, addr, 4, counts[i], nir_component_mask(1));
331    }
332 }
333 
334 static bool
lower_gs_count_instr(nir_builder * b,nir_intrinsic_instr * intr,void * data)335 lower_gs_count_instr(nir_builder *b, nir_intrinsic_instr *intr, void *data)
336 {
337    switch (intr->intrinsic) {
338    case nir_intrinsic_emit_vertex_with_counter:
339    case nir_intrinsic_end_primitive_with_counter:
340    case nir_intrinsic_store_output:
341       /* These are for the main shader, just remove them */
342       nir_instr_remove(&intr->instr);
343       return true;
344 
345    case nir_intrinsic_set_vertex_and_primitive_count:
346       b->cursor = nir_instr_remove(&intr->instr);
347       write_counts(b, intr, data);
348       return true;
349 
350    default:
351       return false;
352    }
353 }
354 
355 static bool
lower_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)356 lower_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
357 {
358    b->cursor = nir_before_instr(&intr->instr);
359 
360    nir_def *id;
361    if (intr->intrinsic == nir_intrinsic_load_primitive_id)
362       id = load_primitive_id(b);
363    else if (intr->intrinsic == nir_intrinsic_load_instance_id)
364       id = load_instance_id(b);
365    else if (intr->intrinsic == nir_intrinsic_load_flat_mask)
366       id = load_geometry_param(b, flat_outputs);
367    else if (intr->intrinsic == nir_intrinsic_load_input_topology_agx)
368       id = load_geometry_param(b, input_topology);
369    else
370       return false;
371 
372    b->cursor = nir_instr_remove(&intr->instr);
373    nir_def_rewrite_uses(&intr->def, id);
374    return true;
375 }
376 
377 /*
378  * Create a "Geometry count" shader. This is a stripped down geometry shader
379  * that just write its number of emitted vertices / primitives / transform
380  * feedback primitives to a count buffer. That count buffer will be prefix
381  * summed prior to running the real geometry shader. This is skipped if the
382  * counts are statically known.
383  */
384 static nir_shader *
agx_nir_create_geometry_count_shader(nir_shader * gs,const nir_shader * libagx,struct lower_gs_state * state)385 agx_nir_create_geometry_count_shader(nir_shader *gs, const nir_shader *libagx,
386                                      struct lower_gs_state *state)
387 {
388    /* Don't muck up the original shader */
389    nir_shader *shader = nir_shader_clone(NULL, gs);
390 
391    if (shader->info.name) {
392       shader->info.name =
393          ralloc_asprintf(shader, "%s_count", shader->info.name);
394    } else {
395       shader->info.name = "count";
396    }
397 
398    NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_gs_count_instr,
399             nir_metadata_control_flow, state);
400 
401    NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_id,
402             nir_metadata_control_flow, NULL);
403 
404    agx_preprocess_nir(shader, libagx);
405    return shader;
406 }
407 
408 struct lower_gs_rast_state {
409    nir_def *instance_id, *primitive_id, *output_id;
410    struct agx_lower_output_to_var_state outputs;
411    struct agx_lower_output_to_var_state selected;
412 };
413 
414 static void
select_rast_output(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_rast_state * state)415 select_rast_output(nir_builder *b, nir_intrinsic_instr *intr,
416                    struct lower_gs_rast_state *state)
417 {
418    b->cursor = nir_instr_remove(&intr->instr);
419 
420    /* We only care about the rasterization stream in the rasterization
421     * shader, so just ignore emits from other streams.
422     */
423    if (nir_intrinsic_stream_id(intr) != 0)
424       return;
425 
426    u_foreach_bit64(slot, b->shader->info.outputs_written) {
427       nir_def *orig = nir_load_var(b, state->selected.outputs[slot]);
428       nir_def *data = nir_load_var(b, state->outputs.outputs[slot]);
429 
430       nir_def *value = nir_bcsel(
431          b, nir_ieq(b, intr->src[0].ssa, state->output_id), data, orig);
432 
433       nir_store_var(b, state->selected.outputs[slot], value,
434                     nir_component_mask(value->num_components));
435    }
436 }
437 
438 static bool
lower_to_gs_rast(nir_builder * b,nir_intrinsic_instr * intr,void * data)439 lower_to_gs_rast(nir_builder *b, nir_intrinsic_instr *intr, void *data)
440 {
441    struct lower_gs_rast_state *state = data;
442 
443    switch (intr->intrinsic) {
444    case nir_intrinsic_store_output:
445       lower_store_to_var(b, intr, &state->outputs);
446       return true;
447 
448    case nir_intrinsic_emit_vertex_with_counter:
449       select_rast_output(b, intr, state);
450       return true;
451 
452    case nir_intrinsic_load_primitive_id:
453       nir_def_rewrite_uses(&intr->def, state->primitive_id);
454       return true;
455 
456    case nir_intrinsic_load_instance_id:
457       nir_def_rewrite_uses(&intr->def, state->instance_id);
458       return true;
459 
460    case nir_intrinsic_load_flat_mask:
461    case nir_intrinsic_load_provoking_last:
462    case nir_intrinsic_load_input_topology_agx: {
463       /* Lowering the same in both GS variants */
464       return lower_id(b, intr, NULL);
465    }
466 
467    case nir_intrinsic_end_primitive_with_counter:
468    case nir_intrinsic_set_vertex_and_primitive_count:
469       nir_instr_remove(&intr->instr);
470       return true;
471 
472    default:
473       return false;
474    }
475 }
476 
477 /*
478  * Side effects in geometry shaders are problematic with our "GS rasterization
479  * shader" implementation. Where does the side effect happen? In the prepass?
480  * In the rast shader? In both?
481  *
482  * A perfect solution is impossible with rast shaders. Since the spec is loose
483  * here, we follow the principle of "least surprise":
484  *
485  * 1. Prefer side effects in the prepass over the rast shader. The prepass runs
486  *    once per API GS invocation so will match the expectations of buggy apps
487  *    not written for tilers.
488  *
489  * 2. If we must execute any side effect in the rast shader, try to execute all
490  *    side effects only in the rast shader. If some side effects must happen in
491  *    the rast shader and others don't, this gets consistent counts
492  *    (i.e. if the app expects plain stores and atomics to match up).
493  *
494  * 3. If we must execute side effects in both rast and the prepass,
495  *    execute all side effects in the rast shader and strip what we can from
496  *    the prepass. This gets the "unsurprising" behaviour from #2 without
497  *    falling over for ridiculous uses of atomics.
498  */
499 static bool
strip_side_effect_from_rast(nir_builder * b,nir_intrinsic_instr * intr,void * data)500 strip_side_effect_from_rast(nir_builder *b, nir_intrinsic_instr *intr,
501                             void *data)
502 {
503    switch (intr->intrinsic) {
504    case nir_intrinsic_store_global:
505    case nir_intrinsic_global_atomic:
506    case nir_intrinsic_global_atomic_swap:
507       break;
508    default:
509       return false;
510    }
511 
512    /* If there's a side effect that's actually required, keep it. */
513    if (nir_intrinsic_infos[intr->intrinsic].has_dest &&
514        !list_is_empty(&intr->def.uses)) {
515 
516       bool *any = data;
517       *any = true;
518       return false;
519    }
520 
521    /* Otherwise, remove the dead instruction. */
522    nir_instr_remove(&intr->instr);
523    return true;
524 }
525 
526 static bool
strip_side_effects_from_rast(nir_shader * s,bool * side_effects_for_rast)527 strip_side_effects_from_rast(nir_shader *s, bool *side_effects_for_rast)
528 {
529    bool progress, any;
530 
531    /* Rather than complex analysis, clone and try to remove as many side effects
532     * as possible. Then we check if we removed them all. We need to loop to
533     * handle complex control flow with side effects, where we can strip
534     * everything but can't figure that out with a simple one-shot analysis.
535     */
536    nir_shader *clone = nir_shader_clone(NULL, s);
537 
538    /* Drop as much as we can */
539    do {
540       progress = false;
541       any = false;
542       NIR_PASS(progress, clone, nir_shader_intrinsics_pass,
543                strip_side_effect_from_rast, nir_metadata_control_flow, &any);
544 
545       NIR_PASS(progress, clone, nir_opt_dce);
546       NIR_PASS(progress, clone, nir_opt_dead_cf);
547    } while (progress);
548 
549    ralloc_free(clone);
550 
551    /* If we need atomics, leave them in */
552    if (any) {
553       *side_effects_for_rast = true;
554       return false;
555    }
556 
557    /* Else strip it all */
558    do {
559       progress = false;
560       any = false;
561       NIR_PASS(progress, s, nir_shader_intrinsics_pass,
562                strip_side_effect_from_rast, nir_metadata_control_flow, &any);
563 
564       NIR_PASS(progress, s, nir_opt_dce);
565       NIR_PASS(progress, s, nir_opt_dead_cf);
566    } while (progress);
567 
568    assert(!any);
569    return progress;
570 }
571 
572 static bool
strip_side_effect_from_main(nir_builder * b,nir_intrinsic_instr * intr,void * data)573 strip_side_effect_from_main(nir_builder *b, nir_intrinsic_instr *intr,
574                             void *data)
575 {
576    switch (intr->intrinsic) {
577    case nir_intrinsic_global_atomic:
578    case nir_intrinsic_global_atomic_swap:
579       break;
580    default:
581       return false;
582    }
583 
584    if (list_is_empty(&intr->def.uses)) {
585       nir_instr_remove(&intr->instr);
586       return true;
587    }
588 
589    return false;
590 }
591 
592 /*
593  * Create a GS rasterization shader. This is a hardware vertex shader that
594  * shades each rasterized output vertex in parallel.
595  */
596 static nir_shader *
agx_nir_create_gs_rast_shader(const nir_shader * gs,const nir_shader * libagx,bool * side_effects_for_rast)597 agx_nir_create_gs_rast_shader(const nir_shader *gs, const nir_shader *libagx,
598                               bool *side_effects_for_rast)
599 {
600    /* Don't muck up the original shader */
601    nir_shader *shader = nir_shader_clone(NULL, gs);
602 
603    unsigned max_verts = output_vertex_id_stride(shader);
604 
605    /* Turn into a vertex shader run only for rasterization. Transform feedback
606     * was handled in the prepass.
607     */
608    shader->info.stage = MESA_SHADER_VERTEX;
609    shader->info.has_transform_feedback_varyings = false;
610    memset(&shader->info.vs, 0, sizeof(shader->info.vs));
611    shader->xfb_info = NULL;
612 
613    if (shader->info.name) {
614       shader->info.name = ralloc_asprintf(shader, "%s_rast", shader->info.name);
615    } else {
616       shader->info.name = "gs rast";
617    }
618 
619    nir_builder b_ =
620       nir_builder_at(nir_before_impl(nir_shader_get_entrypoint(shader)));
621    nir_builder *b = &b_;
622 
623    NIR_PASS(_, shader, strip_side_effects_from_rast, side_effects_for_rast);
624 
625    /* Optimize out pointless gl_PointSize outputs. Bizarrely, these occur. */
626    if (shader->info.gs.output_primitive != MESA_PRIM_POINTS)
627       shader->info.outputs_written &= ~VARYING_BIT_PSIZ;
628 
629    /* See calc_unrolled_index_id */
630    nir_def *raw_id = nir_load_vertex_id(b);
631    nir_def *output_id = nir_umod_imm(b, raw_id, max_verts);
632    nir_def *unrolled = nir_udiv_imm(b, raw_id, max_verts);
633 
634    nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
635    nir_def *instance_id = nir_ushr(b, unrolled, primitives_log2);
636    nir_def *primitive_id = nir_iand(
637       b, unrolled,
638       nir_iadd_imm(b, nir_ishl(b, nir_imm_int(b, 1), primitives_log2), -1));
639 
640    struct lower_gs_rast_state rast_state = {
641       .instance_id = instance_id,
642       .primitive_id = primitive_id,
643       .output_id = output_id,
644    };
645 
646    u_foreach_bit64(slot, shader->info.outputs_written) {
647       const char *slot_name =
648          gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
649 
650       bool scalar = (slot == VARYING_SLOT_PSIZ) ||
651                     (slot == VARYING_SLOT_LAYER) ||
652                     (slot == VARYING_SLOT_VIEWPORT);
653       unsigned comps = scalar ? 1 : 4;
654 
655       rast_state.outputs.outputs[slot] = nir_variable_create(
656          shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, comps),
657          ralloc_asprintf(shader, "%s-temp", slot_name));
658 
659       rast_state.selected.outputs[slot] = nir_variable_create(
660          shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, comps),
661          ralloc_asprintf(shader, "%s-selected", slot_name));
662    }
663 
664    nir_shader_intrinsics_pass(shader, lower_to_gs_rast,
665                               nir_metadata_control_flow, &rast_state);
666 
667    b->cursor = nir_after_impl(b->impl);
668 
669    /* Forward each selected output to the rasterizer */
670    u_foreach_bit64(slot, shader->info.outputs_written) {
671       assert(rast_state.selected.outputs[slot] != NULL);
672       nir_def *value = nir_load_var(b, rast_state.selected.outputs[slot]);
673 
674       /* We set NIR_COMPACT_ARRAYS so clip/cull distance needs to come all in
675        * DIST0. Undo the offset if we need to.
676        */
677       assert(slot != VARYING_SLOT_CULL_DIST1);
678       unsigned offset = 0;
679       if (slot == VARYING_SLOT_CLIP_DIST1)
680          offset = 1;
681 
682       nir_store_output(b, value, nir_imm_int(b, offset),
683                        .io_semantics.location = slot - offset,
684                        .io_semantics.num_slots = 1,
685                        .write_mask = nir_component_mask(value->num_components),
686                        .src_type = nir_type_uint32);
687    }
688 
689    /* It is legal to omit the point size write from the geometry shader when
690     * drawing points. In this case, the point size is implicitly 1.0. To
691     * implement, insert a synthetic `gl_PointSize = 1.0` write into the GS copy
692     * shader, if the GS does not export a point size while drawing points.
693     */
694    bool is_points = gs->info.gs.output_primitive == MESA_PRIM_POINTS;
695 
696    if (!(shader->info.outputs_written & VARYING_BIT_PSIZ) && is_points) {
697       nir_store_output(b, nir_imm_float(b, 1.0), nir_imm_int(b, 0),
698                        .io_semantics.location = VARYING_SLOT_PSIZ,
699                        .io_semantics.num_slots = 1,
700                        .write_mask = nir_component_mask(1),
701                        .src_type = nir_type_float32);
702 
703       shader->info.outputs_written |= VARYING_BIT_PSIZ;
704    }
705 
706    nir_opt_idiv_const(shader, 16);
707 
708    agx_preprocess_nir(shader, libagx);
709    return shader;
710 }
711 
712 static nir_def *
previous_count(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id,enum gs_counter counter)713 previous_count(nir_builder *b, struct lower_gs_state *state, unsigned stream,
714                nir_def *unrolled_id, enum gs_counter counter)
715 {
716    assert(stream < MAX_VERTEX_STREAMS);
717    assert(counter < GS_NUM_COUNTERS);
718    int static_count = state->static_count[counter][stream];
719 
720    if (static_count >= 0) {
721       /* If the number of outputted vertices per invocation is known statically,
722        * we can calculate the base.
723        */
724       return nir_imul_imm(b, unrolled_id, static_count);
725    } else {
726       /* Otherwise, we need to load from the prefix sum buffer. Note that the
727        * sums are inclusive, so index 0 is nonzero. This requires a little
728        * fixup here. We use a saturating unsigned subtraction so we don't read
729        * out-of-bounds for zero.
730        *
731        * TODO: Optimize this.
732        */
733       nir_def *prim_minus_1 = nir_usub_sat(b, unrolled_id, nir_imm_int(b, 1));
734       nir_def *addr =
735          load_count_address(b, state, prim_minus_1, stream, counter);
736 
737       return nir_bcsel(b, nir_ieq_imm(b, unrolled_id, 0), nir_imm_int(b, 0),
738                        nir_load_global_constant(b, addr, 4, 1, 32));
739    }
740 }
741 
742 static nir_def *
previous_vertices(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)743 previous_vertices(nir_builder *b, struct lower_gs_state *state, unsigned stream,
744                   nir_def *unrolled_id)
745 {
746    return previous_count(b, state, stream, unrolled_id, GS_COUNTER_VERTICES);
747 }
748 
749 static nir_def *
previous_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)750 previous_primitives(nir_builder *b, struct lower_gs_state *state,
751                     unsigned stream, nir_def *unrolled_id)
752 {
753    return previous_count(b, state, stream, unrolled_id, GS_COUNTER_PRIMITIVES);
754 }
755 
756 static nir_def *
previous_xfb_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)757 previous_xfb_primitives(nir_builder *b, struct lower_gs_state *state,
758                         unsigned stream, nir_def *unrolled_id)
759 {
760    return previous_count(b, state, stream, unrolled_id,
761                          GS_COUNTER_XFB_PRIMITIVES);
762 }
763 
764 static void
lower_end_primitive(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)765 lower_end_primitive(nir_builder *b, nir_intrinsic_instr *intr,
766                     struct lower_gs_state *state)
767 {
768    assert((intr->intrinsic == nir_intrinsic_set_vertex_and_primitive_count ||
769            b->shader->info.gs.output_primitive != MESA_PRIM_POINTS) &&
770           "endprimitive for points should've been removed");
771 
772    /* The GS is the last stage before rasterization, so if we discard the
773     * rasterization, we don't output an index buffer, nothing will read it.
774     * Index buffer is only for the rasterization stream.
775     */
776    unsigned stream = nir_intrinsic_stream_id(intr);
777    if (state->rasterizer_discard || stream != 0)
778       return;
779 
780    libagx_end_primitive(
781       b, load_geometry_param(b, output_index_buffer), intr->src[0].ssa,
782       intr->src[1].ssa, intr->src[2].ssa,
783       previous_vertices(b, state, 0, calc_unrolled_id(b)),
784       previous_primitives(b, state, 0, calc_unrolled_id(b)),
785       calc_unrolled_index_id(b),
786       nir_imm_bool(b, b->shader->info.gs.output_primitive != MESA_PRIM_POINTS));
787 }
788 
789 static unsigned
verts_in_output_prim(nir_shader * gs)790 verts_in_output_prim(nir_shader *gs)
791 {
792    return mesa_vertices_per_prim(gs->info.gs.output_primitive);
793 }
794 
795 static void
write_xfb(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * index_in_strip,nir_def * prim_id_in_invocation)796 write_xfb(nir_builder *b, struct lower_gs_state *state, unsigned stream,
797           nir_def *index_in_strip, nir_def *prim_id_in_invocation)
798 {
799    struct nir_xfb_info *xfb = b->shader->xfb_info;
800    unsigned verts = verts_in_output_prim(b->shader);
801 
802    /* Get the index of this primitive in the XFB buffer. That is, the base for
803     * this invocation for the stream plus the offset within this invocation.
804     */
805    nir_def *invocation_base =
806       previous_xfb_primitives(b, state, stream, calc_unrolled_id(b));
807 
808    nir_def *prim_index = nir_iadd(b, invocation_base, prim_id_in_invocation);
809    nir_def *base_index = nir_imul_imm(b, prim_index, verts);
810 
811    nir_def *xfb_prims = load_geometry_param(b, xfb_prims[stream]);
812    nir_push_if(b, nir_ult(b, prim_index, xfb_prims));
813 
814    /* Write XFB for each output */
815    for (unsigned i = 0; i < xfb->output_count; ++i) {
816       nir_xfb_output_info output = xfb->outputs[i];
817 
818       /* Only write to the selected stream */
819       if (xfb->buffer_to_stream[output.buffer] != stream)
820          continue;
821 
822       unsigned buffer = output.buffer;
823       unsigned stride = xfb->buffers[buffer].stride;
824       unsigned count = util_bitcount(output.component_mask);
825 
826       for (unsigned vert = 0; vert < verts; ++vert) {
827          /* We write out the vertices backwards, since 0 is the current
828           * emitted vertex (which is actually the last vertex).
829           *
830           * We handle NULL var for
831           * KHR-Single-GL44.enhanced_layouts.xfb_capture_struct.
832           */
833          unsigned v = (verts - 1) - vert;
834          nir_variable *var = state->outputs[output.location][v];
835          nir_def *value = var ? nir_load_var(b, var) : nir_undef(b, 4, 32);
836 
837          /* In case output.component_mask contains invalid components, write
838           * out zeroes instead of blowing up validation.
839           *
840           * KHR-Single-GL44.enhanced_layouts.xfb_capture_inactive_output_component
841           * hits this.
842           */
843          value = nir_pad_vector_imm_int(b, value, 0, 4);
844 
845          nir_def *rotated_vert = nir_imm_int(b, vert);
846          if (verts == 3) {
847             /* Map vertices for output so we get consistent winding order. For
848              * the primitive index, we use the index_in_strip. This is actually
849              * the vertex index in the strip, hence
850              * offset by 2 relative to the true primitive index (#2 for the
851              * first triangle in the strip, #3 for the second). That's ok
852              * because only the parity matters.
853              */
854             rotated_vert = libagx_map_vertex_in_tri_strip(
855                b, index_in_strip, rotated_vert,
856                nir_inot(b, nir_i2b(b, nir_load_provoking_last(b))));
857          }
858 
859          nir_def *addr = libagx_xfb_vertex_address(
860             b, nir_load_geometry_param_buffer_agx(b), base_index, rotated_vert,
861             nir_imm_int(b, buffer), nir_imm_int(b, stride),
862             nir_imm_int(b, output.offset));
863 
864          nir_build_store_global(
865             b, nir_channels(b, value, output.component_mask), addr,
866             .align_mul = 4, .write_mask = nir_component_mask(count),
867             .access = ACCESS_XFB);
868       }
869    }
870 
871    nir_pop_if(b, NULL);
872 }
873 
874 /* Handle transform feedback for a given emit_vertex_with_counter */
875 static void
lower_emit_vertex_xfb(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)876 lower_emit_vertex_xfb(nir_builder *b, nir_intrinsic_instr *intr,
877                       struct lower_gs_state *state)
878 {
879    /* Transform feedback is written for each decomposed output primitive. Since
880     * we're writing strips, that means we output XFB for each vertex after the
881     * first complete primitive is formed.
882     */
883    unsigned first_prim = verts_in_output_prim(b->shader) - 1;
884    nir_def *index_in_strip = intr->src[1].ssa;
885 
886    nir_push_if(b, nir_uge_imm(b, index_in_strip, first_prim));
887    {
888       write_xfb(b, state, nir_intrinsic_stream_id(intr), index_in_strip,
889                 intr->src[3].ssa);
890    }
891    nir_pop_if(b, NULL);
892 
893    /* Transform feedback writes out entire primitives during the emit_vertex. To
894     * do that, we store the values at all vertices in the strip in a little ring
895     * buffer. Index #0 is always the most recent primitive (so non-XFB code can
896     * just grab index #0 without any checking). Index #1 is the previous vertex,
897     * and index #2 is the vertex before that. Now that we've written XFB, since
898     * we've emitted a vertex we need to cycle the ringbuffer, freeing up index
899     * #0 for the next vertex that we are about to emit. We do that by copying
900     * the first n - 1 vertices forward one slot, which has to happen with a
901     * backwards copy implemented here.
902     *
903     * If we're lucky, all of these copies will be propagated away. If we're
904     * unlucky, this involves at most 2 copies per component per XFB output per
905     * vertex.
906     */
907    u_foreach_bit64(slot, b->shader->info.outputs_written) {
908       /* Note: if we're outputting points, verts_in_output_prim will be 1, so
909        * this loop will not execute. This is intended: points are self-contained
910        * primitives and do not need these copies.
911        */
912       for (int v = verts_in_output_prim(b->shader) - 1; v >= 1; --v) {
913          nir_def *value = nir_load_var(b, state->outputs[slot][v - 1]);
914 
915          nir_store_var(b, state->outputs[slot][v], value,
916                        nir_component_mask(value->num_components));
917       }
918    }
919 }
920 
921 static bool
lower_gs_instr(nir_builder * b,nir_intrinsic_instr * intr,void * state)922 lower_gs_instr(nir_builder *b, nir_intrinsic_instr *intr, void *state)
923 {
924    b->cursor = nir_before_instr(&intr->instr);
925 
926    switch (intr->intrinsic) {
927    case nir_intrinsic_set_vertex_and_primitive_count:
928       /* This instruction is mostly for the count shader, so just remove. But
929        * for points, we write the index buffer here so the rast shader can map.
930        */
931       if (b->shader->info.gs.output_primitive == MESA_PRIM_POINTS) {
932          lower_end_primitive(b, intr, state);
933       }
934 
935       break;
936 
937    case nir_intrinsic_end_primitive_with_counter: {
938       unsigned min = verts_in_output_prim(b->shader);
939 
940       /* We only write out complete primitives */
941       nir_push_if(b, nir_uge_imm(b, intr->src[1].ssa, min));
942       {
943          lower_end_primitive(b, intr, state);
944       }
945       nir_pop_if(b, NULL);
946       break;
947    }
948 
949    case nir_intrinsic_emit_vertex_with_counter:
950       /* emit_vertex triggers transform feedback but is otherwise a no-op. */
951       if (b->shader->xfb_info)
952          lower_emit_vertex_xfb(b, intr, state);
953       break;
954 
955    default:
956       return false;
957    }
958 
959    nir_instr_remove(&intr->instr);
960    return true;
961 }
962 
963 static bool
collect_components(nir_builder * b,nir_intrinsic_instr * intr,void * data)964 collect_components(nir_builder *b, nir_intrinsic_instr *intr, void *data)
965 {
966    uint8_t *counts = data;
967    if (intr->intrinsic != nir_intrinsic_store_output)
968       return false;
969 
970    unsigned count = nir_intrinsic_component(intr) +
971                     util_last_bit(nir_intrinsic_write_mask(intr));
972 
973    unsigned loc =
974       nir_intrinsic_io_semantics(intr).location + nir_src_as_uint(intr->src[1]);
975 
976    uint8_t *total_count = &counts[loc];
977 
978    *total_count = MAX2(*total_count, count);
979    return true;
980 }
981 
982 /*
983  * Create the pre-GS shader. This is a small compute 1x1x1 kernel that produces
984  * an indirect draw to rasterize the produced geometry, as well as updates
985  * transform feedback offsets and counters as applicable.
986  */
987 static nir_shader *
agx_nir_create_pre_gs(struct lower_gs_state * state,const nir_shader * libagx,bool indexed,bool restart,struct nir_xfb_info * xfb,unsigned vertices_per_prim,uint8_t streams,unsigned invocations)988 agx_nir_create_pre_gs(struct lower_gs_state *state, const nir_shader *libagx,
989                       bool indexed, bool restart, struct nir_xfb_info *xfb,
990                       unsigned vertices_per_prim, uint8_t streams,
991                       unsigned invocations)
992 {
993    nir_builder b_ = nir_builder_init_simple_shader(
994       MESA_SHADER_COMPUTE, &agx_nir_options, "Pre-GS patch up");
995    nir_builder *b = &b_;
996 
997    /* Load the number of primitives input to the GS */
998    nir_def *unrolled_in_prims = load_geometry_param(b, input_primitives);
999 
1000    /* Setup the draw from the rasterization stream (0). */
1001    if (!state->rasterizer_discard) {
1002       libagx_build_gs_draw(
1003          b, nir_load_geometry_param_buffer_agx(b),
1004          previous_vertices(b, state, 0, unrolled_in_prims),
1005          restart ? previous_primitives(b, state, 0, unrolled_in_prims)
1006                  : nir_imm_int(b, 0));
1007    }
1008 
1009    /* Determine the number of primitives generated in each stream */
1010    nir_def *in_prims[MAX_VERTEX_STREAMS], *prims[MAX_VERTEX_STREAMS];
1011 
1012    u_foreach_bit(i, streams) {
1013       in_prims[i] = previous_xfb_primitives(b, state, i, unrolled_in_prims);
1014       prims[i] = in_prims[i];
1015 
1016       add_counter(b, load_geometry_param(b, prims_generated_counter[i]),
1017                   prims[i]);
1018    }
1019 
1020    if (xfb) {
1021       /* Write XFB addresses */
1022       nir_def *offsets[4] = {NULL};
1023       u_foreach_bit(i, xfb->buffers_written) {
1024          offsets[i] = libagx_setup_xfb_buffer(
1025             b, nir_load_geometry_param_buffer_agx(b), nir_imm_int(b, i));
1026       }
1027 
1028       /* Now clamp to the number that XFB captures */
1029       for (unsigned i = 0; i < xfb->output_count; ++i) {
1030          nir_xfb_output_info output = xfb->outputs[i];
1031 
1032          unsigned buffer = output.buffer;
1033          unsigned stream = xfb->buffer_to_stream[buffer];
1034          unsigned stride = xfb->buffers[buffer].stride;
1035          unsigned words_written = util_bitcount(output.component_mask);
1036          unsigned bytes_written = words_written * 4;
1037 
1038          /* Primitive P will write up to (but not including) offset:
1039           *
1040           *    xfb_offset + ((P - 1) * (verts_per_prim * stride))
1041           *               + ((verts_per_prim - 1) * stride)
1042           *               + output_offset
1043           *               + output_size
1044           *
1045           * Given an XFB buffer of size xfb_size, we get the inequality:
1046           *
1047           *    floor(P) <= (stride + xfb_size - xfb_offset - output_offset -
1048           *                     output_size) // (stride * verts_per_prim)
1049           */
1050          nir_def *size = load_geometry_param(b, xfb_size[buffer]);
1051          size = nir_iadd_imm(b, size, stride - output.offset - bytes_written);
1052          size = nir_isub(b, size, offsets[buffer]);
1053          size = nir_imax(b, size, nir_imm_int(b, 0));
1054          nir_def *max_prims = nir_udiv_imm(b, size, stride * vertices_per_prim);
1055 
1056          prims[stream] = nir_umin(b, prims[stream], max_prims);
1057       }
1058 
1059       nir_def *any_overflow = nir_imm_false(b);
1060 
1061       u_foreach_bit(i, streams) {
1062          nir_def *overflow = nir_ult(b, prims[i], in_prims[i]);
1063          any_overflow = nir_ior(b, any_overflow, overflow);
1064 
1065          store_geometry_param(b, xfb_prims[i], prims[i]);
1066 
1067          add_counter(b, load_geometry_param(b, xfb_overflow[i]),
1068                      nir_b2i32(b, overflow));
1069 
1070          add_counter(b, load_geometry_param(b, xfb_prims_generated_counter[i]),
1071                      prims[i]);
1072       }
1073 
1074       add_counter(b, load_geometry_param(b, xfb_any_overflow),
1075                   nir_b2i32(b, any_overflow));
1076 
1077       /* Update XFB counters */
1078       u_foreach_bit(i, xfb->buffers_written) {
1079          uint32_t prim_stride_B = xfb->buffers[i].stride * vertices_per_prim;
1080          unsigned stream = xfb->buffer_to_stream[i];
1081 
1082          nir_def *off_ptr = load_geometry_param(b, xfb_offs_ptrs[i]);
1083          nir_def *size = nir_imul_imm(b, prims[stream], prim_stride_B);
1084          add_counter(b, off_ptr, size);
1085       }
1086    }
1087 
1088    /* The geometry shader receives a number of input primitives. The driver
1089     * should disable this counter when tessellation is active TODO and count
1090     * patches separately.
1091     */
1092    add_counter(
1093       b,
1094       nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_IA_PRIMITIVES),
1095       unrolled_in_prims);
1096 
1097    /* The geometry shader is invoked once per primitive (after unrolling
1098     * primitive restart). From the spec:
1099     *
1100     *    In case of instanced geometry shaders (see section 11.3.4.2) the
1101     *    geometry shader invocations count is incremented for each separate
1102     *    instanced invocation.
1103     */
1104    add_counter(b,
1105                nir_load_stat_query_address_agx(
1106                   b, .base = PIPE_STAT_QUERY_GS_INVOCATIONS),
1107                nir_imul_imm(b, unrolled_in_prims, invocations));
1108 
1109    nir_def *emitted_prims = nir_imm_int(b, 0);
1110    u_foreach_bit(i, streams) {
1111       emitted_prims =
1112          nir_iadd(b, emitted_prims,
1113                   previous_xfb_primitives(b, state, i, unrolled_in_prims));
1114    }
1115 
1116    add_counter(
1117       b,
1118       nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_GS_PRIMITIVES),
1119       emitted_prims);
1120 
1121    /* Clipper queries are not well-defined, so we can emulate them in lots of
1122     * silly ways. We need the hardware counters to implement them properly. For
1123     * now, just consider all primitives emitted as passing through the clipper.
1124     * This satisfies spec text:
1125     *
1126     *    The number of primitives that reach the primitive clipping stage.
1127     *
1128     * and
1129     *
1130     *    If at least one vertex of the primitive lies inside the clipping
1131     *    volume, the counter is incremented by one or more. Otherwise, the
1132     *    counter is incremented by zero or more.
1133     */
1134    add_counter(
1135       b,
1136       nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_PRIMITIVES),
1137       emitted_prims);
1138 
1139    add_counter(
1140       b,
1141       nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_INVOCATIONS),
1142       emitted_prims);
1143 
1144    agx_preprocess_nir(b->shader, libagx);
1145    return b->shader;
1146 }
1147 
1148 static bool
rewrite_invocation_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)1149 rewrite_invocation_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1150 {
1151    if (intr->intrinsic != nir_intrinsic_load_invocation_id)
1152       return false;
1153 
1154    b->cursor = nir_instr_remove(&intr->instr);
1155    nir_def_rewrite_uses(&intr->def, nir_u2uN(b, data, intr->def.bit_size));
1156    return true;
1157 }
1158 
1159 /*
1160  * Geometry shader instancing allows a GS to run multiple times. The number of
1161  * times is statically known and small. It's easiest to turn this into a loop
1162  * inside the GS, to avoid the feature "leaking" outside and affecting e.g. the
1163  * counts.
1164  */
1165 static void
agx_nir_lower_gs_instancing(nir_shader * gs)1166 agx_nir_lower_gs_instancing(nir_shader *gs)
1167 {
1168    unsigned nr_invocations = gs->info.gs.invocations;
1169    nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1170 
1171    /* Each invocation can produce up to the shader-declared max_vertices, so
1172     * multiply it up for proper bounds check. Emitting more than the declared
1173     * max_vertices per invocation results in undefined behaviour, so erroneously
1174     * emitting more as asked on early invocations is a perfectly cromulent
1175     * behvaiour.
1176     */
1177    gs->info.gs.vertices_out *= gs->info.gs.invocations;
1178 
1179    /* Get the original function */
1180    nir_cf_list list;
1181    nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl));
1182 
1183    /* Create a builder for the wrapped function */
1184    nir_builder b = nir_builder_at(nir_after_block(nir_start_block(impl)));
1185 
1186    nir_variable *i =
1187       nir_local_variable_create(impl, glsl_uintN_t_type(16), NULL);
1188    nir_store_var(&b, i, nir_imm_intN_t(&b, 0, 16), ~0);
1189    nir_def *index = NULL;
1190 
1191    /* Create a loop in the wrapped function */
1192    nir_loop *loop = nir_push_loop(&b);
1193    {
1194       index = nir_load_var(&b, i);
1195       nir_push_if(&b, nir_uge_imm(&b, index, nr_invocations));
1196       {
1197          nir_jump(&b, nir_jump_break);
1198       }
1199       nir_pop_if(&b, NULL);
1200 
1201       b.cursor = nir_cf_reinsert(&list, b.cursor);
1202       nir_store_var(&b, i, nir_iadd_imm(&b, index, 1), ~0);
1203 
1204       /* Make sure we end the primitive between invocations. If the geometry
1205        * shader already ended the primitive, this will get optimized out.
1206        */
1207       nir_end_primitive(&b);
1208    }
1209    nir_pop_loop(&b, loop);
1210 
1211    /* We've mucked about with control flow */
1212    nir_metadata_preserve(impl, nir_metadata_none);
1213 
1214    /* Use the loop counter as the invocation ID each iteration */
1215    nir_shader_intrinsics_pass(gs, rewrite_invocation_id,
1216                               nir_metadata_control_flow, index);
1217 }
1218 
1219 static void
link_libagx(nir_shader * nir,const nir_shader * libagx)1220 link_libagx(nir_shader *nir, const nir_shader *libagx)
1221 {
1222    nir_link_shader_functions(nir, libagx);
1223    NIR_PASS(_, nir, nir_inline_functions);
1224    nir_remove_non_entrypoints(nir);
1225    NIR_PASS(_, nir, nir_lower_indirect_derefs, nir_var_function_temp, 64);
1226    NIR_PASS(_, nir, nir_opt_dce);
1227    NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
1228             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared,
1229             glsl_get_cl_type_size_align);
1230    NIR_PASS(_, nir, nir_opt_deref);
1231    NIR_PASS(_, nir, nir_lower_vars_to_ssa);
1232    NIR_PASS(_, nir, nir_lower_explicit_io,
1233             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
1234                nir_var_mem_global,
1235             nir_address_format_62bit_generic);
1236 }
1237 
1238 bool
agx_nir_lower_gs(nir_shader * gs,const nir_shader * libagx,bool rasterizer_discard,nir_shader ** gs_count,nir_shader ** gs_copy,nir_shader ** pre_gs,enum mesa_prim * out_mode,unsigned * out_count_words)1239 agx_nir_lower_gs(nir_shader *gs, const nir_shader *libagx,
1240                  bool rasterizer_discard, nir_shader **gs_count,
1241                  nir_shader **gs_copy, nir_shader **pre_gs,
1242                  enum mesa_prim *out_mode, unsigned *out_count_words)
1243 {
1244    /* Lower I/O as assumed by the rest of GS lowering */
1245    if (gs->xfb_info != NULL) {
1246       NIR_PASS(_, gs, nir_io_add_const_offset_to_base,
1247                nir_var_shader_in | nir_var_shader_out);
1248       NIR_PASS(_, gs, nir_io_add_intrinsic_xfb_info);
1249    }
1250 
1251    NIR_PASS(_, gs, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
1252 
1253    /* Collect output component counts so we can size the geometry output buffer
1254     * appropriately, instead of assuming everything is vec4.
1255     */
1256    uint8_t component_counts[NUM_TOTAL_VARYING_SLOTS] = {0};
1257    nir_shader_intrinsics_pass(gs, collect_components, nir_metadata_all,
1258                               component_counts);
1259 
1260    /* If geometry shader instancing is used, lower it away before linking
1261     * anything. Otherwise, smash the invocation ID to zero.
1262     */
1263    if (gs->info.gs.invocations != 1) {
1264       agx_nir_lower_gs_instancing(gs);
1265    } else {
1266       nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1267       nir_builder b = nir_builder_at(nir_before_impl(impl));
1268 
1269       nir_shader_intrinsics_pass(gs, rewrite_invocation_id,
1270                                  nir_metadata_control_flow, nir_imm_int(&b, 0));
1271    }
1272 
1273    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_inputs,
1274             nir_metadata_control_flow, NULL);
1275 
1276    /* Lower geometry shader writes to contain all of the required counts, so we
1277     * know where in the various buffers we should write vertices.
1278     */
1279    NIR_PASS(_, gs, nir_lower_gs_intrinsics,
1280             nir_lower_gs_intrinsics_count_primitives |
1281                nir_lower_gs_intrinsics_per_stream |
1282                nir_lower_gs_intrinsics_count_vertices_per_primitive |
1283                nir_lower_gs_intrinsics_overwrite_incomplete |
1284                nir_lower_gs_intrinsics_always_end_primitive |
1285                nir_lower_gs_intrinsics_count_decomposed_primitives);
1286 
1287    /* Clean up after all that lowering we did */
1288    bool progress = false;
1289    do {
1290       progress = false;
1291       NIR_PASS(progress, gs, nir_lower_var_copies);
1292       NIR_PASS(progress, gs, nir_lower_variable_initializers,
1293                nir_var_shader_temp);
1294       NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1295       NIR_PASS(progress, gs, nir_copy_prop);
1296       NIR_PASS(progress, gs, nir_opt_constant_folding);
1297       NIR_PASS(progress, gs, nir_opt_algebraic);
1298       NIR_PASS(progress, gs, nir_opt_cse);
1299       NIR_PASS(progress, gs, nir_opt_dead_cf);
1300       NIR_PASS(progress, gs, nir_opt_dce);
1301 
1302       /* Unrolling lets us statically determine counts more often, which
1303        * otherwise would not be possible with multiple invocations even in the
1304        * simplest of cases.
1305        */
1306       NIR_PASS(progress, gs, nir_opt_loop_unroll);
1307    } while (progress);
1308 
1309    /* If we know counts at compile-time we can simplify, so try to figure out
1310     * the counts statically.
1311     */
1312    struct lower_gs_state gs_state = {
1313       .rasterizer_discard = rasterizer_discard,
1314    };
1315 
1316    nir_gs_count_vertices_and_primitives(
1317       gs, gs_state.static_count[GS_COUNTER_VERTICES],
1318       gs_state.static_count[GS_COUNTER_PRIMITIVES],
1319       gs_state.static_count[GS_COUNTER_XFB_PRIMITIVES], 4);
1320 
1321    /* Anything we don't know statically will be tracked by the count buffer.
1322     * Determine the layout for it.
1323     */
1324    for (unsigned i = 0; i < MAX_VERTEX_STREAMS; ++i) {
1325       for (unsigned c = 0; c < GS_NUM_COUNTERS; ++c) {
1326          gs_state.count_index[i][c] =
1327             (gs_state.static_count[c][i] < 0) ? gs_state.count_stride_el++ : -1;
1328       }
1329    }
1330 
1331    bool side_effects_for_rast = false;
1332    *gs_copy = agx_nir_create_gs_rast_shader(gs, libagx, &side_effects_for_rast);
1333 
1334    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1335             nir_metadata_control_flow, NULL);
1336 
1337    link_libagx(gs, libagx);
1338 
1339    NIR_PASS(_, gs, nir_lower_idiv,
1340             &(const nir_lower_idiv_options){.allow_fp16 = true});
1341 
1342    /* All those variables we created should've gone away by now */
1343    NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1344 
1345    /* If there is any unknown count, we need a geometry count shader */
1346    if (gs_state.count_stride_el > 0)
1347       *gs_count = agx_nir_create_geometry_count_shader(gs, libagx, &gs_state);
1348    else
1349       *gs_count = NULL;
1350 
1351    /* Geometry shader outputs are staged to temporaries */
1352    struct agx_lower_output_to_var_state state = {0};
1353 
1354    u_foreach_bit64(slot, gs->info.outputs_written) {
1355       /* After enough optimizations, the shader metadata can go out of sync, fix
1356        * with our gathered info. Otherwise glsl_vector_type will assert fail.
1357        */
1358       if (component_counts[slot] == 0) {
1359          gs->info.outputs_written &= ~BITFIELD64_BIT(slot);
1360          continue;
1361       }
1362 
1363       const char *slot_name =
1364          gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
1365 
1366       for (unsigned i = 0; i < MAX_PRIM_OUT_SIZE; ++i) {
1367          gs_state.outputs[slot][i] = nir_variable_create(
1368             gs, nir_var_shader_temp,
1369             glsl_vector_type(GLSL_TYPE_UINT, component_counts[slot]),
1370             ralloc_asprintf(gs, "%s-%u", slot_name, i));
1371       }
1372 
1373       state.outputs[slot] = gs_state.outputs[slot][0];
1374    }
1375 
1376    NIR_PASS(_, gs, nir_shader_instructions_pass, agx_lower_output_to_var,
1377             nir_metadata_control_flow, &state);
1378 
1379    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_instr,
1380             nir_metadata_none, &gs_state);
1381 
1382    /* Determine if we are guaranteed to rasterize at least one vertex, so that
1383     * we can strip the prepass of side effects knowing they will execute in the
1384     * rasterization shader.
1385     */
1386    bool rasterizes_at_least_one_vertex =
1387       !rasterizer_discard && gs_state.static_count[0][0] > 0;
1388 
1389    /* Clean up after all that lowering we did */
1390    nir_lower_global_vars_to_local(gs);
1391    do {
1392       progress = false;
1393       NIR_PASS(progress, gs, nir_lower_var_copies);
1394       NIR_PASS(progress, gs, nir_lower_variable_initializers,
1395                nir_var_shader_temp);
1396       NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1397       NIR_PASS(progress, gs, nir_copy_prop);
1398       NIR_PASS(progress, gs, nir_opt_constant_folding);
1399       NIR_PASS(progress, gs, nir_opt_algebraic);
1400       NIR_PASS(progress, gs, nir_opt_cse);
1401       NIR_PASS(progress, gs, nir_opt_dead_cf);
1402       NIR_PASS(progress, gs, nir_opt_dce);
1403       NIR_PASS(progress, gs, nir_opt_loop_unroll);
1404 
1405    } while (progress);
1406 
1407    /* When rasterizing, we try to handle side effects sensibly. */
1408    if (rasterizes_at_least_one_vertex && side_effects_for_rast) {
1409       do {
1410          progress = false;
1411          NIR_PASS(progress, gs, nir_shader_intrinsics_pass,
1412                   strip_side_effect_from_main, nir_metadata_control_flow, NULL);
1413 
1414          NIR_PASS(progress, gs, nir_opt_dce);
1415          NIR_PASS(progress, gs, nir_opt_dead_cf);
1416       } while (progress);
1417    }
1418 
1419    /* All those variables we created should've gone away by now */
1420    NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1421 
1422    NIR_PASS(_, gs, nir_opt_sink, ~0);
1423    NIR_PASS(_, gs, nir_opt_move, ~0);
1424 
1425    NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1426             nir_metadata_control_flow, NULL);
1427 
1428    /* Create auxiliary programs */
1429    *pre_gs = agx_nir_create_pre_gs(
1430       &gs_state, libagx, true, gs->info.gs.output_primitive != MESA_PRIM_POINTS,
1431       gs->xfb_info, verts_in_output_prim(gs), gs->info.gs.active_stream_mask,
1432       gs->info.gs.invocations);
1433 
1434    /* Signal what primitive we want to draw the GS Copy VS with */
1435    *out_mode = gs->info.gs.output_primitive;
1436    *out_count_words = gs_state.count_stride_el;
1437    return true;
1438 }
1439 
1440 /*
1441  * Vertex shaders (tessellation evaluation shaders) before a geometry shader run
1442  * as a dedicated compute prepass. They are invoked as (count, instances, 1).
1443  * Their linear ID is therefore (instances * num vertices) + vertex ID.
1444  *
1445  * This function lowers their vertex shader I/O to compute.
1446  *
1447  * Vertex ID becomes an index buffer pull (without applying the topology). Store
1448  * output becomes a store into the global vertex output buffer.
1449  */
1450 static bool
lower_vs_before_gs(nir_builder * b,nir_intrinsic_instr * intr,void * data)1451 lower_vs_before_gs(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1452 {
1453    if (intr->intrinsic != nir_intrinsic_store_output)
1454       return false;
1455 
1456    b->cursor = nir_instr_remove(&intr->instr);
1457    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
1458    nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
1459 
1460    /* We inline the outputs_written because it's known at compile-time, even
1461     * with shader objects. This lets us constant fold a bit of address math.
1462     */
1463    nir_def *mask = nir_imm_int64(b, b->shader->info.outputs_written);
1464 
1465    nir_def *buffer;
1466    nir_def *nr_verts;
1467    if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1468       buffer = nir_load_vs_output_buffer_agx(b);
1469       nr_verts =
1470          libagx_input_vertices(b, nir_load_input_assembly_buffer_agx(b));
1471    } else {
1472       assert(b->shader->info.stage == MESA_SHADER_TESS_EVAL);
1473 
1474       /* Instancing is unrolled during tessellation so nr_verts is ignored. */
1475       nr_verts = nir_imm_int(b, 0);
1476       buffer = libagx_tes_buffer(b, nir_load_tess_param_buffer_agx(b));
1477    }
1478 
1479    nir_def *linear_id = nir_iadd(b, nir_imul(b, load_instance_id(b), nr_verts),
1480                                  load_primitive_id(b));
1481 
1482    nir_def *addr =
1483       libagx_vertex_output_address(b, buffer, mask, linear_id, location);
1484 
1485    assert(nir_src_bit_size(intr->src[0]) == 32);
1486    addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
1487 
1488    nir_store_global(b, addr, 4, intr->src[0].ssa,
1489                     nir_intrinsic_write_mask(intr));
1490    return true;
1491 }
1492 
1493 bool
agx_nir_lower_vs_before_gs(struct nir_shader * vs,const struct nir_shader * libagx)1494 agx_nir_lower_vs_before_gs(struct nir_shader *vs,
1495                            const struct nir_shader *libagx)
1496 {
1497    bool progress = false;
1498 
1499    /* Lower vertex stores to memory stores */
1500    progress |= nir_shader_intrinsics_pass(vs, lower_vs_before_gs,
1501                                           nir_metadata_control_flow, NULL);
1502 
1503    /* Link libagx, used in lower_vs_before_gs */
1504    if (progress)
1505       link_libagx(vs, libagx);
1506 
1507    return progress;
1508 }
1509 
1510 void
agx_nir_prefix_sum_gs(nir_builder * b,const void * data)1511 agx_nir_prefix_sum_gs(nir_builder *b, const void *data)
1512 {
1513    const unsigned *words = data;
1514 
1515    b->shader->info.workgroup_size[0] = 1024;
1516 
1517    libagx_prefix_sum(b, load_geometry_param(b, count_buffer),
1518                      load_geometry_param(b, input_primitives),
1519                      nir_imm_int(b, *words),
1520                      nir_channel(b, nir_load_workgroup_id(b), 0));
1521 }
1522 
1523 void
agx_nir_prefix_sum_tess(nir_builder * b,const void * data)1524 agx_nir_prefix_sum_tess(nir_builder *b, const void *data)
1525 {
1526    b->shader->info.workgroup_size[0] = 1024;
1527    libagx_prefix_sum_tess(b, nir_load_preamble(b, 1, 64, .base = 0));
1528 }
1529 
1530 void
agx_nir_gs_setup_indirect(nir_builder * b,const void * data)1531 agx_nir_gs_setup_indirect(nir_builder *b, const void *data)
1532 {
1533    const struct agx_gs_setup_indirect_key *key = data;
1534 
1535    libagx_gs_setup_indirect(b, nir_load_preamble(b, 1, 64, .base = 0),
1536                             nir_imm_int(b, key->prim),
1537                             nir_channel(b, nir_load_local_invocation_id(b), 0));
1538 }
1539 
1540 void
agx_nir_unroll_restart(nir_builder * b,const void * data)1541 agx_nir_unroll_restart(nir_builder *b, const void *data)
1542 {
1543    const struct agx_unroll_restart_key *key = data;
1544    b->shader->info.workgroup_size[0] = 1024;
1545 
1546    nir_def *ia = nir_load_preamble(b, 1, 64, .base = 0);
1547    nir_def *draw = nir_channel(b, nir_load_workgroup_id(b), 0);
1548    nir_def *lane = nir_channel(b, nir_load_local_invocation_id(b), 0);
1549    nir_def *mode = nir_imm_int(b, key->prim);
1550 
1551    if (key->index_size_B == 1)
1552       libagx_unroll_restart_u8(b, ia, mode, draw, lane);
1553    else if (key->index_size_B == 2)
1554       libagx_unroll_restart_u16(b, ia, mode, draw, lane);
1555    else if (key->index_size_B == 4)
1556       libagx_unroll_restart_u32(b, ia, mode, draw, lane);
1557    else
1558       unreachable("invalid index size");
1559 }
1560 
1561 void
agx_nir_tessellate(nir_builder * b,const void * data)1562 agx_nir_tessellate(nir_builder *b, const void *data)
1563 {
1564    const struct agx_tessellator_key *key = data;
1565    b->shader->info.workgroup_size[0] = 64;
1566 
1567    nir_def *params = nir_load_preamble(b, 1, 64, .base = 0);
1568    nir_def *patch = nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
1569    nir_def *mode = nir_imm_int(b, key->mode);
1570    nir_def *partitioning = nir_imm_int(b, key->partitioning);
1571    nir_def *output_prim = nir_imm_int(b, key->output_primitive);
1572 
1573    if (key->prim == TESS_PRIMITIVE_ISOLINES)
1574       libagx_tess_isoline(b, params, mode, partitioning, output_prim, patch);
1575    else if (key->prim == TESS_PRIMITIVE_TRIANGLES)
1576       libagx_tess_tri(b, params, mode, partitioning, output_prim, patch);
1577    else if (key->prim == TESS_PRIMITIVE_QUADS)
1578       libagx_tess_quad(b, params, mode, partitioning, output_prim, patch);
1579    else
1580       unreachable("invalid tess primitive");
1581 }
1582 
1583 void
agx_nir_tess_setup_indirect(nir_builder * b,const void * data)1584 agx_nir_tess_setup_indirect(nir_builder *b, const void *data)
1585 {
1586    const struct agx_tess_setup_indirect_key *key = data;
1587 
1588    nir_def *params = nir_load_preamble(b, 1, 64, .base = 0);
1589    nir_def *with_counts = nir_imm_bool(b, key->with_counts);
1590    nir_def *point_mode = nir_imm_bool(b, key->point_mode);
1591 
1592    libagx_tess_setup_indirect(b, params, with_counts, point_mode);
1593 }
1594 
1595 void
agx_nir_increment_statistic(nir_builder * b,const void * data)1596 agx_nir_increment_statistic(nir_builder *b, const void *data)
1597 {
1598    libagx_increment_statistic(b, nir_load_preamble(b, 1, 64, .base = 0));
1599 }
1600 
1601 void
agx_nir_increment_cs_invocations(nir_builder * b,const void * data)1602 agx_nir_increment_cs_invocations(nir_builder *b, const void *data)
1603 {
1604    libagx_increment_cs_invocations(b, nir_load_preamble(b, 1, 64, .base = 0));
1605 }
1606 
1607 void
agx_nir_increment_ia_counters(nir_builder * b,const void * data)1608 agx_nir_increment_ia_counters(nir_builder *b, const void *data)
1609 {
1610    const struct agx_increment_ia_counters_key *key = data;
1611    b->shader->info.workgroup_size[0] = key->index_size_B ? 1024 : 1;
1612 
1613    nir_def *params = nir_load_preamble(b, 1, 64, .base = 0);
1614    nir_def *index_size_B = nir_imm_int(b, key->index_size_B);
1615    nir_def *thread = nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
1616 
1617    libagx_increment_ia_counters(b, params, index_size_B, thread);
1618 }
1619 
1620 void
agx_nir_predicate_indirect(nir_builder * b,const void * data)1621 agx_nir_predicate_indirect(nir_builder *b, const void *data)
1622 {
1623    const struct agx_predicate_indirect_key *key = data;
1624 
1625    nir_def *params = nir_load_preamble(b, 1, 64, .base = 0);
1626    nir_def *indexed = nir_imm_bool(b, key->indexed);
1627    nir_def *thread = nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
1628 
1629    libagx_predicate_indirect(b, params, thread, indexed);
1630 }
1631 
1632 void
agx_nir_decompress(nir_builder * b,const void * data)1633 agx_nir_decompress(nir_builder *b, const void *data)
1634 {
1635    const struct agx_decompress_key *key = data;
1636 
1637    nir_def *params = nir_load_preamble(b, 1, 64, .base = 0);
1638    nir_def *tile = nir_load_workgroup_id(b);
1639    nir_def *local = nir_channel(b, nir_load_local_invocation_id(b), 0);
1640    nir_def *samples = nir_imm_int(b, key->nr_samples);
1641 
1642    libagx_decompress(b, params, tile, local, samples);
1643 }
1644