xref: /aosp_15_r20/external/mesa3d/src/amd/common/ac_nir_lower_tess_io_to_mem.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9 #include "nir_builder.h"
10 
11 /*
12  * These NIR passes are used to lower NIR cross-stage I/O intrinsics into the
13  * memory accesses that actually happen on the HW.
14  *
15  * Each input and output has a 16-byte (4 dwords) slot reserved for it, and
16  * can have up to 4 components. Each component is 32 bits.
17  *
18  * ## VS-TCS-TES I/O - Terminology:
19  *
20  * * patch - Group of vertices, used instead of primitives in tessellation
21  * * per-vertex - input or output which can be different for every vertex.
22  * * per-patch - input output which applies to a patch (a group of vertices)
23  *
24  * ## VS-TCS-TES I/O - How it works:
25  *
26  * ```
27  * SW model:    SW VS         SW TCS    tessellator    SW TES
28  *                ┊             ┊             ┊          ┊
29  *              ┌────┐        ┌────┐        ┌────┐    ┌─────┐
30  * HW pipeline: │ LS │─╮   ╭─>│ HS │─╮   ╭─>│ FF │ ╭─>│VS/ES31  *              └────┘ │   │  └────┘ │   │  └────┘ │  └─────┘
32  * Memory:             ╰─>LDS<──╯    ╰─>VRAM───────╯
33  * ```
34  *
35  * * SW VS runs as a HW LS (Local Shader, merged into HS on GFX9+),
36  *   and SW TCS runs as HW HS (Hull Shader).
37  *   SW TES runs as either HW VS or HW ES (Export Shader).
38  * * LS and HS share the same LDS space.
39  * * LS (SW VS) stores outputs to LDS to be read by HS (SW TCS).
40  * * HS (SW TCS) stores outputs in LDS if the HS (SW TCS) reads them.
41  * * HS (SW TCS) stores outputs in VRAM if the next stage (SW TES) reads them.
42  *
43  * Side note: some old HW supports having TES read from the same LDS space where LS/HS write, but
44  * Mesa always stores HS outputs to VRAM to avoid forcing TES waves to run on the same CU as the LS/HS waves.
45  *
46  * ### Passing VS-TCS I/O in registers
47  *
48  * On GPUs that run SW VS and  SW TCS on the same HW stage (HS on GFX9+),
49  * IO can be passed through registers instead of LDS when the following conditions are met:
50  *
51  * 1. TCS input and output patch size match
52  * 2. Floating point execution modes in SW VS and SW TCS match
53  * 3. The SW VS output is not written indirectly, and the corresponding SW TCS input is not read indirectly
54  *
55  * Some HS outputs could be passed through registers to, but this is a TODO.
56  *
57  * ### LDS layout used by VS-TCS:
58  *
59  * ```
60  * TCS per-vertex inputs for patch 0  <─── 0
61  * TCS per-vertex inputs for patch 1
62  * TCS per-vertex inputs for patch 2  <─── hs_per_vertex_input_lds_offset (rel_patch_id = 2)
63  * ...
64  * TCS per-vertex outputs for patch 0 <─── output_patch0_offset
65  * TCS per-patch outputs for patch 0  <─── output_patch0_patch_data_offset
66  * TCS per-vertex outputs for patch 1
67  * TCS per-patch outputs for patch 1
68  * TCS per-vertex outputs for patch 2 <─── hs_output_lds_offset (rel_patch_id = 2, per-vertex)
69  * TCS per-patch outputs for patch 2  <─── hs_output_lds_offset (rel_patch_id = 2, per-patch)
70  * ...
71  * ```
72  *
73  * ### VRAM layout used by TCS-TES I/O:
74  *
75  * ```
76  * attr 0 of patch 0 vertex 0   <─── "off-chip LDS" offset
77  * attr 0 of patch 0 vertex 1
78  * attr 0 of patch 0 vertex 2
79  * ...
80  * attr 0 of patch 1 vertex 0
81  * attr 0 of patch 1 vertex 1
82  * attr 0 of patch 1 vertex 2   <─── hs_per_vertex_output_vmem_offset (attribute slot = 0, rel_patch_id = 1, vertex index = 1)
83  * ...
84  * attr 0 of patch 2 vertex 0
85  * attr 0 of patch 2 vertex 1
86  * attr 0 of patch 2 vertex 2
87  * ...
88  * attr 1 of patch 0 vertex 0
89  * attr 1 of patch 0 vertex 1
90  * attr 1 of patch 0 vertex 2
91  * ...
92  * ...
93  * per-patch attr 0 of patch 0  <─── hs_out_patch_data_offset_amd
94  * per-patch attr 0 of patch 1
95  * per-patch attr 0 of patch 2  <─── hs_per_patch_output_vmem_offset (attribute slot = 0, rel_patch_id = 2)
96  * ...
97  * per-patch attr 1 of patch 0
98  * per-patch attr 1 of patch 1
99  * per-patch attr 1 of patch 2
100  * ...
101  * ```
102  *
103  */
104 
105 typedef struct {
106    /* Which hardware generation we're dealing with */
107    enum amd_gfx_level gfx_level;
108 
109    /* I/O semantic -> real location used by lowering. */
110    ac_nir_map_io_driver_location map_io;
111 
112    /* True if merged VS+TCS (on GFX9+) has the same number
113     * of input and output patch size.
114     */
115    bool tcs_in_out_eq;
116 
117    /* Bit mask of TCS per-vertex inputs (VS outputs) which
118     * are passed between the two stages only in temporaries (registers).
119     *
120     * A VS output can be passed to TCS in registers when:
121     * - VS is known to write, and TCS is known to read it
122     * - Neither VS nor TCS accesses it indirecty
123     * - There are no TCS cross-invocation reads to this input
124     */
125    uint64_t tcs_temp_only_inputs;
126 
127    /* Bit mask of inputs read by the TCS,
128     * this is used for linking VS outputs to TCS inputs.
129     */
130    uint64_t tcs_inputs_read;
131 
132    /* Bit mask of TCS outputs read by TES. */
133    uint64_t tes_inputs_read;
134    uint32_t tes_patch_inputs_read;
135 
136    /* True if the output patch fits the subgroup, so all TCS outputs are always written in the same
137     * subgroup that reads them.
138     */
139    bool tcs_out_patch_fits_subgroup;
140 
141    /* Set if all invocations will write to all tess factors, so tess factors
142     * can be passed by register.
143     */
144    bool tcs_pass_tessfactors_by_reg;
145 
146    /* Save TCS tess factor for tess factor writer. */
147    nir_variable *tcs_tess_level_outer;
148    nir_variable *tcs_tess_level_inner;
149    unsigned tcs_tess_level_outer_base;
150    unsigned tcs_tess_level_outer_mask;
151    unsigned tcs_tess_level_inner_base;
152    unsigned tcs_tess_level_inner_mask;
153 } lower_tess_io_state;
154 
155 typedef struct {
156    nir_def *outer;
157    nir_def *inner;
158 } tess_levels;
159 
160 #define TESS_LVL_MASK (VARYING_BIT_TESS_LEVEL_OUTER | VARYING_BIT_TESS_LEVEL_INNER)
161 
162 static uint64_t
tcs_vram_per_vtx_out_mask(nir_shader * shader,lower_tess_io_state * st)163 tcs_vram_per_vtx_out_mask(nir_shader *shader, lower_tess_io_state *st)
164 {
165    return st->tes_inputs_read & ~TESS_LVL_MASK;
166 }
167 
168 static uint32_t
tcs_vram_tf_out_mask(nir_shader * shader,lower_tess_io_state * st)169 tcs_vram_tf_out_mask(nir_shader *shader, lower_tess_io_state *st)
170 {
171    return st->tes_inputs_read & TESS_LVL_MASK;
172 }
173 
174 static uint32_t
tcs_vram_per_patch_out_mask(nir_shader * shader,lower_tess_io_state * st)175 tcs_vram_per_patch_out_mask(nir_shader *shader, lower_tess_io_state *st)
176 {
177    return st->tes_patch_inputs_read;
178 }
179 
180 static bool
tcs_output_needs_vmem(nir_intrinsic_instr * intrin,nir_shader * shader,lower_tess_io_state * st)181 tcs_output_needs_vmem(nir_intrinsic_instr *intrin,
182                       nir_shader *shader,
183                       lower_tess_io_state *st)
184 {
185    /* no_varying indicates that TES doesn't read the output. */
186    if (nir_intrinsic_io_semantics(intrin).no_varying)
187       return false;
188 
189    const unsigned loc = nir_intrinsic_io_semantics(intrin).location;
190    const bool per_vertex = intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
191                            intrin->intrinsic == nir_intrinsic_load_per_vertex_output;
192 
193    if (per_vertex) {
194       return tcs_vram_per_vtx_out_mask(shader, st) & BITFIELD64_BIT(loc);
195    } else if (loc == VARYING_SLOT_TESS_LEVEL_OUTER || loc == VARYING_SLOT_TESS_LEVEL_INNER) {
196       return false;
197    } else {
198       return tcs_vram_per_patch_out_mask(shader, st) & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0);
199    }
200 }
201 
202 static uint64_t
tcs_lds_per_vtx_out_mask(nir_shader * shader)203 tcs_lds_per_vtx_out_mask(nir_shader *shader)
204 {
205    return shader->info.outputs_read & shader->info.outputs_written & ~TESS_LVL_MASK;
206 }
207 
208 static uint64_t
tcs_lds_tf_out_mask(nir_shader * shader,lower_tess_io_state * st)209 tcs_lds_tf_out_mask(nir_shader *shader, lower_tess_io_state *st)
210 {
211    return st->tcs_pass_tessfactors_by_reg ? 0ull : (shader->info.outputs_written & TESS_LVL_MASK);
212 }
213 
214 static uint32_t
tcs_lds_per_patch_out_mask(nir_shader * shader)215 tcs_lds_per_patch_out_mask(nir_shader *shader)
216 {
217    return shader->info.patch_outputs_read & shader->info.patch_outputs_written;
218 }
219 
220 static bool
tcs_output_needs_lds(nir_intrinsic_instr * intrin,nir_shader * shader,lower_tess_io_state * st)221 tcs_output_needs_lds(nir_intrinsic_instr *intrin,
222                      nir_shader *shader,
223                      lower_tess_io_state *st)
224 {
225    const unsigned loc = nir_intrinsic_io_semantics(intrin).location;
226    const bool per_vertex = intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
227                            intrin->intrinsic == nir_intrinsic_load_per_vertex_output;
228 
229    if (per_vertex) {
230       return tcs_lds_per_vtx_out_mask(shader) & BITFIELD64_BIT(loc);
231    } else if (loc == VARYING_SLOT_TESS_LEVEL_OUTER || loc == VARYING_SLOT_TESS_LEVEL_INNER) {
232       return tcs_lds_tf_out_mask(shader, st) & BITFIELD64_BIT(loc);
233    } else {
234       return tcs_lds_per_patch_out_mask(shader) & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0);
235    }
236 }
237 
238 static bool
lower_ls_output_store(nir_builder * b,nir_intrinsic_instr * intrin,void * state)239 lower_ls_output_store(nir_builder *b,
240                       nir_intrinsic_instr *intrin,
241                       void *state)
242 {
243    if (intrin->intrinsic != nir_intrinsic_store_output)
244       return false;
245 
246    /* The ARB_shader_viewport_layer_array spec contains the
247     * following issue:
248     *
249     *    2) What happens if gl_ViewportIndex or gl_Layer is
250     *    written in the vertex shader and a geometry shader is
251     *    present?
252     *
253     *    RESOLVED: The value written by the last vertex processing
254     *    stage is used. If the last vertex processing stage
255     *    (vertex, tessellation evaluation or geometry) does not
256     *    statically assign to gl_ViewportIndex or gl_Layer, index
257     *    or layer zero is assumed.
258     *
259     * So writes to those outputs in VS-as-LS are simply ignored.
260     */
261    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
262    if (io_sem.location == VARYING_SLOT_LAYER || io_sem.location == VARYING_SLOT_VIEWPORT) {
263       nir_instr_remove(&intrin->instr);
264       return true;
265    }
266 
267    lower_tess_io_state *st = (lower_tess_io_state *) state;
268 
269    /* When a VS output isn't read by TCS, don't emit anything. */
270    if ((io_sem.no_varying || !(st->tcs_inputs_read & BITFIELD64_BIT(io_sem.location)))) {
271       nir_instr_remove(&intrin->instr);
272       return true;
273    }
274 
275    /* If this is a temp-only TCS input, we don't need to use shared memory at all. */
276    if (st->tcs_temp_only_inputs & BITFIELD64_BIT(io_sem.location))
277       return false;
278 
279    b->cursor = nir_before_instr(&intrin->instr);
280 
281    nir_def *vertex_idx = nir_load_local_invocation_index(b);
282    nir_def *base_off_var = nir_imul(b, vertex_idx, nir_load_lshs_vertex_stride_amd(b));
283 
284    unsigned mapped = ac_nir_map_io_location(io_sem.location, st->tcs_inputs_read & ~st->tcs_temp_only_inputs,
285                                             st->map_io);
286    nir_def *io_off = ac_nir_calc_io_off(b, intrin, nir_imm_int(b, 16u), 4u, mapped);
287    unsigned write_mask = nir_intrinsic_write_mask(intrin);
288 
289    nir_def *off = nir_iadd_nuw(b, base_off_var, io_off);
290    AC_NIR_STORE_IO(b, intrin->src[0].ssa, 0, write_mask, io_sem.high_16bits,
291                    nir_store_shared, off, .write_mask = store_write_mask, .base = store_const_offset);
292 
293    /* NOTE: don't remove the store_output intrinsic on GFX9+ when tcs_in_out_eq,
294     * it will be used by same-invocation TCS input loads.
295     */
296    if (!st->tcs_in_out_eq)
297       nir_instr_remove(&intrin->instr);
298 
299    return true;
300 }
301 
302 static bool
filter_load_tcs_per_vertex_input(const nir_instr * instr,UNUSED const void * state)303 filter_load_tcs_per_vertex_input(const nir_instr *instr,
304                                  UNUSED const void *state)
305 {
306    if (instr->type != nir_instr_type_intrinsic)
307       return false;
308 
309    lower_tess_io_state *st = (lower_tess_io_state *) state;
310    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
311 
312    if (intrin->intrinsic != nir_intrinsic_load_per_vertex_input)
313       return false;
314    if (!st->tcs_in_out_eq)
315       return true;
316 
317    /* tcs_in_out_eq: a same-invocation input load, without indirect offset,
318     * can use temporaries, no need to use shared memory.
319     */
320    nir_src *off_src = nir_get_io_offset_src(intrin);
321    nir_src *vertex_index_src = nir_get_io_arrayed_index_src(intrin);
322    nir_instr *vertex_index_instr = vertex_index_src->ssa->parent_instr;
323 
324 
325    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
326 
327    /* If this is a temp-only TCS input, we don't need to use shared memory at all. */
328    if (st->tcs_temp_only_inputs & BITFIELD64_BIT(io_sem.location)) {
329       ASSERTED bool can_use_temps =
330          nir_src_is_const(*off_src) &&
331          vertex_index_instr->type == nir_instr_type_intrinsic &&
332          nir_instr_as_intrinsic(vertex_index_instr)->intrinsic == nir_intrinsic_load_invocation_id;
333 
334       assert(can_use_temps);
335       return false;
336    }
337 
338    return true;
339 }
340 
341 static nir_def *
hs_per_vertex_input_lds_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * instr)342 hs_per_vertex_input_lds_offset(nir_builder *b,
343                                lower_tess_io_state *st,
344                                nir_intrinsic_instr *instr)
345 {
346    nir_def *tcs_in_vtxcnt = nir_load_patch_vertices_in(b);
347    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
348    nir_def *vertex_index = nir_get_io_arrayed_index_src(instr)->ssa;
349 
350    nir_def *stride = nir_load_lshs_vertex_stride_amd(b);
351    nir_def *tcs_in_patch_stride = nir_imul(b, tcs_in_vtxcnt, stride);
352    nir_def *vertex_index_off = nir_imul(b, vertex_index, stride);
353 
354    nir_def *tcs_in_current_patch_offset = nir_imul(b, rel_patch_id, tcs_in_patch_stride);
355 
356    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(instr);
357    const unsigned mapped = ac_nir_map_io_location(io_sem.location, st->tcs_inputs_read & ~st->tcs_temp_only_inputs,
358                                                   st->map_io);
359    nir_def *io_offset = ac_nir_calc_io_off(b, instr, nir_imm_int(b, 16u), 4u, mapped);
360 
361    return nir_iadd_nuw(b, nir_iadd_nuw(b, tcs_in_current_patch_offset, vertex_index_off), io_offset);
362 }
363 
364 static unsigned
hs_output_lds_map_io_location(nir_shader * shader,const bool per_vertex,const unsigned loc,lower_tess_io_state * st)365 hs_output_lds_map_io_location(nir_shader *shader,
366                               const bool per_vertex,
367                               const unsigned loc,
368                               lower_tess_io_state *st)
369 {
370    if (!per_vertex) {
371       const uint64_t tf_mask = tcs_lds_tf_out_mask(shader, st);
372       if (loc == VARYING_SLOT_TESS_LEVEL_INNER || loc == VARYING_SLOT_TESS_LEVEL_OUTER) {
373          assert(tf_mask & BITFIELD64_BIT(loc));
374          return util_bitcount64(tf_mask & BITFIELD64_MASK(loc));
375       }
376 
377       const uint32_t patch_out_mask = tcs_lds_per_patch_out_mask(shader);
378       assert(patch_out_mask & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0));
379       return util_bitcount64(tf_mask) +
380              util_bitcount(patch_out_mask & BITFIELD_MASK(loc - VARYING_SLOT_PATCH0));
381    } else {
382       const uint64_t per_vertex_mask = tcs_lds_per_vtx_out_mask(shader);
383       assert(per_vertex_mask & BITFIELD64_BIT(loc));
384       return util_bitcount64(per_vertex_mask & BITFIELD64_MASK(loc));
385    }
386 }
387 
388 static nir_def *
hs_output_lds_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin)389 hs_output_lds_offset(nir_builder *b,
390                      lower_tess_io_state *st,
391                      nir_intrinsic_instr *intrin)
392 {
393    bool per_vertex = intrin &&
394                      (intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
395                       intrin->intrinsic == nir_intrinsic_load_per_vertex_output);
396 
397    const uint64_t per_vertex_mask = tcs_lds_per_vtx_out_mask(b->shader);
398    const uint64_t tf_mask = tcs_lds_tf_out_mask(b->shader, st);
399    const uint32_t patch_out_mask = tcs_lds_per_patch_out_mask(b->shader);
400 
401    unsigned tcs_num_reserved_outputs = util_bitcount64(per_vertex_mask);
402    unsigned tcs_num_reserved_patch_outputs = util_bitcount64(tf_mask) + util_bitcount(patch_out_mask);
403    unsigned output_vertex_size = tcs_num_reserved_outputs * 16u;
404    unsigned pervertex_output_patch_size = b->shader->info.tess.tcs_vertices_out * output_vertex_size;
405    unsigned output_patch_stride = pervertex_output_patch_size + tcs_num_reserved_patch_outputs * 16u;
406 
407    nir_def *off = NULL;
408 
409    if (intrin) {
410       const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
411       const unsigned mapped = hs_output_lds_map_io_location(b->shader, per_vertex, io_sem.location, st);
412       off = ac_nir_calc_io_off(b, intrin, nir_imm_int(b, 16u), 4, mapped);
413    } else {
414       off = nir_imm_int(b, 0);
415    }
416 
417    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
418    nir_def *patch_offset = nir_imul_imm(b, rel_patch_id, output_patch_stride);
419 
420    nir_def *tcs_in_vtxcnt = nir_load_patch_vertices_in(b);
421    nir_def *tcs_num_patches = nir_load_tcs_num_patches_amd(b);
422    nir_def *input_patch_size = nir_imul(b, tcs_in_vtxcnt, nir_load_lshs_vertex_stride_amd(b));
423    nir_def *output_patch0_offset = nir_imul(b, input_patch_size, tcs_num_patches);
424    nir_def *output_patch_offset = nir_iadd_nuw(b, patch_offset, output_patch0_offset);
425 
426    if (per_vertex) {
427       nir_def *vertex_index = nir_get_io_arrayed_index_src(intrin)->ssa;
428       nir_def *vertex_index_off = nir_imul_imm(b, vertex_index, output_vertex_size);
429 
430       off = nir_iadd_nuw(b, off, vertex_index_off);
431       return nir_iadd_nuw(b, off, output_patch_offset);
432    } else {
433       off = nir_iadd_imm_nuw(b, off, pervertex_output_patch_size);
434       return nir_iadd_nuw(b, off, output_patch_offset);
435    }
436 }
437 
438 static unsigned
hs_output_vram_map_io_location(nir_shader * shader,const bool per_vertex,const unsigned loc,lower_tess_io_state * st)439 hs_output_vram_map_io_location(nir_shader *shader,
440                                const bool per_vertex,
441                                const unsigned loc,
442                                lower_tess_io_state *st)
443 {
444    /* Unlinked shaders:
445     * We are unaware of TES inputs while lowering TCS outputs.
446     * The driver needs to pass a callback to map varyings to a fixed location.
447     */
448    if (st->map_io)
449       return st->map_io(loc);
450 
451    /* Linked shaders:
452     * Take advantage of having knowledge of TES inputs while lowering TCS outputs.
453     * Map varyings to a prefix sum of the IO mask to save space in VRAM.
454     */
455    if (!per_vertex) {
456       const uint64_t tf_mask = tcs_vram_tf_out_mask(shader, st);
457       if (loc == VARYING_SLOT_TESS_LEVEL_INNER || loc == VARYING_SLOT_TESS_LEVEL_OUTER) {
458          assert(tf_mask & BITFIELD64_BIT(loc));
459          return util_bitcount64(tf_mask & BITFIELD64_MASK(loc));
460       }
461 
462       const uint32_t patch_out_mask = tcs_vram_per_patch_out_mask(shader, st);
463       assert(patch_out_mask & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0));
464       return util_bitcount64(tf_mask) +
465              util_bitcount(patch_out_mask & BITFIELD_MASK(loc - VARYING_SLOT_PATCH0));
466    } else {
467       const uint64_t per_vertex_mask = tcs_vram_per_vtx_out_mask(shader, st);
468       assert(per_vertex_mask & BITFIELD64_BIT(loc));
469       return util_bitcount64(per_vertex_mask & BITFIELD64_MASK(loc));
470    }
471 }
472 
473 static nir_def *
hs_per_vertex_output_vmem_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin)474 hs_per_vertex_output_vmem_offset(nir_builder *b,
475                                  lower_tess_io_state *st,
476                                  nir_intrinsic_instr *intrin)
477 {
478    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
479 
480    nir_def *out_vertices_per_patch = b->shader->info.stage == MESA_SHADER_TESS_CTRL
481                                          ? nir_imm_int(b, b->shader->info.tess.tcs_vertices_out)
482                                          : nir_load_patch_vertices_in(b);
483 
484    nir_def *tcs_num_patches = nir_load_tcs_num_patches_amd(b);
485    nir_def *attr_stride = nir_imul(b, tcs_num_patches, nir_imul_imm(b, out_vertices_per_patch, 16u));
486    nir_def *io_offset =
487       ac_nir_calc_io_off(b, intrin, attr_stride, 4u,
488                                    hs_output_vram_map_io_location(b->shader, true, io_sem.location, st));
489 
490    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
491    nir_def *patch_offset = nir_imul(b, rel_patch_id, nir_imul_imm(b, out_vertices_per_patch, 16u));
492 
493    nir_def *vertex_index = nir_get_io_arrayed_index_src(intrin)->ssa;
494    nir_def *vertex_index_off = nir_imul_imm(b, vertex_index, 16u);
495 
496    return nir_iadd_nuw(b, nir_iadd_nuw(b, patch_offset, vertex_index_off), io_offset);
497 }
498 
499 static nir_def *
hs_per_patch_output_vmem_offset(nir_builder * b,lower_tess_io_state * st,nir_intrinsic_instr * intrin,unsigned const_base_offset)500 hs_per_patch_output_vmem_offset(nir_builder *b,
501                                 lower_tess_io_state *st,
502                                 nir_intrinsic_instr *intrin,
503                                 unsigned const_base_offset)
504 {
505    nir_def *tcs_num_patches = nir_load_tcs_num_patches_amd(b);
506    nir_def *per_patch_data_offset = nir_load_hs_out_patch_data_offset_amd(b);
507 
508    nir_def * off =
509       intrin
510       ? ac_nir_calc_io_off(b, intrin, nir_imul_imm(b, tcs_num_patches, 16u), 4u,
511                                      hs_output_vram_map_io_location(b->shader, false, nir_intrinsic_io_semantics(intrin).location, st))
512       : nir_imm_int(b, 0);
513 
514    if (const_base_offset)
515       off = nir_iadd_nuw(b, off, nir_imul_imm(b, tcs_num_patches, const_base_offset));
516 
517    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
518    nir_def *patch_offset = nir_imul_imm(b, rel_patch_id, 16u);
519    off = nir_iadd_nuw(b, off, per_patch_data_offset);
520    return nir_iadd_nuw(b, off, patch_offset);
521 }
522 
523 static nir_def *
lower_hs_per_vertex_input_load(nir_builder * b,nir_instr * instr,void * state)524 lower_hs_per_vertex_input_load(nir_builder *b,
525                                nir_instr *instr,
526                                void *state)
527 {
528    lower_tess_io_state *st = (lower_tess_io_state *) state;
529    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
530 
531    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
532    nir_def *off = hs_per_vertex_input_lds_offset(b, st, intrin);
533    nir_def *load = NULL;
534 
535    AC_NIR_LOAD_IO(load, b, intrin->def.num_components, intrin->def.bit_size, io_sem.high_16bits,
536                   nir_load_shared, off);
537 
538    return load;
539 }
540 
541 static nir_def *
lower_hs_output_store(nir_builder * b,nir_intrinsic_instr * intrin,lower_tess_io_state * st)542 lower_hs_output_store(nir_builder *b,
543                       nir_intrinsic_instr *intrin,
544                       lower_tess_io_state *st)
545 {
546    assert(intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
547           intrin->intrinsic == nir_intrinsic_store_output);
548 
549    nir_io_semantics semantics = nir_intrinsic_io_semantics(intrin);
550    nir_def *store_val = intrin->src[0].ssa;
551    const unsigned write_mask = nir_intrinsic_write_mask(intrin);
552    const bool write_to_vmem = tcs_output_needs_vmem(intrin, b->shader, st);
553    const bool write_to_lds =  tcs_output_needs_lds(intrin, b->shader, st);
554 
555    if (write_to_vmem) {
556       nir_def *vmem_off = intrin->intrinsic == nir_intrinsic_store_per_vertex_output
557                             ? hs_per_vertex_output_vmem_offset(b, st, intrin)
558                             : hs_per_patch_output_vmem_offset(b, st, intrin, 0);
559 
560       nir_def *hs_ring_tess_offchip = nir_load_ring_tess_offchip_amd(b);
561       nir_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
562       nir_def *zero = nir_imm_int(b, 0);
563       AC_NIR_STORE_IO(b, store_val, 0, write_mask, semantics.high_16bits,
564                       nir_store_buffer_amd, hs_ring_tess_offchip, vmem_off, offchip_offset, zero,
565                       .write_mask = store_write_mask, .base = store_const_offset,
566                       .memory_modes = nir_var_shader_out, .access = ACCESS_COHERENT);
567    }
568 
569    if (write_to_lds) {
570       nir_def *lds_off = hs_output_lds_offset(b, st, intrin);
571       AC_NIR_STORE_IO(b, store_val, 0, write_mask, semantics.high_16bits,
572                       nir_store_shared, lds_off, .write_mask = store_write_mask, .base = store_const_offset);
573    }
574 
575    /* Save tess factor to be used by tess factor writer or reconstruct
576     * store output instruction later.
577     */
578    if (semantics.location == VARYING_SLOT_TESS_LEVEL_INNER ||
579        semantics.location == VARYING_SLOT_TESS_LEVEL_OUTER) {
580       const unsigned base = nir_intrinsic_base(intrin);
581       const unsigned component = nir_intrinsic_component(intrin);
582 
583       if (semantics.location == VARYING_SLOT_TESS_LEVEL_INNER) {
584          st->tcs_tess_level_inner_base = base;
585          st->tcs_tess_level_inner_mask |= write_mask << component;
586 
587          if (st->tcs_pass_tessfactors_by_reg)
588             ac_nir_store_var_components(b, st->tcs_tess_level_inner, store_val,
589                                         component, write_mask);
590       } else {
591          st->tcs_tess_level_outer_base = base;
592          st->tcs_tess_level_outer_mask |= write_mask << component;
593 
594          if (st->tcs_pass_tessfactors_by_reg)
595             ac_nir_store_var_components(b, st->tcs_tess_level_outer, store_val,
596                                         component, write_mask);
597       }
598    }
599 
600    return NIR_LOWER_INSTR_PROGRESS_REPLACE;
601 }
602 
603 static nir_def *
lower_hs_output_load(nir_builder * b,nir_intrinsic_instr * intrin,lower_tess_io_state * st)604 lower_hs_output_load(nir_builder *b,
605                      nir_intrinsic_instr *intrin,
606                      lower_tess_io_state *st)
607 {
608    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
609    const bool is_tess_factor = io_sem.location == VARYING_SLOT_TESS_LEVEL_INNER ||
610                                io_sem.location == VARYING_SLOT_TESS_LEVEL_OUTER;
611 
612    if (is_tess_factor && st->tcs_pass_tessfactors_by_reg) {
613       const unsigned component = nir_intrinsic_component(intrin);
614       const unsigned num_components = intrin->def.num_components;
615       const unsigned bit_size = intrin->def.bit_size;
616 
617       nir_def *var =
618          io_sem.location == VARYING_SLOT_TESS_LEVEL_OUTER
619             ? nir_load_var(b, st->tcs_tess_level_outer)
620             : nir_load_var(b, st->tcs_tess_level_inner);
621 
622       return nir_extract_bits(b, &var, 1, component * bit_size, num_components, bit_size);
623    }
624 
625    /* If an output is not stored by the shader, replace the output load by undef. */
626    if (!tcs_output_needs_lds(intrin, b->shader, st))
627       return nir_undef(b, intrin->def.num_components, intrin->def.bit_size);
628 
629    nir_def *off = hs_output_lds_offset(b, st, intrin);
630    nir_def *load = NULL;
631 
632    AC_NIR_LOAD_IO(load, b, intrin->def.num_components, intrin->def.bit_size, io_sem.high_16bits,
633                   nir_load_shared, off);
634 
635    return load;
636 }
637 
638 static void
update_hs_barrier(nir_intrinsic_instr * intrin,lower_tess_io_state * st)639 update_hs_barrier(nir_intrinsic_instr *intrin, lower_tess_io_state *st)
640 {
641    /* Output loads and stores are lowered to shared memory access,
642     * so we have to update the barriers to also reflect this.
643     */
644    unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
645    if (mem_modes & nir_var_shader_out) {
646       mem_modes |= nir_var_mem_shared;
647       mem_modes &= ~nir_var_shader_out;
648    }
649    nir_intrinsic_set_memory_modes(intrin, mem_modes);
650 
651    mesa_scope exec_scope = nir_intrinsic_execution_scope(intrin);
652    if (exec_scope == SCOPE_WORKGROUP && st->tcs_out_patch_fits_subgroup)
653       nir_intrinsic_set_execution_scope(intrin, SCOPE_SUBGROUP);
654 
655    mesa_scope mem_scope = nir_intrinsic_memory_scope(intrin);
656    if (mem_scope == SCOPE_WORKGROUP && st->tcs_out_patch_fits_subgroup)
657       nir_intrinsic_set_memory_scope(intrin, SCOPE_SUBGROUP);
658 }
659 
660 static nir_def *
lower_hs_output_access(nir_builder * b,nir_instr * instr,void * state)661 lower_hs_output_access(nir_builder *b,
662                        nir_instr *instr,
663                        void *state)
664 {
665    lower_tess_io_state *st = (lower_tess_io_state *) state;
666    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
667 
668    if (intrin->intrinsic == nir_intrinsic_store_output ||
669        intrin->intrinsic == nir_intrinsic_store_per_vertex_output) {
670       return lower_hs_output_store(b, intrin, st);
671    } else if (intrin->intrinsic == nir_intrinsic_load_output ||
672               intrin->intrinsic == nir_intrinsic_load_per_vertex_output) {
673       return lower_hs_output_load(b, intrin, st);
674    } else if (intrin->intrinsic == nir_intrinsic_barrier) {
675       update_hs_barrier(intrin, st);
676       return NIR_LOWER_INSTR_PROGRESS;
677    } else {
678       unreachable("intrinsic not supported by lower_hs_output_access");
679    }
680 }
681 
682 static tess_levels
hs_load_tess_levels(nir_builder * b,lower_tess_io_state * st)683 hs_load_tess_levels(nir_builder *b,
684                     lower_tess_io_state *st)
685 {
686    unsigned outer_comps, inner_comps;
687    mesa_count_tess_level_components(b->shader->info.tess._primitive_mode,
688                                     &outer_comps, &inner_comps);
689 
690    nir_def *outer = NULL;
691    nir_def *inner = NULL;
692 
693    if (st->tcs_pass_tessfactors_by_reg) {
694       if (st->tcs_tess_level_outer_mask) {
695          outer = nir_load_var(b, st->tcs_tess_level_outer);
696          outer = nir_trim_vector(b, outer, outer_comps);
697       }
698 
699       if (inner_comps && st->tcs_tess_level_inner_mask) {
700          inner = nir_load_var(b, st->tcs_tess_level_inner);
701          inner = nir_trim_vector(b, inner, inner_comps);
702       }
703    } else {
704       /* Base LDS address of per-patch outputs in the current patch. */
705       nir_def *lds_base = hs_output_lds_offset(b, st, NULL);
706 
707       /* Load all tessellation factors (aka. tess levels) from LDS. */
708       if (st->tcs_tess_level_outer_mask) {
709          const unsigned mapped = hs_output_lds_map_io_location(b->shader, false, VARYING_SLOT_TESS_LEVEL_OUTER, st);
710          outer = nir_load_shared(b, outer_comps, 32, lds_base, .base = mapped * 16);
711       }
712 
713       if (inner_comps && st->tcs_tess_level_inner_mask) {
714          const unsigned mapped = hs_output_lds_map_io_location(b->shader, false, VARYING_SLOT_TESS_LEVEL_INNER, st);
715          inner = nir_load_shared(b, inner_comps, 32, lds_base, .base = mapped * 16);
716       }
717    }
718 
719    /* Set tess factor to zero if the shader did not write them. */
720    if (!outer)
721       outer = nir_imm_zero(b, outer_comps, 32);
722    if (inner_comps && !inner)
723       inner = nir_imm_zero(b, inner_comps, 32);
724 
725    tess_levels r = {
726       .outer = outer,
727       .inner = inner,
728    };
729 
730    return r;
731 }
732 
733 static void
hs_store_dynamic_control_word_gfx6(nir_builder * b)734 hs_store_dynamic_control_word_gfx6(nir_builder *b)
735 {
736    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
737    nir_def *tessfactor_ring = nir_load_ring_tess_factors_amd(b);
738    nir_def *tess_factors_base = nir_load_ring_tess_factors_offset_amd(b);
739 
740    /* Store the dynamic HS control word. */
741    nir_if *rel_patch_id_zero = nir_push_if(b, nir_ieq_imm(b, rel_patch_id, 0));
742    nir_def *zero = nir_imm_int(b, 0);
743    nir_def *ctrlw = nir_imm_int(b, 0x80000000u);
744    nir_store_buffer_amd(b, ctrlw, tessfactor_ring, zero, tess_factors_base, zero,
745                         .access = ACCESS_COHERENT);
746    nir_pop_if(b, rel_patch_id_zero);
747 }
748 
749 static nir_def *
hs_resize_tess_factor(nir_builder * b,nir_def * tf,unsigned comps)750 hs_resize_tess_factor(nir_builder *b, nir_def *tf, unsigned comps)
751 {
752    if (!comps)
753       return NULL;
754    else if (!tf)
755       return nir_imm_zero(b, comps, 32);
756    else if (comps > tf->num_components)
757       return nir_pad_vector_imm_int(b, tf, 0, comps);
758    else if (comps < tf->num_components)
759       return nir_trim_vector(b, tf, comps);
760    else
761       return tf;
762 }
763 
764 static void
hs_store_tess_factors_for_tessellator(nir_builder * b,enum amd_gfx_level gfx_level,enum tess_primitive_mode prim_mode,tess_levels tessfactors)765 hs_store_tess_factors_for_tessellator(nir_builder *b, enum amd_gfx_level gfx_level,
766                                       enum tess_primitive_mode prim_mode,
767                                       tess_levels tessfactors)
768 {
769    nir_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
770    nir_def *tessfactor_ring = nir_load_ring_tess_factors_amd(b);
771    nir_def *tess_factors_base = nir_load_ring_tess_factors_offset_amd(b);
772    nir_def *zero = nir_imm_int(b, 0);
773 
774    const unsigned tess_factors_const_offset = gfx_level <= GFX8 ? 4 : 0;
775    unsigned outer_comps, inner_comps;
776 
777    mesa_count_tess_level_components(prim_mode, &outer_comps, &inner_comps);
778 
779    nir_def *tess_factors_offset =
780       nir_imul_imm(b, rel_patch_id, (inner_comps + outer_comps) * 4u);
781 
782    nir_def *tf_outer = hs_resize_tess_factor(b, tessfactors.outer, outer_comps);
783    nir_def *tf_inner = hs_resize_tess_factor(b, tessfactors.inner, inner_comps);
784 
785    /* Store tess factors for the tessellator */
786    if (prim_mode == TESS_PRIMITIVE_ISOLINES) {
787       /* LINES reversal */
788       nir_def *t = nir_vec2(b, nir_channel(b, tf_outer, 1), nir_channel(b, tf_outer, 0));
789       nir_store_buffer_amd(b, t, tessfactor_ring, tess_factors_offset, tess_factors_base, zero,
790                            .base = tess_factors_const_offset, .access = ACCESS_COHERENT | ACCESS_CP_GE_COHERENT_AMD);
791    } else if (prim_mode == TESS_PRIMITIVE_TRIANGLES) {
792       nir_def *t = nir_vec4(b, nir_channel(b, tf_outer, 0), nir_channel(b, tf_outer, 1),
793                                nir_channel(b, tf_outer, 2), nir_channel(b, tf_inner, 0));
794       nir_store_buffer_amd(b, t, tessfactor_ring, tess_factors_offset, tess_factors_base, zero,
795                            .base = tess_factors_const_offset, .access = ACCESS_COHERENT | ACCESS_CP_GE_COHERENT_AMD);
796    } else {
797       nir_store_buffer_amd(b, tf_outer, tessfactor_ring, tess_factors_offset, tess_factors_base, zero,
798                            .base = tess_factors_const_offset, .access = ACCESS_COHERENT | ACCESS_CP_GE_COHERENT_AMD);
799       nir_store_buffer_amd(b, tf_inner, tessfactor_ring, tess_factors_offset, tess_factors_base, zero,
800                            .base = tess_factors_const_offset + 4u * outer_comps,
801                            .access = ACCESS_COHERENT | ACCESS_CP_GE_COHERENT_AMD);
802    }
803 }
804 
805 static void
hs_store_tess_factors_for_tes(nir_builder * b,tess_levels tessfactors,lower_tess_io_state * st)806 hs_store_tess_factors_for_tes(nir_builder *b, tess_levels tessfactors, lower_tess_io_state *st)
807 {
808    nir_def *hs_ring_tess_offchip = nir_load_ring_tess_offchip_amd(b);
809    nir_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
810    nir_def *zero = nir_imm_int(b, 0);
811 
812    /* For linked shaders, we must only write the tess factors that the TES actually reads,
813     * otherwise we would write to a memory location reserved for another per-patch output.
814     */
815    const bool tes_reads_outer = st->tes_inputs_read & VARYING_BIT_TESS_LEVEL_OUTER;
816    const bool tes_reads_inner = st->tes_inputs_read & VARYING_BIT_TESS_LEVEL_INNER;
817 
818    if (st->tcs_tess_level_outer_mask && tes_reads_outer) {
819       const unsigned tf_outer_loc = hs_output_vram_map_io_location(b->shader, false, VARYING_SLOT_TESS_LEVEL_OUTER, st);
820       nir_def *vmem_off_outer = hs_per_patch_output_vmem_offset(b, st, NULL, tf_outer_loc * 16);
821 
822       nir_store_buffer_amd(b, tessfactors.outer, hs_ring_tess_offchip,
823                            vmem_off_outer, offchip_offset, zero,
824                            .memory_modes = nir_var_shader_out,
825                            .access = ACCESS_COHERENT);
826    }
827 
828    if (tessfactors.inner && st->tcs_tess_level_inner_mask && tes_reads_inner) {
829       const unsigned tf_inner_loc = hs_output_vram_map_io_location(b->shader, false, VARYING_SLOT_TESS_LEVEL_INNER, st);
830       nir_def *vmem_off_inner = hs_per_patch_output_vmem_offset(b, st, NULL, tf_inner_loc * 16);
831 
832       nir_store_buffer_amd(b, tessfactors.inner, hs_ring_tess_offchip,
833                            vmem_off_inner, offchip_offset, zero,
834                            .memory_modes = nir_var_shader_out,
835                            .access = ACCESS_COHERENT);
836    }
837 }
838 
839 static nir_if *
hs_if_invocation_id_zero(nir_builder * b)840 hs_if_invocation_id_zero(nir_builder *b)
841 {
842    nir_def *invocation_id = nir_load_invocation_id(b);
843 
844    /* Only the 1st invocation of each patch needs to do this. */
845    nir_if *invocation_id_zero = nir_push_if(b, nir_ieq_imm(b, invocation_id, 0));
846 
847    /* When the output patch size is <= 32 then we can flatten the branch here
848     * because we know for sure that at least 1 invocation in all waves will
849     * take the branch.
850     */
851    if (b->shader->info.tess.tcs_vertices_out <= 32)
852       invocation_id_zero->control = nir_selection_control_divergent_always_taken;
853 
854    return invocation_id_zero;
855 }
856 
857 static void
hs_finale(nir_shader * shader,lower_tess_io_state * st)858 hs_finale(nir_shader *shader,
859           lower_tess_io_state *st)
860 {
861    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
862    assert(impl);
863    nir_block *last_block = nir_impl_last_block(impl);
864    assert(last_block);
865 
866    nir_builder builder = nir_builder_at(nir_after_block(last_block));
867    nir_builder *b = &builder; /* This is to avoid the & */
868 
869    /* If tess factors are load from LDS, wait previous LDS stores done. */
870    if (!st->tcs_pass_tessfactors_by_reg) {
871       mesa_scope scope = st->tcs_out_patch_fits_subgroup ? SCOPE_SUBGROUP : SCOPE_WORKGROUP;
872       nir_barrier(b, .execution_scope = scope, .memory_scope = scope,
873                      .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
874    }
875 
876    /* Only the 1st invocation of each patch needs to access VRAM and/or LDS. */
877    nir_if *if_invocation_id_zero = hs_if_invocation_id_zero(b);
878    {
879       tess_levels tessfactors = hs_load_tess_levels(b, st);
880 
881       if (st->gfx_level <= GFX8)
882          hs_store_dynamic_control_word_gfx6(b);
883 
884       nir_def *prim_mode = nir_load_tcs_primitive_mode_amd(b);
885       nir_if *if_triangles = nir_push_if(b, nir_ieq_imm(b, prim_mode, TESS_PRIMITIVE_TRIANGLES));
886       {
887          hs_store_tess_factors_for_tessellator(b, st->gfx_level, TESS_PRIMITIVE_TRIANGLES, tessfactors);
888       }
889       nir_push_else(b, if_triangles);
890       {
891          nir_if *if_isolines = nir_push_if(b, nir_ieq_imm(b, prim_mode, TESS_PRIMITIVE_ISOLINES));
892          {
893             hs_store_tess_factors_for_tessellator(b, st->gfx_level, TESS_PRIMITIVE_ISOLINES, tessfactors);
894          }
895          nir_push_else(b, if_isolines);
896          {
897             hs_store_tess_factors_for_tessellator(b, st->gfx_level, TESS_PRIMITIVE_QUADS, tessfactors);
898          }
899          nir_pop_if(b, if_isolines);
900       }
901       nir_pop_if(b, if_triangles);
902 
903       nir_if *if_tes_reads_tf = nir_push_if(b, nir_load_tcs_tess_levels_to_tes_amd(b));
904       {
905          hs_store_tess_factors_for_tes(b, tessfactors, st);
906       }
907       nir_pop_if(b, if_tes_reads_tf);
908    }
909 
910    nir_pop_if(b, if_invocation_id_zero);
911 
912    nir_metadata_preserve(impl, nir_metadata_none);
913 }
914 
915 static nir_def *
lower_tes_input_load(nir_builder * b,nir_instr * instr,void * state)916 lower_tes_input_load(nir_builder *b,
917                      nir_instr *instr,
918                      void *state)
919 {
920    lower_tess_io_state *st = (lower_tess_io_state *) state;
921    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
922 
923    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
924    nir_def *offchip_ring = nir_load_ring_tess_offchip_amd(b);
925    nir_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
926    nir_def *off = intrin->intrinsic == nir_intrinsic_load_per_vertex_input
927                     ? hs_per_vertex_output_vmem_offset(b, st, intrin)
928                     : hs_per_patch_output_vmem_offset(b, st, intrin, 0);
929 
930    nir_def *zero = nir_imm_int(b, 0);
931    nir_def *load = NULL;
932 
933    AC_NIR_LOAD_IO(load, b, intrin->def.num_components, intrin->def.bit_size, io_sem.high_16bits,
934                   nir_load_buffer_amd, offchip_ring, off, offchip_offset, zero, .access = ACCESS_COHERENT);
935 
936    return load;
937 }
938 
939 static bool
filter_hs_output_access(const nir_instr * instr,UNUSED const void * st)940 filter_hs_output_access(const nir_instr *instr,
941                          UNUSED const void *st)
942 {
943    if (instr->type != nir_instr_type_intrinsic)
944       return false;
945 
946    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
947    return intrin->intrinsic == nir_intrinsic_store_output ||
948           intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
949           intrin->intrinsic == nir_intrinsic_load_output ||
950           intrin->intrinsic == nir_intrinsic_load_per_vertex_output ||
951           intrin->intrinsic == nir_intrinsic_barrier;
952 }
953 
954 static bool
filter_any_input_access(const nir_instr * instr,UNUSED const void * st)955 filter_any_input_access(const nir_instr *instr,
956                         UNUSED const void *st)
957 {
958    if (instr->type != nir_instr_type_intrinsic)
959       return false;
960 
961    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
962    return intrin->intrinsic == nir_intrinsic_load_input ||
963           intrin->intrinsic == nir_intrinsic_load_per_vertex_input;
964 }
965 
966 void
ac_nir_lower_ls_outputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,bool tcs_in_out_eq,uint64_t tcs_inputs_read,uint64_t tcs_temp_only_inputs)967 ac_nir_lower_ls_outputs_to_mem(nir_shader *shader,
968                                ac_nir_map_io_driver_location map,
969                                bool tcs_in_out_eq,
970                                uint64_t tcs_inputs_read,
971                                uint64_t tcs_temp_only_inputs)
972 {
973    assert(shader->info.stage == MESA_SHADER_VERTEX);
974 
975    lower_tess_io_state state = {
976       .tcs_in_out_eq = tcs_in_out_eq,
977       .tcs_inputs_read = tcs_inputs_read,
978       .tcs_temp_only_inputs = tcs_in_out_eq ? tcs_temp_only_inputs : 0,
979       .map_io = map,
980    };
981 
982    nir_shader_intrinsics_pass(shader, lower_ls_output_store,
983                                 nir_metadata_control_flow,
984                                 &state);
985 }
986 
987 void
ac_nir_lower_hs_inputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,bool tcs_in_out_eq,uint64_t tcs_temp_only_inputs)988 ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
989                               ac_nir_map_io_driver_location map,
990                               bool tcs_in_out_eq,
991                               uint64_t tcs_temp_only_inputs)
992 {
993    assert(shader->info.stage == MESA_SHADER_TESS_CTRL);
994 
995    lower_tess_io_state state = {
996       .tcs_inputs_read = shader->info.inputs_read,
997       .tcs_in_out_eq = tcs_in_out_eq,
998       .tcs_temp_only_inputs = tcs_in_out_eq ? tcs_temp_only_inputs : 0,
999       .map_io = map,
1000    };
1001 
1002    nir_shader_lower_instructions(shader,
1003                                  filter_load_tcs_per_vertex_input,
1004                                  lower_hs_per_vertex_input_load,
1005                                  &state);
1006 }
1007 
1008 void
ac_nir_lower_hs_outputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,enum amd_gfx_level gfx_level,uint64_t tes_inputs_read,uint32_t tes_patch_inputs_read,unsigned wave_size,bool pass_tessfactors_by_reg)1009 ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
1010                                ac_nir_map_io_driver_location map,
1011                                enum amd_gfx_level gfx_level,
1012                                uint64_t tes_inputs_read,
1013                                uint32_t tes_patch_inputs_read,
1014                                unsigned wave_size,
1015                                bool pass_tessfactors_by_reg)
1016 {
1017    assert(shader->info.stage == MESA_SHADER_TESS_CTRL);
1018 
1019    lower_tess_io_state state = {
1020       .gfx_level = gfx_level,
1021       .tes_inputs_read = tes_inputs_read,
1022       .tes_patch_inputs_read = tes_patch_inputs_read,
1023       .tcs_out_patch_fits_subgroup = wave_size % shader->info.tess.tcs_vertices_out == 0,
1024       .tcs_pass_tessfactors_by_reg = pass_tessfactors_by_reg,
1025       .map_io = map,
1026    };
1027 
1028    if (pass_tessfactors_by_reg) {
1029       nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1030       state.tcs_tess_level_outer =
1031          nir_local_variable_create(impl, glsl_vec4_type(), "tess outer");
1032       state.tcs_tess_level_inner =
1033          nir_local_variable_create(impl, glsl_vec4_type(), "tess inner");
1034    }
1035 
1036    nir_shader_lower_instructions(shader,
1037                                  filter_hs_output_access,
1038                                  lower_hs_output_access,
1039                                  &state);
1040 
1041    hs_finale(shader, &state);
1042 }
1043 
1044 void
ac_nir_lower_tes_inputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map)1045 ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
1046                                ac_nir_map_io_driver_location map)
1047 {
1048    assert(shader->info.stage == MESA_SHADER_TESS_EVAL);
1049 
1050    lower_tess_io_state state = {
1051       .map_io = map,
1052       .tes_inputs_read = shader->info.inputs_read,
1053       .tes_patch_inputs_read = shader->info.patch_inputs_read,
1054    };
1055 
1056    nir_shader_lower_instructions(shader,
1057                                  filter_any_input_access,
1058                                  lower_tes_input_load,
1059                                  &state);
1060 }
1061