1 /*
2 * Copyright 2023 Alyssa Rosenzweig
3 * Copyright 2023 Valve Corporation
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "agx_nir_lower_gs.h"
8 #include "asahi/compiler/agx_compile.h"
9 #include "compiler/nir/nir_builder.h"
10 #include "gallium/include/pipe/p_defines.h"
11 #include "shaders/draws.h"
12 #include "shaders/geometry.h"
13 #include "util/bitscan.h"
14 #include "util/list.h"
15 #include "util/macros.h"
16 #include "util/ralloc.h"
17 #include "util/u_math.h"
18 #include "libagx_shaders.h"
19 #include "nir.h"
20 #include "nir_builder_opcodes.h"
21 #include "nir_intrinsics.h"
22 #include "nir_intrinsics_indices.h"
23 #include "nir_xfb_info.h"
24 #include "shader_enums.h"
25
26 /* Marks a transform feedback store, which must not be stripped from the
27 * prepass since that's where the transform feedback happens. Chosen as a
28 * vendored flag not to alias other flags we'll see.
29 */
30 #define ACCESS_XFB (ACCESS_IS_SWIZZLED_AMD)
31
32 enum gs_counter {
33 GS_COUNTER_VERTICES = 0,
34 GS_COUNTER_PRIMITIVES,
35 GS_COUNTER_XFB_PRIMITIVES,
36 GS_NUM_COUNTERS
37 };
38
39 #define MAX_PRIM_OUT_SIZE 3
40
41 struct lower_gs_state {
42 int static_count[GS_NUM_COUNTERS][MAX_VERTEX_STREAMS];
43 nir_variable *outputs[NUM_TOTAL_VARYING_SLOTS][MAX_PRIM_OUT_SIZE];
44
45 /* The count buffer contains `count_stride_el` 32-bit words in a row for each
46 * input primitive, for `input_primitives * count_stride_el * 4` total bytes.
47 */
48 unsigned count_stride_el;
49
50 /* The index of each counter in the count buffer, or -1 if it's not in the
51 * count buffer.
52 *
53 * Invariant: count_stride_el == sum(count_index[i][j] >= 0).
54 */
55 int count_index[MAX_VERTEX_STREAMS][GS_NUM_COUNTERS];
56
57 bool rasterizer_discard;
58 };
59
60 /* Helpers for loading from the geometry state buffer */
61 static nir_def *
load_geometry_param_offset(nir_builder * b,uint32_t offset,uint8_t bytes)62 load_geometry_param_offset(nir_builder *b, uint32_t offset, uint8_t bytes)
63 {
64 nir_def *base = nir_load_geometry_param_buffer_agx(b);
65 nir_def *addr = nir_iadd_imm(b, base, offset);
66
67 assert((offset % bytes) == 0 && "must be naturally aligned");
68
69 return nir_load_global_constant(b, addr, bytes, 1, bytes * 8);
70 }
71
72 static void
store_geometry_param_offset(nir_builder * b,nir_def * def,uint32_t offset,uint8_t bytes)73 store_geometry_param_offset(nir_builder *b, nir_def *def, uint32_t offset,
74 uint8_t bytes)
75 {
76 nir_def *base = nir_load_geometry_param_buffer_agx(b);
77 nir_def *addr = nir_iadd_imm(b, base, offset);
78
79 assert((offset % bytes) == 0 && "must be naturally aligned");
80
81 nir_store_global(b, addr, 4, def, nir_component_mask(def->num_components));
82 }
83
84 #define store_geometry_param(b, field, def) \
85 store_geometry_param_offset( \
86 b, def, offsetof(struct agx_geometry_params, field), \
87 sizeof(((struct agx_geometry_params *)0)->field))
88
89 #define load_geometry_param(b, field) \
90 load_geometry_param_offset( \
91 b, offsetof(struct agx_geometry_params, field), \
92 sizeof(((struct agx_geometry_params *)0)->field))
93
94 /* Helper for updating counters */
95 static void
add_counter(nir_builder * b,nir_def * counter,nir_def * increment)96 add_counter(nir_builder *b, nir_def *counter, nir_def *increment)
97 {
98 /* If the counter is NULL, the counter is disabled. Skip the update. */
99 nir_if *nif = nir_push_if(b, nir_ine_imm(b, counter, 0));
100 {
101 nir_def *old = nir_load_global(b, counter, 4, 1, 32);
102 nir_def *new_ = nir_iadd(b, old, increment);
103 nir_store_global(b, counter, 4, new_, nir_component_mask(1));
104 }
105 nir_pop_if(b, nif);
106 }
107
108 /* Helpers for lowering I/O to variables */
109 static void
lower_store_to_var(nir_builder * b,nir_intrinsic_instr * intr,struct agx_lower_output_to_var_state * state)110 lower_store_to_var(nir_builder *b, nir_intrinsic_instr *intr,
111 struct agx_lower_output_to_var_state *state)
112 {
113 b->cursor = nir_instr_remove(&intr->instr);
114 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
115 unsigned component = nir_intrinsic_component(intr);
116 nir_def *value = intr->src[0].ssa;
117
118 assert(nir_src_is_const(intr->src[1]) && "no indirect outputs");
119 assert(nir_intrinsic_write_mask(intr) == nir_component_mask(1) &&
120 "should be scalarized");
121
122 nir_variable *var =
123 state->outputs[sem.location + nir_src_as_uint(intr->src[1])];
124 if (!var) {
125 assert(sem.location == VARYING_SLOT_PSIZ &&
126 "otherwise in outputs_written");
127 return;
128 }
129
130 unsigned nr_components = glsl_get_components(glsl_without_array(var->type));
131 assert(component < nr_components);
132
133 /* Turn it into a vec4 write like NIR expects */
134 value = nir_vector_insert_imm(b, nir_undef(b, nr_components, 32), value,
135 component);
136
137 nir_store_var(b, var, value, BITFIELD_BIT(component));
138 }
139
140 bool
agx_lower_output_to_var(nir_builder * b,nir_instr * instr,void * data)141 agx_lower_output_to_var(nir_builder *b, nir_instr *instr, void *data)
142 {
143 if (instr->type != nir_instr_type_intrinsic)
144 return false;
145
146 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
147 if (intr->intrinsic != nir_intrinsic_store_output)
148 return false;
149
150 lower_store_to_var(b, intr, data);
151 return true;
152 }
153
154 /*
155 * Geometry shader invocations are compute-like:
156 *
157 * (primitive ID, instance ID, 1)
158 */
159 static nir_def *
load_primitive_id(nir_builder * b)160 load_primitive_id(nir_builder *b)
161 {
162 return nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
163 }
164
165 static nir_def *
load_instance_id(nir_builder * b)166 load_instance_id(nir_builder *b)
167 {
168 return nir_channel(b, nir_load_global_invocation_id(b, 32), 1);
169 }
170
171 /* Geometry shaders use software input assembly. The software vertex shader
172 * is invoked for each index, and the geometry shader applies the topology. This
173 * helper applies the topology.
174 */
175 static nir_def *
vertex_id_for_topology_class(nir_builder * b,nir_def * vert,enum mesa_prim cls)176 vertex_id_for_topology_class(nir_builder *b, nir_def *vert, enum mesa_prim cls)
177 {
178 nir_def *prim = nir_load_primitive_id(b);
179 nir_def *flatshade_first = nir_ieq_imm(b, nir_load_provoking_last(b), 0);
180 nir_def *nr = load_geometry_param(b, gs_grid[0]);
181 nir_def *topology = nir_load_input_topology_agx(b);
182
183 switch (cls) {
184 case MESA_PRIM_POINTS:
185 return prim;
186
187 case MESA_PRIM_LINES:
188 return libagx_vertex_id_for_line_class(b, topology, prim, vert, nr);
189
190 case MESA_PRIM_TRIANGLES:
191 return libagx_vertex_id_for_tri_class(b, topology, prim, vert,
192 flatshade_first);
193
194 case MESA_PRIM_LINES_ADJACENCY:
195 return libagx_vertex_id_for_line_adj_class(b, topology, prim, vert);
196
197 case MESA_PRIM_TRIANGLES_ADJACENCY:
198 return libagx_vertex_id_for_tri_adj_class(b, topology, prim, vert, nr,
199 flatshade_first);
200
201 default:
202 unreachable("invalid topology class");
203 }
204 }
205
206 nir_def *
agx_load_per_vertex_input(nir_builder * b,nir_intrinsic_instr * intr,nir_def * vertex)207 agx_load_per_vertex_input(nir_builder *b, nir_intrinsic_instr *intr,
208 nir_def *vertex)
209 {
210 assert(intr->intrinsic == nir_intrinsic_load_per_vertex_input);
211 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
212
213 nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
214 nir_def *addr;
215
216 if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
217 /* GS may be preceded by VS or TES so specified as param */
218 addr = libagx_geometry_input_address(
219 b, nir_load_geometry_param_buffer_agx(b), vertex, location);
220 } else {
221 assert(b->shader->info.stage == MESA_SHADER_TESS_CTRL);
222
223 /* TCS always preceded by VS so we use the VS state directly */
224 addr = libagx_vertex_output_address(b, nir_load_vs_output_buffer_agx(b),
225 nir_load_vs_outputs_agx(b), vertex,
226 location);
227 }
228
229 addr = nir_iadd_imm(b, addr, 4 * nir_intrinsic_component(intr));
230 return nir_load_global_constant(b, addr, 4, intr->def.num_components,
231 intr->def.bit_size);
232 }
233
234 static bool
lower_gs_inputs(nir_builder * b,nir_intrinsic_instr * intr,void * _)235 lower_gs_inputs(nir_builder *b, nir_intrinsic_instr *intr, void *_)
236 {
237 if (intr->intrinsic != nir_intrinsic_load_per_vertex_input)
238 return false;
239
240 b->cursor = nir_instr_remove(&intr->instr);
241
242 /* Calculate the vertex ID we're pulling, based on the topology class */
243 nir_def *vert_in_prim = intr->src[0].ssa;
244 nir_def *vertex = vertex_id_for_topology_class(
245 b, vert_in_prim, b->shader->info.gs.input_primitive);
246
247 nir_def *verts = load_geometry_param(b, vs_grid[0]);
248 nir_def *unrolled =
249 nir_iadd(b, nir_imul(b, nir_load_instance_id(b), verts), vertex);
250
251 nir_def *val = agx_load_per_vertex_input(b, intr, unrolled);
252 nir_def_rewrite_uses(&intr->def, val);
253 return true;
254 }
255
256 /*
257 * Unrolled ID is the index of the primitive in the count buffer, given as
258 * (instance ID * # vertices/instance) + vertex ID
259 */
260 static nir_def *
calc_unrolled_id(nir_builder * b)261 calc_unrolled_id(nir_builder *b)
262 {
263 return nir_iadd(
264 b, nir_imul(b, load_instance_id(b), load_geometry_param(b, gs_grid[0])),
265 load_primitive_id(b));
266 }
267
268 static unsigned
output_vertex_id_stride(nir_shader * gs)269 output_vertex_id_stride(nir_shader *gs)
270 {
271 /* round up to power of two for cheap multiply/division */
272 return util_next_power_of_two(MAX2(gs->info.gs.vertices_out, 1));
273 }
274
275 /* Variant of calc_unrolled_id that uses a power-of-two stride for indices. This
276 * is sparser (acceptable for index buffer values, not for count buffer
277 * indices). It has the nice property of being cheap to invert, unlike
278 * calc_unrolled_id. So, we use calc_unrolled_id for count buffers and
279 * calc_unrolled_index_id for index values.
280 *
281 * This also multiplies by the appropriate stride to calculate the final index
282 * base value.
283 */
284 static nir_def *
calc_unrolled_index_id(nir_builder * b)285 calc_unrolled_index_id(nir_builder *b)
286 {
287 unsigned vertex_stride = output_vertex_id_stride(b->shader);
288 nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
289
290 nir_def *instance = nir_ishl(b, load_instance_id(b), primitives_log2);
291 nir_def *prim = nir_iadd(b, instance, load_primitive_id(b));
292
293 return nir_imul_imm(b, prim, vertex_stride);
294 }
295
296 static nir_def *
load_count_address(nir_builder * b,struct lower_gs_state * state,nir_def * unrolled_id,unsigned stream,enum gs_counter counter)297 load_count_address(nir_builder *b, struct lower_gs_state *state,
298 nir_def *unrolled_id, unsigned stream,
299 enum gs_counter counter)
300 {
301 int index = state->count_index[stream][counter];
302 if (index < 0)
303 return NULL;
304
305 nir_def *prim_offset_el =
306 nir_imul_imm(b, unrolled_id, state->count_stride_el);
307
308 nir_def *offset_el = nir_iadd_imm(b, prim_offset_el, index);
309
310 return nir_iadd(b, load_geometry_param(b, count_buffer),
311 nir_u2u64(b, nir_imul_imm(b, offset_el, 4)));
312 }
313
314 static void
write_counts(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)315 write_counts(nir_builder *b, nir_intrinsic_instr *intr,
316 struct lower_gs_state *state)
317 {
318 /* Store each required counter */
319 nir_def *counts[GS_NUM_COUNTERS] = {
320 [GS_COUNTER_VERTICES] = intr->src[0].ssa,
321 [GS_COUNTER_PRIMITIVES] = intr->src[1].ssa,
322 [GS_COUNTER_XFB_PRIMITIVES] = intr->src[2].ssa,
323 };
324
325 for (unsigned i = 0; i < GS_NUM_COUNTERS; ++i) {
326 nir_def *addr = load_count_address(b, state, calc_unrolled_id(b),
327 nir_intrinsic_stream_id(intr), i);
328
329 if (addr)
330 nir_store_global(b, addr, 4, counts[i], nir_component_mask(1));
331 }
332 }
333
334 static bool
lower_gs_count_instr(nir_builder * b,nir_intrinsic_instr * intr,void * data)335 lower_gs_count_instr(nir_builder *b, nir_intrinsic_instr *intr, void *data)
336 {
337 switch (intr->intrinsic) {
338 case nir_intrinsic_emit_vertex_with_counter:
339 case nir_intrinsic_end_primitive_with_counter:
340 case nir_intrinsic_store_output:
341 /* These are for the main shader, just remove them */
342 nir_instr_remove(&intr->instr);
343 return true;
344
345 case nir_intrinsic_set_vertex_and_primitive_count:
346 b->cursor = nir_instr_remove(&intr->instr);
347 write_counts(b, intr, data);
348 return true;
349
350 default:
351 return false;
352 }
353 }
354
355 static bool
lower_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)356 lower_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
357 {
358 b->cursor = nir_before_instr(&intr->instr);
359
360 nir_def *id;
361 if (intr->intrinsic == nir_intrinsic_load_primitive_id)
362 id = load_primitive_id(b);
363 else if (intr->intrinsic == nir_intrinsic_load_instance_id)
364 id = load_instance_id(b);
365 else if (intr->intrinsic == nir_intrinsic_load_flat_mask)
366 id = load_geometry_param(b, flat_outputs);
367 else if (intr->intrinsic == nir_intrinsic_load_input_topology_agx)
368 id = load_geometry_param(b, input_topology);
369 else
370 return false;
371
372 b->cursor = nir_instr_remove(&intr->instr);
373 nir_def_rewrite_uses(&intr->def, id);
374 return true;
375 }
376
377 /*
378 * Create a "Geometry count" shader. This is a stripped down geometry shader
379 * that just write its number of emitted vertices / primitives / transform
380 * feedback primitives to a count buffer. That count buffer will be prefix
381 * summed prior to running the real geometry shader. This is skipped if the
382 * counts are statically known.
383 */
384 static nir_shader *
agx_nir_create_geometry_count_shader(nir_shader * gs,const nir_shader * libagx,struct lower_gs_state * state)385 agx_nir_create_geometry_count_shader(nir_shader *gs, const nir_shader *libagx,
386 struct lower_gs_state *state)
387 {
388 /* Don't muck up the original shader */
389 nir_shader *shader = nir_shader_clone(NULL, gs);
390
391 if (shader->info.name) {
392 shader->info.name =
393 ralloc_asprintf(shader, "%s_count", shader->info.name);
394 } else {
395 shader->info.name = "count";
396 }
397
398 NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_gs_count_instr,
399 nir_metadata_control_flow, state);
400
401 NIR_PASS(_, shader, nir_shader_intrinsics_pass, lower_id,
402 nir_metadata_control_flow, NULL);
403
404 agx_preprocess_nir(shader, libagx);
405 return shader;
406 }
407
408 struct lower_gs_rast_state {
409 nir_def *instance_id, *primitive_id, *output_id;
410 struct agx_lower_output_to_var_state outputs;
411 struct agx_lower_output_to_var_state selected;
412 };
413
414 static void
select_rast_output(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_rast_state * state)415 select_rast_output(nir_builder *b, nir_intrinsic_instr *intr,
416 struct lower_gs_rast_state *state)
417 {
418 b->cursor = nir_instr_remove(&intr->instr);
419
420 /* We only care about the rasterization stream in the rasterization
421 * shader, so just ignore emits from other streams.
422 */
423 if (nir_intrinsic_stream_id(intr) != 0)
424 return;
425
426 u_foreach_bit64(slot, b->shader->info.outputs_written) {
427 nir_def *orig = nir_load_var(b, state->selected.outputs[slot]);
428 nir_def *data = nir_load_var(b, state->outputs.outputs[slot]);
429
430 nir_def *value = nir_bcsel(
431 b, nir_ieq(b, intr->src[0].ssa, state->output_id), data, orig);
432
433 nir_store_var(b, state->selected.outputs[slot], value,
434 nir_component_mask(value->num_components));
435 }
436 }
437
438 static bool
lower_to_gs_rast(nir_builder * b,nir_intrinsic_instr * intr,void * data)439 lower_to_gs_rast(nir_builder *b, nir_intrinsic_instr *intr, void *data)
440 {
441 struct lower_gs_rast_state *state = data;
442
443 switch (intr->intrinsic) {
444 case nir_intrinsic_store_output:
445 lower_store_to_var(b, intr, &state->outputs);
446 return true;
447
448 case nir_intrinsic_emit_vertex_with_counter:
449 select_rast_output(b, intr, state);
450 return true;
451
452 case nir_intrinsic_load_primitive_id:
453 nir_def_rewrite_uses(&intr->def, state->primitive_id);
454 return true;
455
456 case nir_intrinsic_load_instance_id:
457 nir_def_rewrite_uses(&intr->def, state->instance_id);
458 return true;
459
460 case nir_intrinsic_load_flat_mask:
461 case nir_intrinsic_load_provoking_last:
462 case nir_intrinsic_load_input_topology_agx: {
463 /* Lowering the same in both GS variants */
464 return lower_id(b, intr, NULL);
465 }
466
467 case nir_intrinsic_end_primitive_with_counter:
468 case nir_intrinsic_set_vertex_and_primitive_count:
469 nir_instr_remove(&intr->instr);
470 return true;
471
472 default:
473 return false;
474 }
475 }
476
477 /*
478 * Side effects in geometry shaders are problematic with our "GS rasterization
479 * shader" implementation. Where does the side effect happen? In the prepass?
480 * In the rast shader? In both?
481 *
482 * A perfect solution is impossible with rast shaders. Since the spec is loose
483 * here, we follow the principle of "least surprise":
484 *
485 * 1. Prefer side effects in the prepass over the rast shader. The prepass runs
486 * once per API GS invocation so will match the expectations of buggy apps
487 * not written for tilers.
488 *
489 * 2. If we must execute any side effect in the rast shader, try to execute all
490 * side effects only in the rast shader. If some side effects must happen in
491 * the rast shader and others don't, this gets consistent counts
492 * (i.e. if the app expects plain stores and atomics to match up).
493 *
494 * 3. If we must execute side effects in both rast and the prepass,
495 * execute all side effects in the rast shader and strip what we can from
496 * the prepass. This gets the "unsurprising" behaviour from #2 without
497 * falling over for ridiculous uses of atomics.
498 */
499 static bool
strip_side_effect_from_rast(nir_builder * b,nir_intrinsic_instr * intr,void * data)500 strip_side_effect_from_rast(nir_builder *b, nir_intrinsic_instr *intr,
501 void *data)
502 {
503 switch (intr->intrinsic) {
504 case nir_intrinsic_store_global:
505 case nir_intrinsic_global_atomic:
506 case nir_intrinsic_global_atomic_swap:
507 break;
508 default:
509 return false;
510 }
511
512 /* If there's a side effect that's actually required, keep it. */
513 if (nir_intrinsic_infos[intr->intrinsic].has_dest &&
514 !list_is_empty(&intr->def.uses)) {
515
516 bool *any = data;
517 *any = true;
518 return false;
519 }
520
521 /* Otherwise, remove the dead instruction. */
522 nir_instr_remove(&intr->instr);
523 return true;
524 }
525
526 static bool
strip_side_effects_from_rast(nir_shader * s,bool * side_effects_for_rast)527 strip_side_effects_from_rast(nir_shader *s, bool *side_effects_for_rast)
528 {
529 bool progress, any;
530
531 /* Rather than complex analysis, clone and try to remove as many side effects
532 * as possible. Then we check if we removed them all. We need to loop to
533 * handle complex control flow with side effects, where we can strip
534 * everything but can't figure that out with a simple one-shot analysis.
535 */
536 nir_shader *clone = nir_shader_clone(NULL, s);
537
538 /* Drop as much as we can */
539 do {
540 progress = false;
541 any = false;
542 NIR_PASS(progress, clone, nir_shader_intrinsics_pass,
543 strip_side_effect_from_rast, nir_metadata_control_flow, &any);
544
545 NIR_PASS(progress, clone, nir_opt_dce);
546 NIR_PASS(progress, clone, nir_opt_dead_cf);
547 } while (progress);
548
549 ralloc_free(clone);
550
551 /* If we need atomics, leave them in */
552 if (any) {
553 *side_effects_for_rast = true;
554 return false;
555 }
556
557 /* Else strip it all */
558 do {
559 progress = false;
560 any = false;
561 NIR_PASS(progress, s, nir_shader_intrinsics_pass,
562 strip_side_effect_from_rast, nir_metadata_control_flow, &any);
563
564 NIR_PASS(progress, s, nir_opt_dce);
565 NIR_PASS(progress, s, nir_opt_dead_cf);
566 } while (progress);
567
568 assert(!any);
569 return progress;
570 }
571
572 static bool
strip_side_effect_from_main(nir_builder * b,nir_intrinsic_instr * intr,void * data)573 strip_side_effect_from_main(nir_builder *b, nir_intrinsic_instr *intr,
574 void *data)
575 {
576 switch (intr->intrinsic) {
577 case nir_intrinsic_global_atomic:
578 case nir_intrinsic_global_atomic_swap:
579 break;
580 default:
581 return false;
582 }
583
584 if (list_is_empty(&intr->def.uses)) {
585 nir_instr_remove(&intr->instr);
586 return true;
587 }
588
589 return false;
590 }
591
592 /*
593 * Create a GS rasterization shader. This is a hardware vertex shader that
594 * shades each rasterized output vertex in parallel.
595 */
596 static nir_shader *
agx_nir_create_gs_rast_shader(const nir_shader * gs,const nir_shader * libagx,bool * side_effects_for_rast)597 agx_nir_create_gs_rast_shader(const nir_shader *gs, const nir_shader *libagx,
598 bool *side_effects_for_rast)
599 {
600 /* Don't muck up the original shader */
601 nir_shader *shader = nir_shader_clone(NULL, gs);
602
603 unsigned max_verts = output_vertex_id_stride(shader);
604
605 /* Turn into a vertex shader run only for rasterization. Transform feedback
606 * was handled in the prepass.
607 */
608 shader->info.stage = MESA_SHADER_VERTEX;
609 shader->info.has_transform_feedback_varyings = false;
610 memset(&shader->info.vs, 0, sizeof(shader->info.vs));
611 shader->xfb_info = NULL;
612
613 if (shader->info.name) {
614 shader->info.name = ralloc_asprintf(shader, "%s_rast", shader->info.name);
615 } else {
616 shader->info.name = "gs rast";
617 }
618
619 nir_builder b_ =
620 nir_builder_at(nir_before_impl(nir_shader_get_entrypoint(shader)));
621 nir_builder *b = &b_;
622
623 NIR_PASS(_, shader, strip_side_effects_from_rast, side_effects_for_rast);
624
625 /* Optimize out pointless gl_PointSize outputs. Bizarrely, these occur. */
626 if (shader->info.gs.output_primitive != MESA_PRIM_POINTS)
627 shader->info.outputs_written &= ~VARYING_BIT_PSIZ;
628
629 /* See calc_unrolled_index_id */
630 nir_def *raw_id = nir_load_vertex_id(b);
631 nir_def *output_id = nir_umod_imm(b, raw_id, max_verts);
632 nir_def *unrolled = nir_udiv_imm(b, raw_id, max_verts);
633
634 nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
635 nir_def *instance_id = nir_ushr(b, unrolled, primitives_log2);
636 nir_def *primitive_id = nir_iand(
637 b, unrolled,
638 nir_iadd_imm(b, nir_ishl(b, nir_imm_int(b, 1), primitives_log2), -1));
639
640 struct lower_gs_rast_state rast_state = {
641 .instance_id = instance_id,
642 .primitive_id = primitive_id,
643 .output_id = output_id,
644 };
645
646 u_foreach_bit64(slot, shader->info.outputs_written) {
647 const char *slot_name =
648 gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
649
650 bool scalar = (slot == VARYING_SLOT_PSIZ) ||
651 (slot == VARYING_SLOT_LAYER) ||
652 (slot == VARYING_SLOT_VIEWPORT);
653 unsigned comps = scalar ? 1 : 4;
654
655 rast_state.outputs.outputs[slot] = nir_variable_create(
656 shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, comps),
657 ralloc_asprintf(shader, "%s-temp", slot_name));
658
659 rast_state.selected.outputs[slot] = nir_variable_create(
660 shader, nir_var_shader_temp, glsl_vector_type(GLSL_TYPE_UINT, comps),
661 ralloc_asprintf(shader, "%s-selected", slot_name));
662 }
663
664 nir_shader_intrinsics_pass(shader, lower_to_gs_rast,
665 nir_metadata_control_flow, &rast_state);
666
667 b->cursor = nir_after_impl(b->impl);
668
669 /* Forward each selected output to the rasterizer */
670 u_foreach_bit64(slot, shader->info.outputs_written) {
671 assert(rast_state.selected.outputs[slot] != NULL);
672 nir_def *value = nir_load_var(b, rast_state.selected.outputs[slot]);
673
674 /* We set NIR_COMPACT_ARRAYS so clip/cull distance needs to come all in
675 * DIST0. Undo the offset if we need to.
676 */
677 assert(slot != VARYING_SLOT_CULL_DIST1);
678 unsigned offset = 0;
679 if (slot == VARYING_SLOT_CLIP_DIST1)
680 offset = 1;
681
682 nir_store_output(b, value, nir_imm_int(b, offset),
683 .io_semantics.location = slot - offset,
684 .io_semantics.num_slots = 1,
685 .write_mask = nir_component_mask(value->num_components),
686 .src_type = nir_type_uint32);
687 }
688
689 /* It is legal to omit the point size write from the geometry shader when
690 * drawing points. In this case, the point size is implicitly 1.0. To
691 * implement, insert a synthetic `gl_PointSize = 1.0` write into the GS copy
692 * shader, if the GS does not export a point size while drawing points.
693 */
694 bool is_points = gs->info.gs.output_primitive == MESA_PRIM_POINTS;
695
696 if (!(shader->info.outputs_written & VARYING_BIT_PSIZ) && is_points) {
697 nir_store_output(b, nir_imm_float(b, 1.0), nir_imm_int(b, 0),
698 .io_semantics.location = VARYING_SLOT_PSIZ,
699 .io_semantics.num_slots = 1,
700 .write_mask = nir_component_mask(1),
701 .src_type = nir_type_float32);
702
703 shader->info.outputs_written |= VARYING_BIT_PSIZ;
704 }
705
706 nir_opt_idiv_const(shader, 16);
707
708 agx_preprocess_nir(shader, libagx);
709 return shader;
710 }
711
712 static nir_def *
previous_count(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id,enum gs_counter counter)713 previous_count(nir_builder *b, struct lower_gs_state *state, unsigned stream,
714 nir_def *unrolled_id, enum gs_counter counter)
715 {
716 assert(stream < MAX_VERTEX_STREAMS);
717 assert(counter < GS_NUM_COUNTERS);
718 int static_count = state->static_count[counter][stream];
719
720 if (static_count >= 0) {
721 /* If the number of outputted vertices per invocation is known statically,
722 * we can calculate the base.
723 */
724 return nir_imul_imm(b, unrolled_id, static_count);
725 } else {
726 /* Otherwise, we need to load from the prefix sum buffer. Note that the
727 * sums are inclusive, so index 0 is nonzero. This requires a little
728 * fixup here. We use a saturating unsigned subtraction so we don't read
729 * out-of-bounds for zero.
730 *
731 * TODO: Optimize this.
732 */
733 nir_def *prim_minus_1 = nir_usub_sat(b, unrolled_id, nir_imm_int(b, 1));
734 nir_def *addr =
735 load_count_address(b, state, prim_minus_1, stream, counter);
736
737 return nir_bcsel(b, nir_ieq_imm(b, unrolled_id, 0), nir_imm_int(b, 0),
738 nir_load_global_constant(b, addr, 4, 1, 32));
739 }
740 }
741
742 static nir_def *
previous_vertices(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)743 previous_vertices(nir_builder *b, struct lower_gs_state *state, unsigned stream,
744 nir_def *unrolled_id)
745 {
746 return previous_count(b, state, stream, unrolled_id, GS_COUNTER_VERTICES);
747 }
748
749 static nir_def *
previous_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)750 previous_primitives(nir_builder *b, struct lower_gs_state *state,
751 unsigned stream, nir_def *unrolled_id)
752 {
753 return previous_count(b, state, stream, unrolled_id, GS_COUNTER_PRIMITIVES);
754 }
755
756 static nir_def *
previous_xfb_primitives(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * unrolled_id)757 previous_xfb_primitives(nir_builder *b, struct lower_gs_state *state,
758 unsigned stream, nir_def *unrolled_id)
759 {
760 return previous_count(b, state, stream, unrolled_id,
761 GS_COUNTER_XFB_PRIMITIVES);
762 }
763
764 static void
lower_end_primitive(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)765 lower_end_primitive(nir_builder *b, nir_intrinsic_instr *intr,
766 struct lower_gs_state *state)
767 {
768 assert((intr->intrinsic == nir_intrinsic_set_vertex_and_primitive_count ||
769 b->shader->info.gs.output_primitive != MESA_PRIM_POINTS) &&
770 "endprimitive for points should've been removed");
771
772 /* The GS is the last stage before rasterization, so if we discard the
773 * rasterization, we don't output an index buffer, nothing will read it.
774 * Index buffer is only for the rasterization stream.
775 */
776 unsigned stream = nir_intrinsic_stream_id(intr);
777 if (state->rasterizer_discard || stream != 0)
778 return;
779
780 libagx_end_primitive(
781 b, load_geometry_param(b, output_index_buffer), intr->src[0].ssa,
782 intr->src[1].ssa, intr->src[2].ssa,
783 previous_vertices(b, state, 0, calc_unrolled_id(b)),
784 previous_primitives(b, state, 0, calc_unrolled_id(b)),
785 calc_unrolled_index_id(b),
786 nir_imm_bool(b, b->shader->info.gs.output_primitive != MESA_PRIM_POINTS));
787 }
788
789 static unsigned
verts_in_output_prim(nir_shader * gs)790 verts_in_output_prim(nir_shader *gs)
791 {
792 return mesa_vertices_per_prim(gs->info.gs.output_primitive);
793 }
794
795 static void
write_xfb(nir_builder * b,struct lower_gs_state * state,unsigned stream,nir_def * index_in_strip,nir_def * prim_id_in_invocation)796 write_xfb(nir_builder *b, struct lower_gs_state *state, unsigned stream,
797 nir_def *index_in_strip, nir_def *prim_id_in_invocation)
798 {
799 struct nir_xfb_info *xfb = b->shader->xfb_info;
800 unsigned verts = verts_in_output_prim(b->shader);
801
802 /* Get the index of this primitive in the XFB buffer. That is, the base for
803 * this invocation for the stream plus the offset within this invocation.
804 */
805 nir_def *invocation_base =
806 previous_xfb_primitives(b, state, stream, calc_unrolled_id(b));
807
808 nir_def *prim_index = nir_iadd(b, invocation_base, prim_id_in_invocation);
809 nir_def *base_index = nir_imul_imm(b, prim_index, verts);
810
811 nir_def *xfb_prims = load_geometry_param(b, xfb_prims[stream]);
812 nir_push_if(b, nir_ult(b, prim_index, xfb_prims));
813
814 /* Write XFB for each output */
815 for (unsigned i = 0; i < xfb->output_count; ++i) {
816 nir_xfb_output_info output = xfb->outputs[i];
817
818 /* Only write to the selected stream */
819 if (xfb->buffer_to_stream[output.buffer] != stream)
820 continue;
821
822 unsigned buffer = output.buffer;
823 unsigned stride = xfb->buffers[buffer].stride;
824 unsigned count = util_bitcount(output.component_mask);
825
826 for (unsigned vert = 0; vert < verts; ++vert) {
827 /* We write out the vertices backwards, since 0 is the current
828 * emitted vertex (which is actually the last vertex).
829 *
830 * We handle NULL var for
831 * KHR-Single-GL44.enhanced_layouts.xfb_capture_struct.
832 */
833 unsigned v = (verts - 1) - vert;
834 nir_variable *var = state->outputs[output.location][v];
835 nir_def *value = var ? nir_load_var(b, var) : nir_undef(b, 4, 32);
836
837 /* In case output.component_mask contains invalid components, write
838 * out zeroes instead of blowing up validation.
839 *
840 * KHR-Single-GL44.enhanced_layouts.xfb_capture_inactive_output_component
841 * hits this.
842 */
843 value = nir_pad_vector_imm_int(b, value, 0, 4);
844
845 nir_def *rotated_vert = nir_imm_int(b, vert);
846 if (verts == 3) {
847 /* Map vertices for output so we get consistent winding order. For
848 * the primitive index, we use the index_in_strip. This is actually
849 * the vertex index in the strip, hence
850 * offset by 2 relative to the true primitive index (#2 for the
851 * first triangle in the strip, #3 for the second). That's ok
852 * because only the parity matters.
853 */
854 rotated_vert = libagx_map_vertex_in_tri_strip(
855 b, index_in_strip, rotated_vert,
856 nir_inot(b, nir_i2b(b, nir_load_provoking_last(b))));
857 }
858
859 nir_def *addr = libagx_xfb_vertex_address(
860 b, nir_load_geometry_param_buffer_agx(b), base_index, rotated_vert,
861 nir_imm_int(b, buffer), nir_imm_int(b, stride),
862 nir_imm_int(b, output.offset));
863
864 nir_build_store_global(
865 b, nir_channels(b, value, output.component_mask), addr,
866 .align_mul = 4, .write_mask = nir_component_mask(count),
867 .access = ACCESS_XFB);
868 }
869 }
870
871 nir_pop_if(b, NULL);
872 }
873
874 /* Handle transform feedback for a given emit_vertex_with_counter */
875 static void
lower_emit_vertex_xfb(nir_builder * b,nir_intrinsic_instr * intr,struct lower_gs_state * state)876 lower_emit_vertex_xfb(nir_builder *b, nir_intrinsic_instr *intr,
877 struct lower_gs_state *state)
878 {
879 /* Transform feedback is written for each decomposed output primitive. Since
880 * we're writing strips, that means we output XFB for each vertex after the
881 * first complete primitive is formed.
882 */
883 unsigned first_prim = verts_in_output_prim(b->shader) - 1;
884 nir_def *index_in_strip = intr->src[1].ssa;
885
886 nir_push_if(b, nir_uge_imm(b, index_in_strip, first_prim));
887 {
888 write_xfb(b, state, nir_intrinsic_stream_id(intr), index_in_strip,
889 intr->src[3].ssa);
890 }
891 nir_pop_if(b, NULL);
892
893 /* Transform feedback writes out entire primitives during the emit_vertex. To
894 * do that, we store the values at all vertices in the strip in a little ring
895 * buffer. Index #0 is always the most recent primitive (so non-XFB code can
896 * just grab index #0 without any checking). Index #1 is the previous vertex,
897 * and index #2 is the vertex before that. Now that we've written XFB, since
898 * we've emitted a vertex we need to cycle the ringbuffer, freeing up index
899 * #0 for the next vertex that we are about to emit. We do that by copying
900 * the first n - 1 vertices forward one slot, which has to happen with a
901 * backwards copy implemented here.
902 *
903 * If we're lucky, all of these copies will be propagated away. If we're
904 * unlucky, this involves at most 2 copies per component per XFB output per
905 * vertex.
906 */
907 u_foreach_bit64(slot, b->shader->info.outputs_written) {
908 /* Note: if we're outputting points, verts_in_output_prim will be 1, so
909 * this loop will not execute. This is intended: points are self-contained
910 * primitives and do not need these copies.
911 */
912 for (int v = verts_in_output_prim(b->shader) - 1; v >= 1; --v) {
913 nir_def *value = nir_load_var(b, state->outputs[slot][v - 1]);
914
915 nir_store_var(b, state->outputs[slot][v], value,
916 nir_component_mask(value->num_components));
917 }
918 }
919 }
920
921 static bool
lower_gs_instr(nir_builder * b,nir_intrinsic_instr * intr,void * state)922 lower_gs_instr(nir_builder *b, nir_intrinsic_instr *intr, void *state)
923 {
924 b->cursor = nir_before_instr(&intr->instr);
925
926 switch (intr->intrinsic) {
927 case nir_intrinsic_set_vertex_and_primitive_count:
928 /* This instruction is mostly for the count shader, so just remove. But
929 * for points, we write the index buffer here so the rast shader can map.
930 */
931 if (b->shader->info.gs.output_primitive == MESA_PRIM_POINTS) {
932 lower_end_primitive(b, intr, state);
933 }
934
935 break;
936
937 case nir_intrinsic_end_primitive_with_counter: {
938 unsigned min = verts_in_output_prim(b->shader);
939
940 /* We only write out complete primitives */
941 nir_push_if(b, nir_uge_imm(b, intr->src[1].ssa, min));
942 {
943 lower_end_primitive(b, intr, state);
944 }
945 nir_pop_if(b, NULL);
946 break;
947 }
948
949 case nir_intrinsic_emit_vertex_with_counter:
950 /* emit_vertex triggers transform feedback but is otherwise a no-op. */
951 if (b->shader->xfb_info)
952 lower_emit_vertex_xfb(b, intr, state);
953 break;
954
955 default:
956 return false;
957 }
958
959 nir_instr_remove(&intr->instr);
960 return true;
961 }
962
963 static bool
collect_components(nir_builder * b,nir_intrinsic_instr * intr,void * data)964 collect_components(nir_builder *b, nir_intrinsic_instr *intr, void *data)
965 {
966 uint8_t *counts = data;
967 if (intr->intrinsic != nir_intrinsic_store_output)
968 return false;
969
970 unsigned count = nir_intrinsic_component(intr) +
971 util_last_bit(nir_intrinsic_write_mask(intr));
972
973 unsigned loc =
974 nir_intrinsic_io_semantics(intr).location + nir_src_as_uint(intr->src[1]);
975
976 uint8_t *total_count = &counts[loc];
977
978 *total_count = MAX2(*total_count, count);
979 return true;
980 }
981
982 /*
983 * Create the pre-GS shader. This is a small compute 1x1x1 kernel that produces
984 * an indirect draw to rasterize the produced geometry, as well as updates
985 * transform feedback offsets and counters as applicable.
986 */
987 static nir_shader *
agx_nir_create_pre_gs(struct lower_gs_state * state,const nir_shader * libagx,bool indexed,bool restart,struct nir_xfb_info * xfb,unsigned vertices_per_prim,uint8_t streams,unsigned invocations)988 agx_nir_create_pre_gs(struct lower_gs_state *state, const nir_shader *libagx,
989 bool indexed, bool restart, struct nir_xfb_info *xfb,
990 unsigned vertices_per_prim, uint8_t streams,
991 unsigned invocations)
992 {
993 nir_builder b_ = nir_builder_init_simple_shader(
994 MESA_SHADER_COMPUTE, &agx_nir_options, "Pre-GS patch up");
995 nir_builder *b = &b_;
996
997 /* Load the number of primitives input to the GS */
998 nir_def *unrolled_in_prims = load_geometry_param(b, input_primitives);
999
1000 /* Setup the draw from the rasterization stream (0). */
1001 if (!state->rasterizer_discard) {
1002 libagx_build_gs_draw(
1003 b, nir_load_geometry_param_buffer_agx(b),
1004 previous_vertices(b, state, 0, unrolled_in_prims),
1005 restart ? previous_primitives(b, state, 0, unrolled_in_prims)
1006 : nir_imm_int(b, 0));
1007 }
1008
1009 /* Determine the number of primitives generated in each stream */
1010 nir_def *in_prims[MAX_VERTEX_STREAMS], *prims[MAX_VERTEX_STREAMS];
1011
1012 u_foreach_bit(i, streams) {
1013 in_prims[i] = previous_xfb_primitives(b, state, i, unrolled_in_prims);
1014 prims[i] = in_prims[i];
1015
1016 add_counter(b, load_geometry_param(b, prims_generated_counter[i]),
1017 prims[i]);
1018 }
1019
1020 if (xfb) {
1021 /* Write XFB addresses */
1022 nir_def *offsets[4] = {NULL};
1023 u_foreach_bit(i, xfb->buffers_written) {
1024 offsets[i] = libagx_setup_xfb_buffer(
1025 b, nir_load_geometry_param_buffer_agx(b), nir_imm_int(b, i));
1026 }
1027
1028 /* Now clamp to the number that XFB captures */
1029 for (unsigned i = 0; i < xfb->output_count; ++i) {
1030 nir_xfb_output_info output = xfb->outputs[i];
1031
1032 unsigned buffer = output.buffer;
1033 unsigned stream = xfb->buffer_to_stream[buffer];
1034 unsigned stride = xfb->buffers[buffer].stride;
1035 unsigned words_written = util_bitcount(output.component_mask);
1036 unsigned bytes_written = words_written * 4;
1037
1038 /* Primitive P will write up to (but not including) offset:
1039 *
1040 * xfb_offset + ((P - 1) * (verts_per_prim * stride))
1041 * + ((verts_per_prim - 1) * stride)
1042 * + output_offset
1043 * + output_size
1044 *
1045 * Given an XFB buffer of size xfb_size, we get the inequality:
1046 *
1047 * floor(P) <= (stride + xfb_size - xfb_offset - output_offset -
1048 * output_size) // (stride * verts_per_prim)
1049 */
1050 nir_def *size = load_geometry_param(b, xfb_size[buffer]);
1051 size = nir_iadd_imm(b, size, stride - output.offset - bytes_written);
1052 size = nir_isub(b, size, offsets[buffer]);
1053 size = nir_imax(b, size, nir_imm_int(b, 0));
1054 nir_def *max_prims = nir_udiv_imm(b, size, stride * vertices_per_prim);
1055
1056 prims[stream] = nir_umin(b, prims[stream], max_prims);
1057 }
1058
1059 nir_def *any_overflow = nir_imm_false(b);
1060
1061 u_foreach_bit(i, streams) {
1062 nir_def *overflow = nir_ult(b, prims[i], in_prims[i]);
1063 any_overflow = nir_ior(b, any_overflow, overflow);
1064
1065 store_geometry_param(b, xfb_prims[i], prims[i]);
1066
1067 add_counter(b, load_geometry_param(b, xfb_overflow[i]),
1068 nir_b2i32(b, overflow));
1069
1070 add_counter(b, load_geometry_param(b, xfb_prims_generated_counter[i]),
1071 prims[i]);
1072 }
1073
1074 add_counter(b, load_geometry_param(b, xfb_any_overflow),
1075 nir_b2i32(b, any_overflow));
1076
1077 /* Update XFB counters */
1078 u_foreach_bit(i, xfb->buffers_written) {
1079 uint32_t prim_stride_B = xfb->buffers[i].stride * vertices_per_prim;
1080 unsigned stream = xfb->buffer_to_stream[i];
1081
1082 nir_def *off_ptr = load_geometry_param(b, xfb_offs_ptrs[i]);
1083 nir_def *size = nir_imul_imm(b, prims[stream], prim_stride_B);
1084 add_counter(b, off_ptr, size);
1085 }
1086 }
1087
1088 /* The geometry shader receives a number of input primitives. The driver
1089 * should disable this counter when tessellation is active TODO and count
1090 * patches separately.
1091 */
1092 add_counter(
1093 b,
1094 nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_IA_PRIMITIVES),
1095 unrolled_in_prims);
1096
1097 /* The geometry shader is invoked once per primitive (after unrolling
1098 * primitive restart). From the spec:
1099 *
1100 * In case of instanced geometry shaders (see section 11.3.4.2) the
1101 * geometry shader invocations count is incremented for each separate
1102 * instanced invocation.
1103 */
1104 add_counter(b,
1105 nir_load_stat_query_address_agx(
1106 b, .base = PIPE_STAT_QUERY_GS_INVOCATIONS),
1107 nir_imul_imm(b, unrolled_in_prims, invocations));
1108
1109 nir_def *emitted_prims = nir_imm_int(b, 0);
1110 u_foreach_bit(i, streams) {
1111 emitted_prims =
1112 nir_iadd(b, emitted_prims,
1113 previous_xfb_primitives(b, state, i, unrolled_in_prims));
1114 }
1115
1116 add_counter(
1117 b,
1118 nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_GS_PRIMITIVES),
1119 emitted_prims);
1120
1121 /* Clipper queries are not well-defined, so we can emulate them in lots of
1122 * silly ways. We need the hardware counters to implement them properly. For
1123 * now, just consider all primitives emitted as passing through the clipper.
1124 * This satisfies spec text:
1125 *
1126 * The number of primitives that reach the primitive clipping stage.
1127 *
1128 * and
1129 *
1130 * If at least one vertex of the primitive lies inside the clipping
1131 * volume, the counter is incremented by one or more. Otherwise, the
1132 * counter is incremented by zero or more.
1133 */
1134 add_counter(
1135 b,
1136 nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_PRIMITIVES),
1137 emitted_prims);
1138
1139 add_counter(
1140 b,
1141 nir_load_stat_query_address_agx(b, .base = PIPE_STAT_QUERY_C_INVOCATIONS),
1142 emitted_prims);
1143
1144 agx_preprocess_nir(b->shader, libagx);
1145 return b->shader;
1146 }
1147
1148 static bool
rewrite_invocation_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)1149 rewrite_invocation_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1150 {
1151 if (intr->intrinsic != nir_intrinsic_load_invocation_id)
1152 return false;
1153
1154 b->cursor = nir_instr_remove(&intr->instr);
1155 nir_def_rewrite_uses(&intr->def, nir_u2uN(b, data, intr->def.bit_size));
1156 return true;
1157 }
1158
1159 /*
1160 * Geometry shader instancing allows a GS to run multiple times. The number of
1161 * times is statically known and small. It's easiest to turn this into a loop
1162 * inside the GS, to avoid the feature "leaking" outside and affecting e.g. the
1163 * counts.
1164 */
1165 static void
agx_nir_lower_gs_instancing(nir_shader * gs)1166 agx_nir_lower_gs_instancing(nir_shader *gs)
1167 {
1168 unsigned nr_invocations = gs->info.gs.invocations;
1169 nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1170
1171 /* Each invocation can produce up to the shader-declared max_vertices, so
1172 * multiply it up for proper bounds check. Emitting more than the declared
1173 * max_vertices per invocation results in undefined behaviour, so erroneously
1174 * emitting more as asked on early invocations is a perfectly cromulent
1175 * behvaiour.
1176 */
1177 gs->info.gs.vertices_out *= gs->info.gs.invocations;
1178
1179 /* Get the original function */
1180 nir_cf_list list;
1181 nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl));
1182
1183 /* Create a builder for the wrapped function */
1184 nir_builder b = nir_builder_at(nir_after_block(nir_start_block(impl)));
1185
1186 nir_variable *i =
1187 nir_local_variable_create(impl, glsl_uintN_t_type(16), NULL);
1188 nir_store_var(&b, i, nir_imm_intN_t(&b, 0, 16), ~0);
1189 nir_def *index = NULL;
1190
1191 /* Create a loop in the wrapped function */
1192 nir_loop *loop = nir_push_loop(&b);
1193 {
1194 index = nir_load_var(&b, i);
1195 nir_push_if(&b, nir_uge_imm(&b, index, nr_invocations));
1196 {
1197 nir_jump(&b, nir_jump_break);
1198 }
1199 nir_pop_if(&b, NULL);
1200
1201 b.cursor = nir_cf_reinsert(&list, b.cursor);
1202 nir_store_var(&b, i, nir_iadd_imm(&b, index, 1), ~0);
1203
1204 /* Make sure we end the primitive between invocations. If the geometry
1205 * shader already ended the primitive, this will get optimized out.
1206 */
1207 nir_end_primitive(&b);
1208 }
1209 nir_pop_loop(&b, loop);
1210
1211 /* We've mucked about with control flow */
1212 nir_metadata_preserve(impl, nir_metadata_none);
1213
1214 /* Use the loop counter as the invocation ID each iteration */
1215 nir_shader_intrinsics_pass(gs, rewrite_invocation_id,
1216 nir_metadata_control_flow, index);
1217 }
1218
1219 static void
link_libagx(nir_shader * nir,const nir_shader * libagx)1220 link_libagx(nir_shader *nir, const nir_shader *libagx)
1221 {
1222 nir_link_shader_functions(nir, libagx);
1223 NIR_PASS(_, nir, nir_inline_functions);
1224 nir_remove_non_entrypoints(nir);
1225 NIR_PASS(_, nir, nir_lower_indirect_derefs, nir_var_function_temp, 64);
1226 NIR_PASS(_, nir, nir_opt_dce);
1227 NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
1228 nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared,
1229 glsl_get_cl_type_size_align);
1230 NIR_PASS(_, nir, nir_opt_deref);
1231 NIR_PASS(_, nir, nir_lower_vars_to_ssa);
1232 NIR_PASS(_, nir, nir_lower_explicit_io,
1233 nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
1234 nir_var_mem_global,
1235 nir_address_format_62bit_generic);
1236 }
1237
1238 bool
agx_nir_lower_gs(nir_shader * gs,const nir_shader * libagx,bool rasterizer_discard,nir_shader ** gs_count,nir_shader ** gs_copy,nir_shader ** pre_gs,enum mesa_prim * out_mode,unsigned * out_count_words)1239 agx_nir_lower_gs(nir_shader *gs, const nir_shader *libagx,
1240 bool rasterizer_discard, nir_shader **gs_count,
1241 nir_shader **gs_copy, nir_shader **pre_gs,
1242 enum mesa_prim *out_mode, unsigned *out_count_words)
1243 {
1244 /* Lower I/O as assumed by the rest of GS lowering */
1245 if (gs->xfb_info != NULL) {
1246 NIR_PASS(_, gs, nir_io_add_const_offset_to_base,
1247 nir_var_shader_in | nir_var_shader_out);
1248 NIR_PASS(_, gs, nir_io_add_intrinsic_xfb_info);
1249 }
1250
1251 NIR_PASS(_, gs, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
1252
1253 /* Collect output component counts so we can size the geometry output buffer
1254 * appropriately, instead of assuming everything is vec4.
1255 */
1256 uint8_t component_counts[NUM_TOTAL_VARYING_SLOTS] = {0};
1257 nir_shader_intrinsics_pass(gs, collect_components, nir_metadata_all,
1258 component_counts);
1259
1260 /* If geometry shader instancing is used, lower it away before linking
1261 * anything. Otherwise, smash the invocation ID to zero.
1262 */
1263 if (gs->info.gs.invocations != 1) {
1264 agx_nir_lower_gs_instancing(gs);
1265 } else {
1266 nir_function_impl *impl = nir_shader_get_entrypoint(gs);
1267 nir_builder b = nir_builder_at(nir_before_impl(impl));
1268
1269 nir_shader_intrinsics_pass(gs, rewrite_invocation_id,
1270 nir_metadata_control_flow, nir_imm_int(&b, 0));
1271 }
1272
1273 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_inputs,
1274 nir_metadata_control_flow, NULL);
1275
1276 /* Lower geometry shader writes to contain all of the required counts, so we
1277 * know where in the various buffers we should write vertices.
1278 */
1279 NIR_PASS(_, gs, nir_lower_gs_intrinsics,
1280 nir_lower_gs_intrinsics_count_primitives |
1281 nir_lower_gs_intrinsics_per_stream |
1282 nir_lower_gs_intrinsics_count_vertices_per_primitive |
1283 nir_lower_gs_intrinsics_overwrite_incomplete |
1284 nir_lower_gs_intrinsics_always_end_primitive |
1285 nir_lower_gs_intrinsics_count_decomposed_primitives);
1286
1287 /* Clean up after all that lowering we did */
1288 bool progress = false;
1289 do {
1290 progress = false;
1291 NIR_PASS(progress, gs, nir_lower_var_copies);
1292 NIR_PASS(progress, gs, nir_lower_variable_initializers,
1293 nir_var_shader_temp);
1294 NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1295 NIR_PASS(progress, gs, nir_copy_prop);
1296 NIR_PASS(progress, gs, nir_opt_constant_folding);
1297 NIR_PASS(progress, gs, nir_opt_algebraic);
1298 NIR_PASS(progress, gs, nir_opt_cse);
1299 NIR_PASS(progress, gs, nir_opt_dead_cf);
1300 NIR_PASS(progress, gs, nir_opt_dce);
1301
1302 /* Unrolling lets us statically determine counts more often, which
1303 * otherwise would not be possible with multiple invocations even in the
1304 * simplest of cases.
1305 */
1306 NIR_PASS(progress, gs, nir_opt_loop_unroll);
1307 } while (progress);
1308
1309 /* If we know counts at compile-time we can simplify, so try to figure out
1310 * the counts statically.
1311 */
1312 struct lower_gs_state gs_state = {
1313 .rasterizer_discard = rasterizer_discard,
1314 };
1315
1316 nir_gs_count_vertices_and_primitives(
1317 gs, gs_state.static_count[GS_COUNTER_VERTICES],
1318 gs_state.static_count[GS_COUNTER_PRIMITIVES],
1319 gs_state.static_count[GS_COUNTER_XFB_PRIMITIVES], 4);
1320
1321 /* Anything we don't know statically will be tracked by the count buffer.
1322 * Determine the layout for it.
1323 */
1324 for (unsigned i = 0; i < MAX_VERTEX_STREAMS; ++i) {
1325 for (unsigned c = 0; c < GS_NUM_COUNTERS; ++c) {
1326 gs_state.count_index[i][c] =
1327 (gs_state.static_count[c][i] < 0) ? gs_state.count_stride_el++ : -1;
1328 }
1329 }
1330
1331 bool side_effects_for_rast = false;
1332 *gs_copy = agx_nir_create_gs_rast_shader(gs, libagx, &side_effects_for_rast);
1333
1334 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1335 nir_metadata_control_flow, NULL);
1336
1337 link_libagx(gs, libagx);
1338
1339 NIR_PASS(_, gs, nir_lower_idiv,
1340 &(const nir_lower_idiv_options){.allow_fp16 = true});
1341
1342 /* All those variables we created should've gone away by now */
1343 NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1344
1345 /* If there is any unknown count, we need a geometry count shader */
1346 if (gs_state.count_stride_el > 0)
1347 *gs_count = agx_nir_create_geometry_count_shader(gs, libagx, &gs_state);
1348 else
1349 *gs_count = NULL;
1350
1351 /* Geometry shader outputs are staged to temporaries */
1352 struct agx_lower_output_to_var_state state = {0};
1353
1354 u_foreach_bit64(slot, gs->info.outputs_written) {
1355 /* After enough optimizations, the shader metadata can go out of sync, fix
1356 * with our gathered info. Otherwise glsl_vector_type will assert fail.
1357 */
1358 if (component_counts[slot] == 0) {
1359 gs->info.outputs_written &= ~BITFIELD64_BIT(slot);
1360 continue;
1361 }
1362
1363 const char *slot_name =
1364 gl_varying_slot_name_for_stage(slot, MESA_SHADER_GEOMETRY);
1365
1366 for (unsigned i = 0; i < MAX_PRIM_OUT_SIZE; ++i) {
1367 gs_state.outputs[slot][i] = nir_variable_create(
1368 gs, nir_var_shader_temp,
1369 glsl_vector_type(GLSL_TYPE_UINT, component_counts[slot]),
1370 ralloc_asprintf(gs, "%s-%u", slot_name, i));
1371 }
1372
1373 state.outputs[slot] = gs_state.outputs[slot][0];
1374 }
1375
1376 NIR_PASS(_, gs, nir_shader_instructions_pass, agx_lower_output_to_var,
1377 nir_metadata_control_flow, &state);
1378
1379 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_gs_instr,
1380 nir_metadata_none, &gs_state);
1381
1382 /* Determine if we are guaranteed to rasterize at least one vertex, so that
1383 * we can strip the prepass of side effects knowing they will execute in the
1384 * rasterization shader.
1385 */
1386 bool rasterizes_at_least_one_vertex =
1387 !rasterizer_discard && gs_state.static_count[0][0] > 0;
1388
1389 /* Clean up after all that lowering we did */
1390 nir_lower_global_vars_to_local(gs);
1391 do {
1392 progress = false;
1393 NIR_PASS(progress, gs, nir_lower_var_copies);
1394 NIR_PASS(progress, gs, nir_lower_variable_initializers,
1395 nir_var_shader_temp);
1396 NIR_PASS(progress, gs, nir_lower_vars_to_ssa);
1397 NIR_PASS(progress, gs, nir_copy_prop);
1398 NIR_PASS(progress, gs, nir_opt_constant_folding);
1399 NIR_PASS(progress, gs, nir_opt_algebraic);
1400 NIR_PASS(progress, gs, nir_opt_cse);
1401 NIR_PASS(progress, gs, nir_opt_dead_cf);
1402 NIR_PASS(progress, gs, nir_opt_dce);
1403 NIR_PASS(progress, gs, nir_opt_loop_unroll);
1404
1405 } while (progress);
1406
1407 /* When rasterizing, we try to handle side effects sensibly. */
1408 if (rasterizes_at_least_one_vertex && side_effects_for_rast) {
1409 do {
1410 progress = false;
1411 NIR_PASS(progress, gs, nir_shader_intrinsics_pass,
1412 strip_side_effect_from_main, nir_metadata_control_flow, NULL);
1413
1414 NIR_PASS(progress, gs, nir_opt_dce);
1415 NIR_PASS(progress, gs, nir_opt_dead_cf);
1416 } while (progress);
1417 }
1418
1419 /* All those variables we created should've gone away by now */
1420 NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
1421
1422 NIR_PASS(_, gs, nir_opt_sink, ~0);
1423 NIR_PASS(_, gs, nir_opt_move, ~0);
1424
1425 NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
1426 nir_metadata_control_flow, NULL);
1427
1428 /* Create auxiliary programs */
1429 *pre_gs = agx_nir_create_pre_gs(
1430 &gs_state, libagx, true, gs->info.gs.output_primitive != MESA_PRIM_POINTS,
1431 gs->xfb_info, verts_in_output_prim(gs), gs->info.gs.active_stream_mask,
1432 gs->info.gs.invocations);
1433
1434 /* Signal what primitive we want to draw the GS Copy VS with */
1435 *out_mode = gs->info.gs.output_primitive;
1436 *out_count_words = gs_state.count_stride_el;
1437 return true;
1438 }
1439
1440 /*
1441 * Vertex shaders (tessellation evaluation shaders) before a geometry shader run
1442 * as a dedicated compute prepass. They are invoked as (count, instances, 1).
1443 * Their linear ID is therefore (instances * num vertices) + vertex ID.
1444 *
1445 * This function lowers their vertex shader I/O to compute.
1446 *
1447 * Vertex ID becomes an index buffer pull (without applying the topology). Store
1448 * output becomes a store into the global vertex output buffer.
1449 */
1450 static bool
lower_vs_before_gs(nir_builder * b,nir_intrinsic_instr * intr,void * data)1451 lower_vs_before_gs(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1452 {
1453 if (intr->intrinsic != nir_intrinsic_store_output)
1454 return false;
1455
1456 b->cursor = nir_instr_remove(&intr->instr);
1457 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
1458 nir_def *location = nir_iadd_imm(b, intr->src[1].ssa, sem.location);
1459
1460 /* We inline the outputs_written because it's known at compile-time, even
1461 * with shader objects. This lets us constant fold a bit of address math.
1462 */
1463 nir_def *mask = nir_imm_int64(b, b->shader->info.outputs_written);
1464
1465 nir_def *buffer;
1466 nir_def *nr_verts;
1467 if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1468 buffer = nir_load_vs_output_buffer_agx(b);
1469 nr_verts =
1470 libagx_input_vertices(b, nir_load_input_assembly_buffer_agx(b));
1471 } else {
1472 assert(b->shader->info.stage == MESA_SHADER_TESS_EVAL);
1473
1474 /* Instancing is unrolled during tessellation so nr_verts is ignored. */
1475 nr_verts = nir_imm_int(b, 0);
1476 buffer = libagx_tes_buffer(b, nir_load_tess_param_buffer_agx(b));
1477 }
1478
1479 nir_def *linear_id = nir_iadd(b, nir_imul(b, load_instance_id(b), nr_verts),
1480 load_primitive_id(b));
1481
1482 nir_def *addr =
1483 libagx_vertex_output_address(b, buffer, mask, linear_id, location);
1484
1485 assert(nir_src_bit_size(intr->src[0]) == 32);
1486 addr = nir_iadd_imm(b, addr, nir_intrinsic_component(intr) * 4);
1487
1488 nir_store_global(b, addr, 4, intr->src[0].ssa,
1489 nir_intrinsic_write_mask(intr));
1490 return true;
1491 }
1492
1493 bool
agx_nir_lower_vs_before_gs(struct nir_shader * vs,const struct nir_shader * libagx)1494 agx_nir_lower_vs_before_gs(struct nir_shader *vs,
1495 const struct nir_shader *libagx)
1496 {
1497 bool progress = false;
1498
1499 /* Lower vertex stores to memory stores */
1500 progress |= nir_shader_intrinsics_pass(vs, lower_vs_before_gs,
1501 nir_metadata_control_flow, NULL);
1502
1503 /* Link libagx, used in lower_vs_before_gs */
1504 if (progress)
1505 link_libagx(vs, libagx);
1506
1507 return progress;
1508 }
1509
1510 void
agx_nir_prefix_sum_gs(nir_builder * b,const void * data)1511 agx_nir_prefix_sum_gs(nir_builder *b, const void *data)
1512 {
1513 const unsigned *words = data;
1514
1515 b->shader->info.workgroup_size[0] = 1024;
1516
1517 libagx_prefix_sum(b, load_geometry_param(b, count_buffer),
1518 load_geometry_param(b, input_primitives),
1519 nir_imm_int(b, *words),
1520 nir_channel(b, nir_load_workgroup_id(b), 0));
1521 }
1522
1523 void
agx_nir_prefix_sum_tess(nir_builder * b,const void * data)1524 agx_nir_prefix_sum_tess(nir_builder *b, const void *data)
1525 {
1526 b->shader->info.workgroup_size[0] = 1024;
1527 libagx_prefix_sum_tess(b, nir_load_preamble(b, 1, 64, .base = 0));
1528 }
1529
1530 void
agx_nir_gs_setup_indirect(nir_builder * b,const void * data)1531 agx_nir_gs_setup_indirect(nir_builder *b, const void *data)
1532 {
1533 const struct agx_gs_setup_indirect_key *key = data;
1534
1535 libagx_gs_setup_indirect(b, nir_load_preamble(b, 1, 64, .base = 0),
1536 nir_imm_int(b, key->prim),
1537 nir_channel(b, nir_load_local_invocation_id(b), 0));
1538 }
1539
1540 void
agx_nir_unroll_restart(nir_builder * b,const void * data)1541 agx_nir_unroll_restart(nir_builder *b, const void *data)
1542 {
1543 const struct agx_unroll_restart_key *key = data;
1544 b->shader->info.workgroup_size[0] = 1024;
1545
1546 nir_def *ia = nir_load_preamble(b, 1, 64, .base = 0);
1547 nir_def *draw = nir_channel(b, nir_load_workgroup_id(b), 0);
1548 nir_def *lane = nir_channel(b, nir_load_local_invocation_id(b), 0);
1549 nir_def *mode = nir_imm_int(b, key->prim);
1550
1551 if (key->index_size_B == 1)
1552 libagx_unroll_restart_u8(b, ia, mode, draw, lane);
1553 else if (key->index_size_B == 2)
1554 libagx_unroll_restart_u16(b, ia, mode, draw, lane);
1555 else if (key->index_size_B == 4)
1556 libagx_unroll_restart_u32(b, ia, mode, draw, lane);
1557 else
1558 unreachable("invalid index size");
1559 }
1560
1561 void
agx_nir_tessellate(nir_builder * b,const void * data)1562 agx_nir_tessellate(nir_builder *b, const void *data)
1563 {
1564 const struct agx_tessellator_key *key = data;
1565 b->shader->info.workgroup_size[0] = 64;
1566
1567 nir_def *params = nir_load_preamble(b, 1, 64, .base = 0);
1568 nir_def *patch = nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
1569 nir_def *mode = nir_imm_int(b, key->mode);
1570 nir_def *partitioning = nir_imm_int(b, key->partitioning);
1571 nir_def *output_prim = nir_imm_int(b, key->output_primitive);
1572
1573 if (key->prim == TESS_PRIMITIVE_ISOLINES)
1574 libagx_tess_isoline(b, params, mode, partitioning, output_prim, patch);
1575 else if (key->prim == TESS_PRIMITIVE_TRIANGLES)
1576 libagx_tess_tri(b, params, mode, partitioning, output_prim, patch);
1577 else if (key->prim == TESS_PRIMITIVE_QUADS)
1578 libagx_tess_quad(b, params, mode, partitioning, output_prim, patch);
1579 else
1580 unreachable("invalid tess primitive");
1581 }
1582
1583 void
agx_nir_tess_setup_indirect(nir_builder * b,const void * data)1584 agx_nir_tess_setup_indirect(nir_builder *b, const void *data)
1585 {
1586 const struct agx_tess_setup_indirect_key *key = data;
1587
1588 nir_def *params = nir_load_preamble(b, 1, 64, .base = 0);
1589 nir_def *with_counts = nir_imm_bool(b, key->with_counts);
1590 nir_def *point_mode = nir_imm_bool(b, key->point_mode);
1591
1592 libagx_tess_setup_indirect(b, params, with_counts, point_mode);
1593 }
1594
1595 void
agx_nir_increment_statistic(nir_builder * b,const void * data)1596 agx_nir_increment_statistic(nir_builder *b, const void *data)
1597 {
1598 libagx_increment_statistic(b, nir_load_preamble(b, 1, 64, .base = 0));
1599 }
1600
1601 void
agx_nir_increment_cs_invocations(nir_builder * b,const void * data)1602 agx_nir_increment_cs_invocations(nir_builder *b, const void *data)
1603 {
1604 libagx_increment_cs_invocations(b, nir_load_preamble(b, 1, 64, .base = 0));
1605 }
1606
1607 void
agx_nir_increment_ia_counters(nir_builder * b,const void * data)1608 agx_nir_increment_ia_counters(nir_builder *b, const void *data)
1609 {
1610 const struct agx_increment_ia_counters_key *key = data;
1611 b->shader->info.workgroup_size[0] = key->index_size_B ? 1024 : 1;
1612
1613 nir_def *params = nir_load_preamble(b, 1, 64, .base = 0);
1614 nir_def *index_size_B = nir_imm_int(b, key->index_size_B);
1615 nir_def *thread = nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
1616
1617 libagx_increment_ia_counters(b, params, index_size_B, thread);
1618 }
1619
1620 void
agx_nir_predicate_indirect(nir_builder * b,const void * data)1621 agx_nir_predicate_indirect(nir_builder *b, const void *data)
1622 {
1623 const struct agx_predicate_indirect_key *key = data;
1624
1625 nir_def *params = nir_load_preamble(b, 1, 64, .base = 0);
1626 nir_def *indexed = nir_imm_bool(b, key->indexed);
1627 nir_def *thread = nir_channel(b, nir_load_global_invocation_id(b, 32), 0);
1628
1629 libagx_predicate_indirect(b, params, thread, indexed);
1630 }
1631
1632 void
agx_nir_decompress(nir_builder * b,const void * data)1633 agx_nir_decompress(nir_builder *b, const void *data)
1634 {
1635 const struct agx_decompress_key *key = data;
1636
1637 nir_def *params = nir_load_preamble(b, 1, 64, .base = 0);
1638 nir_def *tile = nir_load_workgroup_id(b);
1639 nir_def *local = nir_channel(b, nir_load_local_invocation_id(b), 0);
1640 nir_def *samples = nir_imm_int(b, key->nr_samples);
1641
1642 libagx_decompress(b, params, tile, local, samples);
1643 }
1644