xref: /aosp_15_r20/external/mesa3d/src/freedreno/ir3/ir3_nir_lower_tess.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2019 Google, Inc.
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "compiler/nir/nir_builder.h"
7 #include "ir3_compiler.h"
8 #include "ir3_nir.h"
9 
10 struct state {
11    uint32_t topology;
12 
13    struct primitive_map {
14       /* +POSITION, +PSIZE, ... - see shader_io_get_unique_index */
15       unsigned loc[12 + 32];
16       unsigned stride;
17    } map;
18 
19    nir_def *header;
20 
21    nir_variable *vertex_count_var;
22    nir_variable *emitted_vertex_var;
23    nir_variable *vertex_flags_out;
24 
25    struct exec_list old_outputs;
26    struct exec_list new_outputs;
27    struct exec_list emit_outputs;
28 
29    /* tess ctrl shader on a650 gets the local primitive id at different bits: */
30    unsigned local_primitive_id_start;
31 };
32 
33 static nir_def *
bitfield_extract(nir_builder * b,nir_def * v,uint32_t start,uint32_t mask)34 bitfield_extract(nir_builder *b, nir_def *v, uint32_t start, uint32_t mask)
35 {
36    return nir_iand_imm(b, nir_ushr_imm(b, v, start), mask);
37 }
38 
39 static nir_def *
build_invocation_id(nir_builder * b,struct state * state)40 build_invocation_id(nir_builder *b, struct state *state)
41 {
42    return bitfield_extract(b, state->header, 11, 31);
43 }
44 
45 static nir_def *
build_vertex_id(nir_builder * b,struct state * state)46 build_vertex_id(nir_builder *b, struct state *state)
47 {
48    return bitfield_extract(b, state->header, 6, 31);
49 }
50 
51 static nir_def *
build_local_primitive_id(nir_builder * b,struct state * state)52 build_local_primitive_id(nir_builder *b, struct state *state)
53 {
54    return bitfield_extract(b, state->header, state->local_primitive_id_start,
55                            63);
56 }
57 
58 static bool
is_tess_levels(gl_varying_slot slot)59 is_tess_levels(gl_varying_slot slot)
60 {
61    return (slot == VARYING_SLOT_PRIMITIVE_ID ||
62            slot == VARYING_SLOT_TESS_LEVEL_OUTER ||
63            slot == VARYING_SLOT_TESS_LEVEL_INNER);
64 }
65 
66 /* Return a deterministic index for varyings. We can't rely on driver_location
67  * to be correct without linking the different stages first, so we create
68  * "primitive maps" where the producer decides on the location of each varying
69  * slot and then exports a per-slot array to the consumer. This compacts the
70  * gl_varying_slot space down a bit so that the primitive maps aren't too
71  * large.
72  *
73  * Note: per-patch varyings are currently handled separately, without any
74  * compacting.
75  *
76  * TODO: We could probably use the driver_location's directly in the non-SSO
77  * (Vulkan) case.
78  */
79 
80 static unsigned
shader_io_get_unique_index(gl_varying_slot slot)81 shader_io_get_unique_index(gl_varying_slot slot)
82 {
83    switch (slot) {
84    case VARYING_SLOT_POS:         return 0;
85    case VARYING_SLOT_PSIZ:        return 1;
86    case VARYING_SLOT_COL0:        return 2;
87    case VARYING_SLOT_COL1:        return 3;
88    case VARYING_SLOT_BFC0:        return 4;
89    case VARYING_SLOT_BFC1:        return 5;
90    case VARYING_SLOT_FOGC:        return 6;
91    case VARYING_SLOT_CLIP_DIST0:  return 7;
92    case VARYING_SLOT_CLIP_DIST1:  return 8;
93    case VARYING_SLOT_CLIP_VERTEX: return 9;
94    case VARYING_SLOT_LAYER:       return 10;
95    case VARYING_SLOT_VIEWPORT:    return 11;
96    case VARYING_SLOT_VAR0 ... VARYING_SLOT_VAR31: {
97       struct state state = {};
98       STATIC_ASSERT(ARRAY_SIZE(state.map.loc) - 1 ==
99                     (12 + VARYING_SLOT_VAR31 - VARYING_SLOT_VAR0));
100       struct ir3_shader_variant v = {};
101       STATIC_ASSERT(ARRAY_SIZE(v.output_loc) - 1 ==
102                     (12 + VARYING_SLOT_VAR31 - VARYING_SLOT_VAR0));
103       return 12 + (slot - VARYING_SLOT_VAR0);
104    }
105    default:
106       unreachable("illegal slot in get unique index\n");
107    }
108 }
109 
110 static nir_def *
build_local_offset(nir_builder * b,struct state * state,nir_def * vertex,uint32_t location,uint32_t comp,nir_def * offset)111 build_local_offset(nir_builder *b, struct state *state, nir_def *vertex,
112                    uint32_t location, uint32_t comp, nir_def *offset)
113 {
114    nir_def *primitive_stride = nir_load_vs_primitive_stride_ir3(b);
115    nir_def *primitive_offset =
116       nir_imul24(b, build_local_primitive_id(b, state), primitive_stride);
117    nir_def *attr_offset;
118    nir_def *vertex_stride;
119    unsigned index = shader_io_get_unique_index(location);
120 
121    switch (b->shader->info.stage) {
122    case MESA_SHADER_VERTEX:
123    case MESA_SHADER_TESS_EVAL:
124       vertex_stride = nir_imm_int(b, state->map.stride * 4);
125       attr_offset = nir_imm_int(b, state->map.loc[index] + 4 * comp);
126       break;
127    case MESA_SHADER_TESS_CTRL:
128    case MESA_SHADER_GEOMETRY:
129       vertex_stride = nir_load_vs_vertex_stride_ir3(b);
130       attr_offset = nir_iadd_imm(b, nir_load_primitive_location_ir3(b, index),
131                                  comp * 4);
132       break;
133    default:
134       unreachable("bad shader stage");
135    }
136 
137    nir_def *vertex_offset = nir_imul24(b, vertex, vertex_stride);
138 
139    return nir_iadd(
140       b, nir_iadd(b, primitive_offset, vertex_offset),
141       nir_iadd(b, attr_offset, nir_ishl_imm(b, offset, 4)));
142 }
143 
144 static nir_intrinsic_instr *
replace_intrinsic(nir_builder * b,nir_intrinsic_instr * intr,nir_intrinsic_op op,nir_def * src0,nir_def * src1,nir_def * src2)145 replace_intrinsic(nir_builder *b, nir_intrinsic_instr *intr,
146                   nir_intrinsic_op op, nir_def *src0, nir_def *src1,
147                   nir_def *src2)
148 {
149    nir_intrinsic_instr *new_intr = nir_intrinsic_instr_create(b->shader, op);
150 
151    new_intr->src[0] = nir_src_for_ssa(src0);
152    if (src1)
153       new_intr->src[1] = nir_src_for_ssa(src1);
154    if (src2)
155       new_intr->src[2] = nir_src_for_ssa(src2);
156 
157    new_intr->num_components = intr->num_components;
158 
159    if (nir_intrinsic_infos[op].has_dest)
160       nir_def_init(&new_intr->instr, &new_intr->def,
161                    intr->num_components, intr->def.bit_size);
162 
163    nir_builder_instr_insert(b, &new_intr->instr);
164 
165    if (nir_intrinsic_infos[op].has_dest)
166       nir_def_rewrite_uses(&intr->def, &new_intr->def);
167 
168    nir_instr_remove(&intr->instr);
169 
170    return new_intr;
171 }
172 
173 static void
build_primitive_map(nir_shader * shader,struct primitive_map * map)174 build_primitive_map(nir_shader *shader, struct primitive_map *map)
175 {
176    /* All interfaces except the TCS <-> TES interface use ldlw, which takes
177     * an offset in bytes, so each vec4 slot is 16 bytes. TCS <-> TES uses
178     * ldg, which takes an offset in dwords, but each per-vertex slot has
179     * space for every vertex, and there's space at the beginning for
180     * per-patch varyings.
181     */
182    unsigned slot_size = 16, start = 0;
183    if (shader->info.stage == MESA_SHADER_TESS_CTRL) {
184       slot_size = shader->info.tess.tcs_vertices_out * 4;
185       start = util_last_bit(shader->info.patch_outputs_written) * 4;
186    }
187 
188    uint64_t mask = shader->info.outputs_written;
189    unsigned loc = start;
190    while (mask) {
191       int location = u_bit_scan64(&mask);
192       if (is_tess_levels(location))
193          continue;
194 
195       unsigned index = shader_io_get_unique_index(location);
196       map->loc[index] = loc;
197       loc += slot_size;
198    }
199 
200    map->stride = loc;
201    /* Use units of dwords for the stride. */
202    if (shader->info.stage != MESA_SHADER_TESS_CTRL)
203       map->stride /= 4;
204 }
205 
206 /* For shader stages that receive a primitive map, calculate how big it should
207  * be.
208  */
209 
210 static unsigned
calc_primitive_map_size(nir_shader * shader)211 calc_primitive_map_size(nir_shader *shader)
212 {
213    uint64_t mask = shader->info.inputs_read;
214    unsigned max_index = 0;
215    while (mask) {
216       int location = u_bit_scan64(&mask);
217 
218       if (is_tess_levels(location))
219          continue;
220 
221       unsigned index = shader_io_get_unique_index(location);
222       max_index = MAX2(max_index, index + 1);
223    }
224 
225    return max_index;
226 }
227 
228 static void
lower_block_to_explicit_output(nir_block * block,nir_builder * b,struct state * state)229 lower_block_to_explicit_output(nir_block *block, nir_builder *b,
230                                struct state *state)
231 {
232    nir_foreach_instr_safe (instr, block) {
233       if (instr->type != nir_instr_type_intrinsic)
234          continue;
235 
236       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
237 
238       switch (intr->intrinsic) {
239       case nir_intrinsic_store_output: {
240          // src[] = { value, offset }.
241 
242          /* nir_lower_io_to_temporaries replaces all access to output
243           * variables with temp variables and then emits a nir_copy_var at
244           * the end of the shader.  Thus, we should always get a full wrmask
245           * here.
246           */
247          assert(
248             util_is_power_of_two_nonzero(nir_intrinsic_write_mask(intr) + 1));
249 
250          b->cursor = nir_instr_remove(&intr->instr);
251 
252          nir_def *vertex_id = build_vertex_id(b, state);
253          nir_def *offset = build_local_offset(
254             b, state, vertex_id, nir_intrinsic_io_semantics(intr).location,
255             nir_intrinsic_component(intr), intr->src[1].ssa);
256 
257          nir_store_shared_ir3(b, intr->src[0].ssa, offset);
258          break;
259       }
260 
261       default:
262          break;
263       }
264    }
265 }
266 
267 static nir_def *
local_thread_id(nir_builder * b)268 local_thread_id(nir_builder *b)
269 {
270    return bitfield_extract(b, nir_load_gs_header_ir3(b), 16, 1023);
271 }
272 
273 void
ir3_nir_lower_to_explicit_output(nir_shader * shader,struct ir3_shader_variant * v,unsigned topology)274 ir3_nir_lower_to_explicit_output(nir_shader *shader,
275                                  struct ir3_shader_variant *v,
276                                  unsigned topology)
277 {
278    struct state state = {};
279 
280    build_primitive_map(shader, &state.map);
281    memcpy(v->output_loc, state.map.loc, sizeof(v->output_loc));
282 
283    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
284    assert(impl);
285 
286    nir_builder b = nir_builder_at(nir_before_impl(impl));
287 
288    if (v->type == MESA_SHADER_VERTEX && topology != IR3_TESS_NONE)
289       state.header = nir_load_tcs_header_ir3(&b);
290    else
291       state.header = nir_load_gs_header_ir3(&b);
292 
293    nir_foreach_block_safe (block, impl)
294       lower_block_to_explicit_output(block, &b, &state);
295 
296    nir_metadata_preserve(impl,
297                          nir_metadata_control_flow);
298 
299    v->output_size = state.map.stride;
300 }
301 
302 static void
lower_block_to_explicit_input(nir_block * block,nir_builder * b,struct state * state)303 lower_block_to_explicit_input(nir_block *block, nir_builder *b,
304                               struct state *state)
305 {
306    nir_foreach_instr_safe (instr, block) {
307       if (instr->type != nir_instr_type_intrinsic)
308          continue;
309 
310       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
311 
312       switch (intr->intrinsic) {
313       case nir_intrinsic_load_per_vertex_input: {
314          // src[] = { vertex, offset }.
315 
316          b->cursor = nir_before_instr(&intr->instr);
317 
318          nir_def *offset = build_local_offset(
319             b, state,
320             intr->src[0].ssa, // this is typically gl_InvocationID
321             nir_intrinsic_io_semantics(intr).location,
322             nir_intrinsic_component(intr), intr->src[1].ssa);
323 
324          replace_intrinsic(b, intr, nir_intrinsic_load_shared_ir3, offset, NULL,
325                            NULL);
326          break;
327       }
328 
329       case nir_intrinsic_load_invocation_id: {
330          b->cursor = nir_before_instr(&intr->instr);
331 
332          nir_def *iid = build_invocation_id(b, state);
333          nir_def_replace(&intr->def, iid);
334          break;
335       }
336 
337       default:
338          break;
339       }
340    }
341 }
342 
343 void
ir3_nir_lower_to_explicit_input(nir_shader * shader,struct ir3_shader_variant * v)344 ir3_nir_lower_to_explicit_input(nir_shader *shader,
345                                 struct ir3_shader_variant *v)
346 {
347    struct state state = {};
348 
349    /* when using stl/ldl (instead of stlw/ldlw) for linking VS and HS,
350     * HS uses a different primitive id, which starts at bit 16 in the header
351     */
352    if (shader->info.stage == MESA_SHADER_TESS_CTRL &&
353        v->compiler->tess_use_shared)
354       state.local_primitive_id_start = 16;
355 
356    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
357    assert(impl);
358 
359    nir_builder b = nir_builder_at(nir_before_impl(impl));
360 
361    if (shader->info.stage == MESA_SHADER_GEOMETRY)
362       state.header = nir_load_gs_header_ir3(&b);
363    else
364       state.header = nir_load_tcs_header_ir3(&b);
365 
366    nir_foreach_block_safe (block, impl)
367       lower_block_to_explicit_input(block, &b, &state);
368 
369    v->input_size = calc_primitive_map_size(shader);
370 }
371 
372 static nir_def *
build_tcs_out_vertices(nir_builder * b)373 build_tcs_out_vertices(nir_builder *b)
374 {
375    if (b->shader->info.stage == MESA_SHADER_TESS_CTRL)
376       return nir_imm_int(b, b->shader->info.tess.tcs_vertices_out);
377    else
378       return nir_load_patch_vertices_in(b);
379 }
380 
381 static nir_def *
build_per_vertex_offset(nir_builder * b,struct state * state,nir_def * vertex,uint32_t location,uint32_t comp,nir_def * offset)382 build_per_vertex_offset(nir_builder *b, struct state *state,
383                         nir_def *vertex, uint32_t location, uint32_t comp,
384                         nir_def *offset)
385 {
386    nir_def *patch_id = nir_load_rel_patch_id_ir3(b);
387    nir_def *patch_stride = nir_load_hs_patch_stride_ir3(b);
388    nir_def *patch_offset = nir_imul24(b, patch_id, patch_stride);
389    nir_def *attr_offset;
390 
391    if (nir_src_is_const(nir_src_for_ssa(offset))) {
392       location += nir_src_as_uint(nir_src_for_ssa(offset));
393       offset = nir_imm_int(b, 0);
394    } else {
395       /* Offset is in vec4's, but we need it in unit of components for the
396        * load/store_global_ir3 offset.
397        */
398       offset = nir_ishl_imm(b, offset, 2);
399    }
400 
401    nir_def *vertex_offset;
402    if (vertex) {
403       unsigned index = shader_io_get_unique_index(location);
404       switch (b->shader->info.stage) {
405       case MESA_SHADER_TESS_CTRL:
406          attr_offset = nir_imm_int(b, state->map.loc[index] + comp);
407          break;
408       case MESA_SHADER_TESS_EVAL:
409          attr_offset = nir_iadd_imm(b, nir_load_primitive_location_ir3(b, index),
410                                     comp);
411          break;
412       default:
413          unreachable("bad shader state");
414       }
415 
416       attr_offset = nir_iadd(b, attr_offset,
417                              nir_imul24(b, offset, build_tcs_out_vertices(b)));
418       vertex_offset = nir_ishl_imm(b, vertex, 2);
419    } else {
420       assert(location >= VARYING_SLOT_PATCH0 &&
421              location <= VARYING_SLOT_TESS_MAX);
422       unsigned index = location - VARYING_SLOT_PATCH0;
423       attr_offset = nir_iadd_imm(b, offset, index * 4 + comp);
424       vertex_offset = nir_imm_int(b, 0);
425    }
426 
427    return nir_iadd(b, nir_iadd(b, patch_offset, attr_offset), vertex_offset);
428 }
429 
430 static nir_def *
build_patch_offset(nir_builder * b,struct state * state,uint32_t base,uint32_t comp,nir_def * offset)431 build_patch_offset(nir_builder *b, struct state *state, uint32_t base,
432                    uint32_t comp, nir_def *offset)
433 {
434    return build_per_vertex_offset(b, state, NULL, base, comp, offset);
435 }
436 
437 static void
tess_level_components(struct state * state,uint32_t * inner,uint32_t * outer)438 tess_level_components(struct state *state, uint32_t *inner, uint32_t *outer)
439 {
440    switch (state->topology) {
441    case IR3_TESS_TRIANGLES:
442       *inner = 1;
443       *outer = 3;
444       break;
445    case IR3_TESS_QUADS:
446       *inner = 2;
447       *outer = 4;
448       break;
449    case IR3_TESS_ISOLINES:
450       *inner = 0;
451       *outer = 2;
452       break;
453    default:
454       unreachable("bad");
455    }
456 }
457 
458 static nir_def *
build_tessfactor_base(nir_builder * b,gl_varying_slot slot,uint32_t comp,struct state * state)459 build_tessfactor_base(nir_builder *b, gl_varying_slot slot, uint32_t comp,
460                       struct state *state)
461 {
462    uint32_t inner_levels, outer_levels;
463    tess_level_components(state, &inner_levels, &outer_levels);
464 
465    const uint32_t patch_stride = 1 + inner_levels + outer_levels;
466 
467    nir_def *patch_id = nir_load_rel_patch_id_ir3(b);
468 
469    nir_def *patch_offset =
470       nir_imul24(b, patch_id, nir_imm_int(b, patch_stride));
471 
472    uint32_t offset;
473    switch (slot) {
474    case VARYING_SLOT_PRIMITIVE_ID:
475       offset = 0;
476       break;
477    case VARYING_SLOT_TESS_LEVEL_OUTER:
478       offset = 1;
479       break;
480    case VARYING_SLOT_TESS_LEVEL_INNER:
481       offset = 1 + outer_levels;
482       break;
483    default:
484       unreachable("bad");
485    }
486 
487    return nir_iadd_imm(b, patch_offset, offset + comp);
488 }
489 
490 static void
lower_tess_ctrl_block(nir_block * block,nir_builder * b,struct state * state)491 lower_tess_ctrl_block(nir_block *block, nir_builder *b, struct state *state)
492 {
493    nir_foreach_instr_safe (instr, block) {
494       if (instr->type != nir_instr_type_intrinsic)
495          continue;
496 
497       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
498 
499       switch (intr->intrinsic) {
500       case nir_intrinsic_load_per_vertex_output: {
501          // src[] = { vertex, offset }.
502 
503          b->cursor = nir_before_instr(&intr->instr);
504 
505          nir_def *address = nir_load_tess_param_base_ir3(b);
506          nir_def *offset = build_per_vertex_offset(
507             b, state, intr->src[0].ssa,
508             nir_intrinsic_io_semantics(intr).location,
509             nir_intrinsic_component(intr), intr->src[1].ssa);
510 
511          replace_intrinsic(b, intr, nir_intrinsic_load_global_ir3, address,
512                            offset, NULL);
513          break;
514       }
515 
516       case nir_intrinsic_store_per_vertex_output: {
517          // src[] = { value, vertex, offset }.
518 
519          b->cursor = nir_before_instr(&intr->instr);
520 
521          /* sparse writemask not supported */
522          assert(
523             util_is_power_of_two_nonzero(nir_intrinsic_write_mask(intr) + 1));
524 
525          nir_def *value = intr->src[0].ssa;
526          nir_def *address = nir_load_tess_param_base_ir3(b);
527          nir_def *offset = build_per_vertex_offset(
528             b, state, intr->src[1].ssa,
529             nir_intrinsic_io_semantics(intr).location,
530             nir_intrinsic_component(intr), intr->src[2].ssa);
531 
532          replace_intrinsic(b, intr, nir_intrinsic_store_global_ir3, value,
533                            address, offset);
534 
535          break;
536       }
537 
538       case nir_intrinsic_load_output: {
539          // src[] = { offset }.
540 
541          b->cursor = nir_before_instr(&intr->instr);
542 
543          nir_def *address, *offset;
544 
545          /* note if vectorization of the tess level loads ever happens:
546           * "ldg" across 16-byte boundaries can behave incorrectly if results
547           * are never used. most likely some issue with (sy) not properly
548           * syncing with values coming from a second memory transaction.
549           */
550          gl_varying_slot location = nir_intrinsic_io_semantics(intr).location;
551          if (is_tess_levels(location)) {
552             assert(intr->def.num_components == 1);
553             address = nir_load_tess_factor_base_ir3(b);
554             offset = build_tessfactor_base(
555                b, location, nir_intrinsic_component(intr), state);
556          } else {
557             address = nir_load_tess_param_base_ir3(b);
558             offset = build_patch_offset(b, state, location,
559                                         nir_intrinsic_component(intr),
560                                         intr->src[0].ssa);
561          }
562 
563          replace_intrinsic(b, intr, nir_intrinsic_load_global_ir3, address,
564                            offset, NULL);
565          break;
566       }
567 
568       case nir_intrinsic_store_output: {
569          // src[] = { value, offset }.
570 
571          /* write patch output to bo */
572 
573          b->cursor = nir_before_instr(&intr->instr);
574 
575          /* sparse writemask not supported */
576          assert(
577             util_is_power_of_two_nonzero(nir_intrinsic_write_mask(intr) + 1));
578 
579          gl_varying_slot location = nir_intrinsic_io_semantics(intr).location;
580          if (is_tess_levels(location)) {
581             uint32_t inner_levels, outer_levels, levels;
582             tess_level_components(state, &inner_levels, &outer_levels);
583 
584             assert(intr->src[0].ssa->num_components == 1);
585 
586             nir_if *nif = NULL;
587             if (location != VARYING_SLOT_PRIMITIVE_ID) {
588                /* with tess levels are defined as float[4] and float[2],
589                 * but tess factor BO has smaller sizes for tris/isolines,
590                 * so we have to discard any writes beyond the number of
591                 * components for inner/outer levels
592                 */
593                if (location == VARYING_SLOT_TESS_LEVEL_OUTER)
594                   levels = outer_levels;
595                else
596                   levels = inner_levels;
597 
598                nir_def *offset = nir_iadd_imm(
599                   b, intr->src[1].ssa, nir_intrinsic_component(intr));
600                nif = nir_push_if(b, nir_ult_imm(b, offset, levels));
601             }
602 
603             nir_def *offset = build_tessfactor_base(
604                b, location, nir_intrinsic_component(intr), state);
605 
606             replace_intrinsic(b, intr, nir_intrinsic_store_global_ir3,
607                               intr->src[0].ssa,
608                               nir_load_tess_factor_base_ir3(b),
609                               nir_iadd(b, intr->src[1].ssa, offset));
610 
611             if (location != VARYING_SLOT_PRIMITIVE_ID) {
612                nir_pop_if(b, nif);
613             }
614          } else {
615             nir_def *address = nir_load_tess_param_base_ir3(b);
616             nir_def *offset = build_patch_offset(
617                b, state, location, nir_intrinsic_component(intr),
618                intr->src[1].ssa);
619 
620             replace_intrinsic(b, intr, nir_intrinsic_store_global_ir3,
621                               intr->src[0].ssa, address, offset);
622          }
623          break;
624       }
625 
626       default:
627          break;
628       }
629    }
630 }
631 
632 void
ir3_nir_lower_tess_ctrl(nir_shader * shader,struct ir3_shader_variant * v,unsigned topology)633 ir3_nir_lower_tess_ctrl(nir_shader *shader, struct ir3_shader_variant *v,
634                         unsigned topology)
635 {
636    struct state state = {.topology = topology};
637 
638    if (shader_debug_enabled(shader->info.stage, shader->info.internal)) {
639       mesa_logi("NIR (before tess lowering) for %s shader:",
640                 _mesa_shader_stage_to_string(shader->info.stage));
641       nir_log_shaderi(shader);
642    }
643 
644    build_primitive_map(shader, &state.map);
645    memcpy(v->output_loc, state.map.loc, sizeof(v->output_loc));
646    v->output_size = state.map.stride;
647 
648    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
649    assert(impl);
650 
651    nir_builder b = nir_builder_at(nir_before_impl(impl));
652 
653    state.header = nir_load_tcs_header_ir3(&b);
654 
655    /* If required, store gl_PrimitiveID. */
656    if (v->key.tcs_store_primid) {
657       b.cursor = nir_after_impl(impl);
658 
659       nir_store_output(&b, nir_load_primitive_id(&b), nir_imm_int(&b, 0),
660                        .io_semantics = {
661                            .location = VARYING_SLOT_PRIMITIVE_ID,
662                            .num_slots = 1
663                         });
664 
665       b.cursor = nir_before_impl(impl);
666    }
667 
668    nir_foreach_block_safe (block, impl)
669       lower_tess_ctrl_block(block, &b, &state);
670 
671    /* Now move the body of the TCS into a conditional:
672     *
673     *   if (gl_InvocationID < num_vertices)
674     *     // body
675     *
676     */
677 
678    nir_cf_list body;
679    nir_cf_extract(&body, nir_before_impl(impl),
680                   nir_after_impl(impl));
681 
682    b.cursor = nir_after_impl(impl);
683 
684    /* Re-emit the header, since the old one got moved into the if branch */
685    state.header = nir_load_tcs_header_ir3(&b);
686    nir_def *iid = build_invocation_id(&b, &state);
687 
688    const uint32_t nvertices = shader->info.tess.tcs_vertices_out;
689    nir_def *cond = nir_ult_imm(&b, iid, nvertices);
690 
691    nir_if *nif = nir_push_if(&b, cond);
692 
693    nir_cf_reinsert(&body, b.cursor);
694 
695    b.cursor = nir_after_cf_list(&nif->then_list);
696 
697    nir_pop_if(&b, nif);
698 
699    nir_metadata_preserve(impl, nir_metadata_none);
700 }
701 
702 static void
lower_tess_eval_block(nir_block * block,nir_builder * b,struct state * state)703 lower_tess_eval_block(nir_block *block, nir_builder *b, struct state *state)
704 {
705    nir_foreach_instr_safe (instr, block) {
706       if (instr->type != nir_instr_type_intrinsic)
707          continue;
708 
709       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
710 
711       switch (intr->intrinsic) {
712       case nir_intrinsic_load_per_vertex_input: {
713          // src[] = { vertex, offset }.
714 
715          b->cursor = nir_before_instr(&intr->instr);
716 
717          nir_def *address = nir_load_tess_param_base_ir3(b);
718          nir_def *offset = build_per_vertex_offset(
719             b, state, intr->src[0].ssa,
720             nir_intrinsic_io_semantics(intr).location,
721             nir_intrinsic_component(intr), intr->src[1].ssa);
722 
723          replace_intrinsic(b, intr, nir_intrinsic_load_global_ir3, address,
724                            offset, NULL);
725          break;
726       }
727 
728       case nir_intrinsic_load_input: {
729          // src[] = { offset }.
730 
731          b->cursor = nir_before_instr(&intr->instr);
732 
733          nir_def *address, *offset;
734 
735          /* note if vectorization of the tess level loads ever happens:
736           * "ldg" across 16-byte boundaries can behave incorrectly if results
737           * are never used. most likely some issue with (sy) not properly
738           * syncing with values coming from a second memory transaction.
739           */
740          gl_varying_slot location = nir_intrinsic_io_semantics(intr).location;
741          if (is_tess_levels(location)) {
742             assert(intr->def.num_components == 1);
743             address = nir_load_tess_factor_base_ir3(b);
744             offset = build_tessfactor_base(
745                b, location, nir_intrinsic_component(intr), state);
746          } else {
747             address = nir_load_tess_param_base_ir3(b);
748             offset = build_patch_offset(b, state, location,
749                                         nir_intrinsic_component(intr),
750                                         intr->src[0].ssa);
751          }
752 
753          replace_intrinsic(b, intr, nir_intrinsic_load_global_ir3, address,
754                            offset, NULL);
755          break;
756       }
757 
758       default:
759          break;
760       }
761    }
762 }
763 
764 void
ir3_nir_lower_tess_eval(nir_shader * shader,struct ir3_shader_variant * v,unsigned topology)765 ir3_nir_lower_tess_eval(nir_shader *shader, struct ir3_shader_variant *v,
766                         unsigned topology)
767 {
768    struct state state = {.topology = topology};
769 
770    if (shader_debug_enabled(shader->info.stage, shader->info.internal)) {
771       mesa_logi("NIR (before tess lowering) for %s shader:",
772                 _mesa_shader_stage_to_string(shader->info.stage));
773       nir_log_shaderi(shader);
774    }
775 
776    NIR_PASS_V(shader, nir_lower_tess_coord_z, topology == IR3_TESS_TRIANGLES);
777 
778    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
779    assert(impl);
780 
781    nir_builder b = nir_builder_create(impl);
782 
783    nir_foreach_block_safe (block, impl)
784       lower_tess_eval_block(block, &b, &state);
785 
786    v->input_size = calc_primitive_map_size(shader);
787 
788    nir_metadata_preserve(impl, nir_metadata_none);
789 }
790 
791 /* The hardware does not support incomplete primitives in multiple streams at
792  * once or ending the "wrong" stream, but Vulkan allows this. That is,
793  * EmitStreamVertex(N) followed by EmitStreamVertex(M) or EndStreamPrimitive(M)
794  * where N != M and there isn't a call to EndStreamPrimitive(N) in between isn't
795  * supported by the hardware. Fix this up by duplicating the entire shader per
796  * stream, removing EmitStreamVertex/EndStreamPrimitive calls for streams other
797  * than the current one.
798  */
799 
800 static void
lower_mixed_streams(nir_shader * nir)801 lower_mixed_streams(nir_shader *nir)
802 {
803    /* We don't have to do anything for points because there is only one vertex
804     * per primitive and therefore no possibility of mixing.
805     */
806    if (nir->info.gs.output_primitive == MESA_PRIM_POINTS)
807       return;
808 
809    nir_function_impl *entrypoint = nir_shader_get_entrypoint(nir);
810 
811    uint8_t stream_mask = 0;
812 
813    nir_foreach_block (block, entrypoint) {
814       nir_foreach_instr (instr, block) {
815          if (instr->type != nir_instr_type_intrinsic)
816             continue;
817 
818          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
819 
820          if (intrin->intrinsic == nir_intrinsic_emit_vertex ||
821              intrin->intrinsic == nir_intrinsic_end_primitive)
822             stream_mask |= 1 << nir_intrinsic_stream_id(intrin);
823       }
824    }
825 
826    if (util_is_power_of_two_or_zero(stream_mask))
827       return;
828 
829    nir_cf_list body;
830    nir_cf_list_extract(&body, &entrypoint->body);
831 
832    nir_builder b = nir_builder_create(entrypoint);
833 
834    u_foreach_bit (stream, stream_mask) {
835       b.cursor = nir_after_impl(entrypoint);
836 
837       /* Inserting the cloned body invalidates any cursor not using an
838        * instruction, so we need to emit this to keep track of where the new
839        * body is to iterate over it.
840        */
841       nir_instr *anchor = &nir_nop(&b)->instr;
842 
843       nir_cf_list_clone_and_reinsert(&body, &entrypoint->cf_node, b.cursor, NULL);
844 
845       /* We need to iterate over all instructions after the anchor, which is a
846        * bit tricky to do so we do it manually.
847        */
848       for (nir_block *block = anchor->block; block != NULL;
849            block = nir_block_cf_tree_next(block)) {
850          for (nir_instr *instr =
851                (block == anchor->block) ? anchor : nir_block_first_instr(block),
852                *next = instr ? nir_instr_next(instr) : NULL;
853               instr != NULL; instr = next, next = next ? nir_instr_next(next) : NULL) {
854             if (instr->type != nir_instr_type_intrinsic)
855                continue;
856 
857             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
858             if ((intrin->intrinsic == nir_intrinsic_emit_vertex ||
859                  intrin->intrinsic == nir_intrinsic_end_primitive) &&
860                 nir_intrinsic_stream_id(intrin) != stream) {
861                nir_instr_remove(instr);
862             }
863          }
864       }
865 
866       nir_instr_remove(anchor);
867 
868       /* The user can omit the last EndStreamPrimitive(), so add an extra one
869        * here before potentially adding other copies of the body that emit to
870        * different streams. Our lowering means that redundant calls to
871        * EndStreamPrimitive are safe and should be optimized out.
872        */
873       b.cursor = nir_after_impl(entrypoint);
874       nir_end_primitive(&b, .stream_id = stream);
875    }
876 
877    nir_cf_delete(&body);
878 }
879 
880 static void
copy_vars(nir_builder * b,struct exec_list * dests,struct exec_list * srcs)881 copy_vars(nir_builder *b, struct exec_list *dests, struct exec_list *srcs)
882 {
883    foreach_two_lists (dest_node, dests, src_node, srcs) {
884       nir_variable *dest = exec_node_data(nir_variable, dest_node, node);
885       nir_variable *src = exec_node_data(nir_variable, src_node, node);
886       nir_copy_var(b, dest, src);
887    }
888 }
889 
890 static void
lower_gs_block(nir_block * block,nir_builder * b,struct state * state)891 lower_gs_block(nir_block *block, nir_builder *b, struct state *state)
892 {
893    nir_foreach_instr_safe (instr, block) {
894       if (instr->type != nir_instr_type_intrinsic)
895          continue;
896 
897       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
898 
899       switch (intr->intrinsic) {
900       case nir_intrinsic_end_primitive: {
901          /* The HW will use the stream from the preceding emitted vertices,
902           * which thanks to the lower_mixed_streams is the same as the stream
903           * for this instruction, so we can ignore it here.
904           */
905          b->cursor = nir_before_instr(&intr->instr);
906          nir_store_var(b, state->vertex_flags_out, nir_imm_int(b, 4), 0x1);
907          nir_instr_remove(&intr->instr);
908          break;
909       }
910 
911       case nir_intrinsic_emit_vertex: {
912          /* Load the vertex count */
913          b->cursor = nir_before_instr(&intr->instr);
914          nir_def *count = nir_load_var(b, state->vertex_count_var);
915 
916          nir_push_if(b, nir_ieq(b, count, local_thread_id(b)));
917 
918          unsigned stream = nir_intrinsic_stream_id(intr);
919          /* vertex_flags_out |= stream */
920          nir_store_var(b, state->vertex_flags_out,
921                        nir_ior_imm(b, nir_load_var(b, state->vertex_flags_out),
922                                    stream),
923                        0x1 /* .x */);
924 
925          copy_vars(b, &state->emit_outputs, &state->old_outputs);
926 
927          nir_instr_remove(&intr->instr);
928 
929          nir_store_var(b, state->emitted_vertex_var,
930                        nir_iadd_imm(b,
931                                     nir_load_var(b,
932                                                  state->emitted_vertex_var),
933                                                  1),
934                        0x1);
935 
936          nir_pop_if(b, NULL);
937 
938          /* Increment the vertex count by 1 */
939          nir_store_var(b, state->vertex_count_var,
940                        nir_iadd_imm(b, count, 1), 0x1); /* .x */
941          nir_store_var(b, state->vertex_flags_out, nir_imm_int(b, 0), 0x1);
942 
943          break;
944       }
945 
946       default:
947          break;
948       }
949    }
950 }
951 
952 void
ir3_nir_lower_gs(nir_shader * shader)953 ir3_nir_lower_gs(nir_shader *shader)
954 {
955    struct state state = {};
956 
957    /* Don't lower multiple times: */
958    nir_foreach_shader_out_variable (var, shader)
959       if (var->data.location == VARYING_SLOT_GS_VERTEX_FLAGS_IR3)
960          return;
961 
962    if (shader_debug_enabled(shader->info.stage, shader->info.internal)) {
963       mesa_logi("NIR (before gs lowering):");
964       nir_log_shaderi(shader);
965    }
966 
967    lower_mixed_streams(shader);
968 
969    /* Create an output var for vertex_flags. This will be shadowed below,
970     * same way regular outputs get shadowed, and this variable will become a
971     * temporary.
972     */
973    state.vertex_flags_out = nir_variable_create(
974       shader, nir_var_shader_out, glsl_uint_type(), "vertex_flags");
975    state.vertex_flags_out->data.driver_location = shader->num_outputs++;
976    state.vertex_flags_out->data.location = VARYING_SLOT_GS_VERTEX_FLAGS_IR3;
977    state.vertex_flags_out->data.interpolation = INTERP_MODE_NONE;
978 
979    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
980    assert(impl);
981 
982    nir_builder b = nir_builder_at(nir_before_impl(impl));
983 
984    state.header = nir_load_gs_header_ir3(&b);
985 
986    /* Generate two set of shadow vars for the output variables.  The first
987     * set replaces the real outputs and the second set (emit_outputs) we'll
988     * assign in the emit_vertex conditionals.  Then at the end of the shader
989     * we copy the emit_outputs to the real outputs, so that we get
990     * store_output in uniform control flow.
991     */
992    exec_list_make_empty(&state.old_outputs);
993    nir_foreach_shader_out_variable_safe (var, shader) {
994       exec_node_remove(&var->node);
995       exec_list_push_tail(&state.old_outputs, &var->node);
996    }
997    exec_list_make_empty(&state.new_outputs);
998    exec_list_make_empty(&state.emit_outputs);
999    nir_foreach_variable_in_list (var, &state.old_outputs) {
1000       /* Create a new output var by cloning the original output var and
1001        * stealing the name.
1002        */
1003       nir_variable *output = nir_variable_clone(var, shader);
1004       exec_list_push_tail(&state.new_outputs, &output->node);
1005 
1006       /* Rewrite the original output to be a shadow variable. */
1007       var->name = ralloc_asprintf(var, "%s@gs-temp", output->name);
1008       var->data.mode = nir_var_shader_temp;
1009 
1010       /* Clone the shadow variable to create the emit shadow variable that
1011        * we'll assign in the emit conditionals.
1012        */
1013       nir_variable *emit_output = nir_variable_clone(var, shader);
1014       emit_output->name = ralloc_asprintf(var, "%s@emit-temp", output->name);
1015       exec_list_push_tail(&state.emit_outputs, &emit_output->node);
1016    }
1017 
1018    /* During the shader we'll keep track of which vertex we're currently
1019     * emitting for the EmitVertex test and how many vertices we emitted so we
1020     * know to discard if didn't emit any.  In most simple shaders, this can
1021     * all be statically determined and gets optimized away.
1022     */
1023    state.vertex_count_var =
1024       nir_local_variable_create(impl, glsl_uint_type(), "vertex_count");
1025    state.emitted_vertex_var =
1026       nir_local_variable_create(impl, glsl_uint_type(), "emitted_vertex");
1027 
1028    /* Initialize to 0. */
1029    b.cursor = nir_before_impl(impl);
1030    nir_store_var(&b, state.vertex_count_var, nir_imm_int(&b, 0), 0x1);
1031    nir_store_var(&b, state.emitted_vertex_var, nir_imm_int(&b, 0), 0x1);
1032    nir_store_var(&b, state.vertex_flags_out, nir_imm_int(&b, 4), 0x1);
1033 
1034    nir_foreach_block_safe (block, impl)
1035       lower_gs_block(block, &b, &state);
1036 
1037    /* Note: returns are lowered, so there should be only one block before the
1038     * end block.  If we had real returns, we would probably want to redirect
1039     * them to this new if statement, rather than emitting this code at every
1040     * return statement.
1041     */
1042    assert(impl->end_block->predecessors->entries == 1);
1043    nir_block *block = nir_impl_last_block(impl);
1044    b.cursor = nir_after_block_before_jump(block);
1045 
1046    /* If we haven't emitted any vertex we need to copy the shadow (old)
1047     * outputs to emit outputs here.
1048     *
1049     * Also some piglit GS tests[1] don't have EndPrimitive() so throw
1050     * in an extra vertex_flags write for good measure.  If unneeded it
1051     * will be optimized out.
1052     *
1053     * [1] ex, tests/spec/glsl-1.50/execution/compatibility/clipping/gs-clip-vertex-const-accept.shader_test
1054     */
1055    nir_def *cond =
1056       nir_ieq_imm(&b, nir_load_var(&b, state.emitted_vertex_var), 0);
1057    nir_push_if(&b, cond);
1058    nir_store_var(&b, state.vertex_flags_out, nir_imm_int(&b, 4), 0x1);
1059    copy_vars(&b, &state.emit_outputs, &state.old_outputs);
1060    nir_pop_if(&b, NULL);
1061 
1062    nir_discard_if(&b, cond);
1063 
1064    copy_vars(&b, &state.new_outputs, &state.emit_outputs);
1065 
1066    exec_list_append(&shader->variables, &state.old_outputs);
1067    exec_list_append(&shader->variables, &state.emit_outputs);
1068    exec_list_append(&shader->variables, &state.new_outputs);
1069 
1070    nir_metadata_preserve(impl, nir_metadata_none);
1071 
1072    nir_lower_global_vars_to_local(shader);
1073    nir_split_var_copies(shader);
1074    nir_lower_var_copies(shader);
1075 
1076    nir_fixup_deref_modes(shader);
1077 
1078    if (shader_debug_enabled(shader->info.stage, shader->info.internal)) {
1079       mesa_logi("NIR (after gs lowering):");
1080       nir_log_shaderi(shader);
1081    }
1082 }
1083