xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/radeonsi/si_nir_lower_vs_inputs.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Advanced Micro Devices, Inc.
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "nir_builder.h"
8 
9 #include "ac_nir.h"
10 #include "si_shader_internal.h"
11 #include "si_state.h"
12 #include "si_pipe.h"
13 
14 struct lower_vs_inputs_state {
15    struct si_shader *shader;
16    struct si_shader_args *args;
17 
18    nir_def *instance_divisor_constbuf;
19    nir_def *vertex_index[16];
20 };
21 
22 /* See fast_idiv_by_const.h. */
23 /* If num != UINT_MAX, this more efficient version can be used. */
24 /* Set: increment = util_fast_udiv_info::increment; */
25 static nir_def *
fast_udiv_nuw(nir_builder * b,nir_def * num,nir_def * divisor)26 fast_udiv_nuw(nir_builder *b, nir_def *num, nir_def *divisor)
27 {
28    nir_def *multiplier = nir_channel(b, divisor, 0);
29    nir_def *pre_shift = nir_channel(b, divisor, 1);
30    nir_def *post_shift = nir_channel(b, divisor, 2);
31    nir_def *increment = nir_channel(b, divisor, 3);
32 
33    num = nir_ushr(b, num, pre_shift);
34    num = nir_iadd_nuw(b, num, increment);
35    num = nir_umul_high(b, num, multiplier);
36    return nir_ushr(b, num, post_shift);
37 }
38 
39 static nir_def *
get_vertex_index(nir_builder * b,int input_index,struct lower_vs_inputs_state * s)40 get_vertex_index(nir_builder *b, int input_index, struct lower_vs_inputs_state *s)
41 {
42    const union si_shader_key *key = &s->shader->key;
43 
44    bool divisor_is_one =
45       key->ge.mono.instance_divisor_is_one & (1u << input_index);
46    bool divisor_is_fetched =
47       key->ge.mono.instance_divisor_is_fetched & (1u << input_index);
48 
49    if (divisor_is_one || divisor_is_fetched) {
50       nir_def *instance_id = nir_load_instance_id(b);
51 
52       /* This is used to determine vs vgpr count in si_get_vs_vgpr_comp_cnt(). */
53       s->shader->info.uses_instanceid = true;
54 
55       nir_def *index = NULL;
56       if (divisor_is_one) {
57          index = instance_id;
58       } else {
59          nir_def *offset = nir_imm_int(b, input_index * 16);
60          nir_def *divisor = nir_load_ubo(b, 4, 32, s->instance_divisor_constbuf, offset,
61                                              .range = ~0);
62 
63          /* The faster NUW version doesn't work when InstanceID == UINT_MAX.
64           * Such InstanceID might not be achievable in a reasonable time though.
65           */
66          index = fast_udiv_nuw(b, instance_id, divisor);
67       }
68 
69       nir_def *start_instance = nir_load_base_instance(b);
70       return nir_iadd(b, index, start_instance);
71    } else {
72       nir_def *vertex_id = nir_load_vertex_id_zero_base(b);
73       nir_def *base_vertex = nir_load_first_vertex(b);
74 
75       return nir_iadd(b, vertex_id, base_vertex);
76    }
77 }
78 
79 static void
get_vertex_index_for_all_inputs(nir_shader * nir,struct lower_vs_inputs_state * s)80 get_vertex_index_for_all_inputs(nir_shader *nir, struct lower_vs_inputs_state *s)
81 {
82    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
83 
84    nir_builder builder = nir_builder_at(nir_before_impl(impl));
85    nir_builder *b = &builder;
86 
87    const struct si_shader_selector *sel = s->shader->selector;
88    const union si_shader_key *key = &s->shader->key;
89 
90    if (key->ge.mono.instance_divisor_is_fetched) {
91       s->instance_divisor_constbuf =
92          si_nir_load_internal_binding(b, s->args, SI_VS_CONST_INSTANCE_DIVISORS, 4);
93    }
94 
95    for (int i = 0; i < sel->info.num_inputs; i++)
96       s->vertex_index[i] = get_vertex_index(b, i, s);
97 }
98 
99 static void
load_vs_input_from_blit_sgpr(nir_builder * b,unsigned input_index,struct lower_vs_inputs_state * s,nir_def * out[4])100 load_vs_input_from_blit_sgpr(nir_builder *b, unsigned input_index,
101                              struct lower_vs_inputs_state *s,
102                              nir_def *out[4])
103 {
104    nir_def *vertex_id = nir_load_vertex_id_zero_base(b);
105    nir_def *sel_x1 = nir_ule_imm(b, vertex_id, 1);
106    /* Use nir_ine, because we have 3 vertices and only
107     * the middle one should use y2.
108     */
109    nir_def *sel_y1 = nir_ine_imm(b, vertex_id, 1);
110 
111    if (input_index == 0) {
112       /* Position: */
113       nir_def *x1y1 = ac_nir_load_arg_at_offset(b, &s->args->ac, s->args->vs_blit_inputs, 0);
114       nir_def *x2y2 = ac_nir_load_arg_at_offset(b, &s->args->ac, s->args->vs_blit_inputs, 1);
115 
116       x1y1 = nir_i2i32(b, nir_unpack_32_2x16(b, x1y1));
117       x2y2 = nir_i2i32(b, nir_unpack_32_2x16(b, x2y2));
118 
119       nir_def *x1 = nir_channel(b, x1y1, 0);
120       nir_def *y1 = nir_channel(b, x1y1, 1);
121       nir_def *x2 = nir_channel(b, x2y2, 0);
122       nir_def *y2 = nir_channel(b, x2y2, 1);
123 
124       out[0] = nir_i2f32(b, nir_bcsel(b, sel_x1, x1, x2));
125       out[1] = nir_i2f32(b, nir_bcsel(b, sel_y1, y1, y2));
126       out[2] = ac_nir_load_arg_at_offset(b, &s->args->ac, s->args->vs_blit_inputs, 2);
127       out[3] = nir_imm_float(b, 1);
128    } else {
129       bool has_attribute_ring_address = s->shader->selector->screen->info.gfx_level >= GFX11;
130 
131       /* Color or texture coordinates: */
132       assert(input_index == 1);
133 
134       unsigned vs_blit_property = s->shader->selector->info.base.vs.blit_sgprs_amd;
135       if (vs_blit_property == SI_VS_BLIT_SGPRS_POS_COLOR + has_attribute_ring_address) {
136          for (int i = 0; i < 4; i++)
137             out[i] = ac_nir_load_arg_at_offset(b, &s->args->ac, s->args->vs_blit_inputs, 3 + i);
138       } else {
139          assert(vs_blit_property == SI_VS_BLIT_SGPRS_POS_TEXCOORD + has_attribute_ring_address);
140 
141          nir_def *x1 = ac_nir_load_arg_at_offset(b, &s->args->ac, s->args->vs_blit_inputs, 3);
142          nir_def *y1 = ac_nir_load_arg_at_offset(b, &s->args->ac, s->args->vs_blit_inputs, 4);
143          nir_def *x2 = ac_nir_load_arg_at_offset(b, &s->args->ac, s->args->vs_blit_inputs, 5);
144          nir_def *y2 = ac_nir_load_arg_at_offset(b, &s->args->ac, s->args->vs_blit_inputs, 6);
145 
146          out[0] = nir_bcsel(b, sel_x1, x1, x2);
147          out[1] = nir_bcsel(b, sel_y1, y1, y2);
148          out[2] = ac_nir_load_arg_at_offset(b, &s->args->ac, s->args->vs_blit_inputs, 7);
149          out[3] = ac_nir_load_arg_at_offset(b, &s->args->ac, s->args->vs_blit_inputs, 8);
150       }
151    }
152 }
153 
154 /**
155  * Convert an 11- or 10-bit unsigned floating point number to an f32.
156  *
157  * The input exponent is expected to be biased analogous to IEEE-754, i.e. by
158  * 2^(exp_bits-1) - 1 (as defined in OpenGL and other graphics APIs).
159  */
160 static nir_def *
ufN_to_float(nir_builder * b,nir_def * src,unsigned exp_bits,unsigned mant_bits)161 ufN_to_float(nir_builder *b, nir_def *src, unsigned exp_bits, unsigned mant_bits)
162 {
163    assert(src->bit_size == 32);
164 
165    nir_def *mantissa = nir_iand_imm(b, src, (1 << mant_bits) - 1);
166 
167    /* Converting normal numbers is just a shift + correcting the exponent bias */
168    unsigned normal_shift = 23 - mant_bits;
169    unsigned bias_shift = 127 - ((1 << (exp_bits - 1)) - 1);
170 
171    nir_def *shifted = nir_ishl_imm(b, src, normal_shift);
172    nir_def *normal = nir_iadd_imm(b, shifted, bias_shift << 23);
173 
174    /* Converting nan/inf numbers is the same, but with a different exponent update */
175    nir_def *naninf = nir_ior_imm(b, normal, 0xff << 23);
176 
177    /* Converting denormals is the complex case: determine the leading zeros of the
178     * mantissa to obtain the correct shift for the mantissa and exponent correction.
179     */
180    nir_def *ctlz = nir_uclz(b, mantissa);
181    /* Shift such that the leading 1 ends up as the LSB of the exponent field. */
182    nir_def *denormal = nir_ishl(b, mantissa, nir_iadd_imm(b, ctlz, -8));
183 
184    unsigned denormal_exp = bias_shift + (32 - mant_bits) - 1;
185    nir_def *tmp = nir_isub_imm(b, denormal_exp, ctlz);
186    denormal = nir_iadd(b, denormal, nir_ishl_imm(b, tmp, 23));
187 
188    /* Select the final result. */
189    nir_def *cond = nir_uge_imm(b, src, ((1ULL << exp_bits) - 1) << mant_bits);
190    nir_def *result = nir_bcsel(b, cond, naninf, normal);
191 
192    cond = nir_uge_imm(b, src, 1ULL << mant_bits);
193    result = nir_bcsel(b, cond, result, denormal);
194 
195    cond = nir_ine_imm(b, src, 0);
196    result = nir_bcsel(b, cond, result, nir_imm_int(b, 0));
197 
198    return result;
199 }
200 
201 /**
202  * Generate a fully general open coded buffer format fetch with all required
203  * fixups suitable for vertex fetch, using non-format buffer loads.
204  *
205  * Some combinations of argument values have special interpretations:
206  * - size = 8 bytes, format = fixed indicates PIPE_FORMAT_R11G11B10_FLOAT
207  * - size = 8 bytes, format != {float,fixed} indicates a 2_10_10_10 data format
208  */
209 static void
opencoded_load_format(nir_builder * b,nir_def * rsrc,nir_def * vindex,union si_vs_fix_fetch fix_fetch,bool known_aligned,enum amd_gfx_level gfx_level,nir_def * out[4])210 opencoded_load_format(nir_builder *b, nir_def *rsrc, nir_def *vindex,
211                       union si_vs_fix_fetch fix_fetch, bool known_aligned,
212                       enum amd_gfx_level gfx_level, nir_def *out[4])
213 {
214    unsigned log_size = fix_fetch.u.log_size;
215    unsigned num_channels = fix_fetch.u.num_channels_m1 + 1;
216    unsigned format = fix_fetch.u.format;
217    bool reverse = fix_fetch.u.reverse;
218 
219    unsigned load_log_size = log_size;
220    unsigned load_num_channels = num_channels;
221    if (log_size == 3) {
222       load_log_size = 2;
223       if (format == AC_FETCH_FORMAT_FLOAT) {
224          load_num_channels = 2 * num_channels;
225       } else {
226          load_num_channels = 1; /* 10_11_11 or 2_10_10_10 */
227       }
228    }
229 
230    int log_recombine = 0;
231    if ((gfx_level == GFX6 || gfx_level >= GFX10) && !known_aligned) {
232       /* Avoid alignment restrictions by loading one byte at a time. */
233       load_num_channels <<= load_log_size;
234       log_recombine = load_log_size;
235       load_log_size = 0;
236    } else if (load_num_channels == 2 || load_num_channels == 4) {
237       log_recombine = -util_logbase2(load_num_channels);
238       load_num_channels = 1;
239       load_log_size += -log_recombine;
240    }
241 
242    nir_def *loads[32]; /* up to 32 bytes */
243    for (unsigned i = 0; i < load_num_channels; ++i) {
244       nir_def *soffset = nir_imm_int(b, i << load_log_size);
245       unsigned num_channels = 1 << (MAX2(load_log_size, 2) - 2);
246       unsigned bit_size = 8 << MIN2(load_log_size, 2);
247       nir_def *zero = nir_imm_int(b, 0);
248 
249       loads[i] = nir_load_buffer_amd(b, num_channels, bit_size, rsrc, zero, soffset, vindex);
250    }
251 
252    if (log_recombine > 0) {
253       /* Recombine bytes if necessary (GFX6 only) */
254       unsigned dst_bitsize = log_recombine == 2 ? 32 : 16;
255 
256       for (unsigned src = 0, dst = 0; src < load_num_channels; ++dst) {
257          nir_def *accum = NULL;
258          for (unsigned i = 0; i < (1 << log_recombine); ++i, ++src) {
259             nir_def *tmp = nir_u2uN(b, loads[src], dst_bitsize);
260             if (i == 0) {
261                accum = tmp;
262             } else {
263                tmp = nir_ishl_imm(b, tmp, 8 * i);
264                accum = nir_ior(b, accum, tmp);
265             }
266          }
267          loads[dst] = accum;
268       }
269    } else if (log_recombine < 0) {
270       /* Split vectors of dwords */
271       if (load_log_size > 2) {
272          assert(load_num_channels == 1);
273          nir_def *loaded = loads[0];
274          unsigned log_split = load_log_size - 2;
275          log_recombine += log_split;
276          load_num_channels = 1 << log_split;
277          load_log_size = 2;
278          for (unsigned i = 0; i < load_num_channels; ++i)
279             loads[i] = nir_channel(b, loaded, i);
280       }
281 
282       /* Further split dwords and shorts if required */
283       if (log_recombine < 0) {
284          for (unsigned src = load_num_channels, dst = load_num_channels << -log_recombine;
285               src > 0; --src) {
286             unsigned dst_bits = 1 << (3 + load_log_size + log_recombine);
287             nir_def *loaded = loads[src - 1];
288             for (unsigned i = 1 << -log_recombine; i > 0; --i, --dst) {
289                nir_def *tmp = nir_ushr_imm(b, loaded, dst_bits * (i - 1));
290                loads[dst - 1] = nir_u2uN(b, tmp, dst_bits);
291             }
292          }
293       }
294    }
295 
296    if (log_size == 3) {
297       switch (format) {
298       case AC_FETCH_FORMAT_FLOAT: {
299          for (unsigned i = 0; i < num_channels; ++i)
300             loads[i] = nir_pack_64_2x32_split(b, loads[2 * i], loads[2 * i + 1]);
301          break;
302       }
303       case AC_FETCH_FORMAT_FIXED: {
304          /* 10_11_11_FLOAT */
305          nir_def *data = loads[0];
306          nir_def *red = nir_iand_imm(b, data, 2047);
307          nir_def *green = nir_iand_imm(b, nir_ushr_imm(b, data, 11), 2047);
308          nir_def *blue = nir_ushr_imm(b, data, 22);
309 
310          loads[0] = ufN_to_float(b, red, 5, 6);
311          loads[1] = ufN_to_float(b, green, 5, 6);
312          loads[2] = ufN_to_float(b, blue, 5, 5);
313 
314          num_channels = 3;
315          log_size = 2;
316          format = AC_FETCH_FORMAT_FLOAT;
317          break;
318       }
319       case AC_FETCH_FORMAT_UINT:
320       case AC_FETCH_FORMAT_UNORM:
321       case AC_FETCH_FORMAT_USCALED: {
322          /* 2_10_10_10 data formats */
323          nir_def *data = loads[0];
324 
325          loads[0] = nir_ubfe_imm(b, data, 0, 10);
326          loads[1] = nir_ubfe_imm(b, data, 10, 10);
327          loads[2] = nir_ubfe_imm(b, data, 20, 10);
328          loads[3] = nir_ubfe_imm(b, data, 30, 2);
329 
330          num_channels = 4;
331          break;
332       }
333       case AC_FETCH_FORMAT_SINT:
334       case AC_FETCH_FORMAT_SNORM:
335       case AC_FETCH_FORMAT_SSCALED: {
336          /* 2_10_10_10 data formats */
337          nir_def *data = loads[0];
338 
339          loads[0] = nir_ibfe_imm(b, data, 0, 10);
340          loads[1] = nir_ibfe_imm(b, data, 10, 10);
341          loads[2] = nir_ibfe_imm(b, data, 20, 10);
342          loads[3] = nir_ibfe_imm(b, data, 30, 2);
343 
344          num_channels = 4;
345          break;
346       }
347       default:
348          unreachable("invalid fetch format");
349          break;
350       }
351    }
352 
353    switch (format) {
354    case AC_FETCH_FORMAT_FLOAT:
355       if (log_size != 2) {
356          for (unsigned chan = 0; chan < num_channels; ++chan)
357             loads[chan] = nir_f2f32(b, loads[chan]);
358       }
359       break;
360    case AC_FETCH_FORMAT_UINT:
361       if (log_size != 2) {
362          for (unsigned chan = 0; chan < num_channels; ++chan)
363             loads[chan] = nir_u2u32(b, loads[chan]);
364       }
365       break;
366    case AC_FETCH_FORMAT_SINT:
367       if (log_size != 2) {
368          for (unsigned chan = 0; chan < num_channels; ++chan)
369             loads[chan] = nir_i2i32(b, loads[chan]);
370       }
371       break;
372    case AC_FETCH_FORMAT_USCALED:
373       for (unsigned chan = 0; chan < num_channels; ++chan)
374          loads[chan] = nir_u2f32(b, loads[chan]);
375       break;
376    case AC_FETCH_FORMAT_SSCALED:
377       for (unsigned chan = 0; chan < num_channels; ++chan)
378          loads[chan] = nir_i2f32(b, loads[chan]);
379       break;
380    case AC_FETCH_FORMAT_FIXED:
381       for (unsigned chan = 0; chan < num_channels; ++chan) {
382          nir_def *tmp = nir_i2f32(b, loads[chan]);
383          loads[chan] = nir_fmul_imm(b, tmp, 1.0 / 0x10000);
384       }
385       break;
386    case AC_FETCH_FORMAT_UNORM:
387       for (unsigned chan = 0; chan < num_channels; ++chan) {
388          /* 2_10_10_10 data formats */
389          unsigned bits = log_size == 3 ? (chan == 3 ? 2 : 10) : (8 << log_size);
390          nir_def *tmp = nir_u2f32(b, loads[chan]);
391          loads[chan] = nir_fmul_imm(b, tmp, 1.0 / (double)BITFIELD64_MASK(bits));
392       }
393       break;
394    case AC_FETCH_FORMAT_SNORM:
395       for (unsigned chan = 0; chan < num_channels; ++chan) {
396          /* 2_10_10_10 data formats */
397          unsigned bits = log_size == 3 ? (chan == 3 ? 2 : 10) : (8 << log_size);
398          nir_def *tmp = nir_i2f32(b, loads[chan]);
399          tmp = nir_fmul_imm(b, tmp, 1.0 / (double)BITFIELD64_MASK(bits - 1));
400          /* Clamp to [-1, 1] */
401          tmp = nir_fmax(b, tmp, nir_imm_float(b, -1));
402          loads[chan] = nir_fmin(b, tmp, nir_imm_float(b, 1));
403       }
404       break;
405    default:
406       unreachable("invalid fetch format");
407       break;
408    }
409 
410    while (num_channels < 4) {
411       unsigned pad_value = num_channels == 3 ? 1 : 0;
412       loads[num_channels] =
413          format == AC_FETCH_FORMAT_UINT || format == AC_FETCH_FORMAT_SINT ?
414          nir_imm_int(b, pad_value) : nir_imm_float(b, pad_value);
415       num_channels++;
416    }
417 
418    if (reverse) {
419       nir_def *tmp = loads[0];
420       loads[0] = loads[2];
421       loads[2] = tmp;
422    }
423 
424    memcpy(out, loads, 4 * sizeof(out[0]));
425 }
426 
427 static void
load_vs_input_from_vertex_buffer(nir_builder * b,unsigned input_index,struct lower_vs_inputs_state * s,unsigned bit_size,nir_def * out[4])428 load_vs_input_from_vertex_buffer(nir_builder *b, unsigned input_index,
429                                  struct lower_vs_inputs_state *s,
430                                  unsigned bit_size, nir_def *out[4])
431 {
432    const struct si_shader_selector *sel = s->shader->selector;
433    const union si_shader_key *key = &s->shader->key;
434 
435    nir_def *vb_desc;
436    if (input_index < sel->info.num_vbos_in_user_sgprs) {
437       vb_desc = ac_nir_load_arg(b, &s->args->ac, s->args->vb_descriptors[input_index]);
438    } else {
439       unsigned index = input_index - sel->info.num_vbos_in_user_sgprs;
440       nir_def *addr = ac_nir_load_arg(b, &s->args->ac, s->args->ac.vertex_buffers);
441       vb_desc = nir_load_smem_amd(b, 4, addr, nir_imm_int(b, index * 16));
442    }
443 
444    nir_def *vertex_index = s->vertex_index[input_index];
445 
446    /* Use the open-coded implementation for all loads of doubles and
447     * of dword-sized data that needs fixups. We need to insert conversion
448     * code anyway.
449     */
450    bool opencode = key->ge.mono.vs_fetch_opencode & (1 << input_index);
451    union si_vs_fix_fetch fix_fetch = key->ge.mono.vs_fix_fetch[input_index];
452    if (opencode ||
453        (fix_fetch.u.log_size == 3 && fix_fetch.u.format == AC_FETCH_FORMAT_FLOAT) ||
454        fix_fetch.u.log_size == 2) {
455       opencoded_load_format(b, vb_desc, vertex_index, fix_fetch, !opencode,
456                             sel->screen->info.gfx_level, out);
457 
458       if (bit_size == 16) {
459          if (fix_fetch.u.format == AC_FETCH_FORMAT_UINT ||
460              fix_fetch.u.format == AC_FETCH_FORMAT_SINT) {
461             for (unsigned i = 0; i < 4; i++)
462                out[i] = nir_u2u16(b, out[i]);
463          } else {
464             for (unsigned i = 0; i < 4; i++)
465                out[i] = nir_f2f16(b, out[i]);
466          }
467       }
468       return;
469    }
470 
471    unsigned required_channels = util_last_bit(sel->info.input[input_index].usage_mask);
472    if (required_channels == 0) {
473       for (unsigned i = 0; i < 4; ++i)
474          out[i] = nir_undef(b, 1, bit_size);
475       return;
476    }
477 
478    /* Do multiple loads for special formats. */
479    nir_def *fetches[4];
480    unsigned num_fetches;
481    unsigned fetch_stride;
482    unsigned channels_per_fetch;
483 
484    if (fix_fetch.u.log_size <= 1 && fix_fetch.u.num_channels_m1 == 2) {
485       num_fetches = MIN2(required_channels, 3);
486       fetch_stride = 1 << fix_fetch.u.log_size;
487       channels_per_fetch = 1;
488    } else {
489       num_fetches = 1;
490       fetch_stride = 0;
491       channels_per_fetch = required_channels;
492    }
493 
494    for (unsigned i = 0; i < num_fetches; ++i) {
495       nir_def *zero = nir_imm_int(b, 0);
496       fetches[i] = nir_load_buffer_amd(b, channels_per_fetch, bit_size, vb_desc,
497                                        zero, zero, vertex_index,
498                                        .base = fetch_stride * i,
499                                        .access = ACCESS_USES_FORMAT_AMD);
500    }
501 
502    if (num_fetches == 1 && channels_per_fetch > 1) {
503       nir_def *fetch = fetches[0];
504       for (unsigned i = 0; i < channels_per_fetch; ++i)
505          fetches[i] = nir_channel(b, fetch, i);
506 
507       num_fetches = channels_per_fetch;
508       channels_per_fetch = 1;
509    }
510 
511    for (unsigned i = num_fetches; i < 4; ++i)
512       fetches[i] = nir_undef(b, 1, bit_size);
513 
514    if (fix_fetch.u.log_size <= 1 && fix_fetch.u.num_channels_m1 == 2 && required_channels == 4) {
515       if (fix_fetch.u.format == AC_FETCH_FORMAT_UINT || fix_fetch.u.format == AC_FETCH_FORMAT_SINT)
516          fetches[3] = nir_imm_intN_t(b, 1, bit_size);
517       else
518          fetches[3] = nir_imm_floatN_t(b, 1, bit_size);
519    } else if (fix_fetch.u.log_size == 3 &&
520               (fix_fetch.u.format == AC_FETCH_FORMAT_SNORM ||
521                fix_fetch.u.format == AC_FETCH_FORMAT_SSCALED ||
522                fix_fetch.u.format == AC_FETCH_FORMAT_SINT) &&
523               required_channels == 4) {
524 
525       /* For 2_10_10_10, the hardware returns an unsigned value;
526        * convert it to a signed one.
527        */
528       nir_def *tmp = fetches[3];
529 
530       /* First, recover the sign-extended signed integer value. */
531       if (fix_fetch.u.format == AC_FETCH_FORMAT_SSCALED)
532          tmp = nir_f2uN(b, tmp, bit_size);
533 
534       /* For the integer-like cases, do a natural sign extension.
535        *
536        * For the SNORM case, the values are 0.0, 0.333, 0.666, 1.0
537        * and happen to contain 0, 1, 2, 3 as the two LSBs of the
538        * exponent.
539        */
540       tmp = nir_ishl_imm(b, tmp, fix_fetch.u.format == AC_FETCH_FORMAT_SNORM ? 7 : 30);
541       tmp = nir_ishr_imm(b, tmp, 30);
542 
543       /* Convert back to the right type. */
544       if (fix_fetch.u.format == AC_FETCH_FORMAT_SNORM) {
545          tmp = nir_i2fN(b, tmp, bit_size);
546          /* Clamp to [-1, 1] */
547          tmp = nir_fmax(b, tmp, nir_imm_float(b, -1));
548          tmp = nir_fmin(b, tmp, nir_imm_float(b, 1));
549       } else if (fix_fetch.u.format == AC_FETCH_FORMAT_SSCALED) {
550          tmp = nir_i2fN(b, tmp, bit_size);
551       }
552 
553       fetches[3] = tmp;
554    }
555 
556    memcpy(out, fetches, 4 * sizeof(out[0]));
557 }
558 
559 static bool
lower_vs_input_instr(nir_builder * b,nir_intrinsic_instr * intrin,void * state)560 lower_vs_input_instr(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
561 {
562    if (intrin->intrinsic != nir_intrinsic_load_input)
563       return false;
564 
565    struct lower_vs_inputs_state *s = (struct lower_vs_inputs_state *)state;
566 
567    b->cursor = nir_before_instr(&intrin->instr);
568 
569    unsigned input_index = nir_intrinsic_base(intrin);
570    unsigned component = nir_intrinsic_component(intrin);
571    unsigned num_components = intrin->def.num_components;
572 
573    nir_def *comp[4];
574    if (s->shader->selector->info.base.vs.blit_sgprs_amd)
575       load_vs_input_from_blit_sgpr(b, input_index, s, comp);
576    else
577       load_vs_input_from_vertex_buffer(b, input_index, s, intrin->def.bit_size, comp);
578 
579    nir_def *replacement = nir_vec(b, &comp[component], num_components);
580 
581    nir_def_replace(&intrin->def, replacement);
582    nir_instr_free(&intrin->instr);
583 
584    return true;
585 }
586 
587 bool
si_nir_lower_vs_inputs(nir_shader * nir,struct si_shader * shader,struct si_shader_args * args)588 si_nir_lower_vs_inputs(nir_shader *nir, struct si_shader *shader, struct si_shader_args *args)
589 {
590    const struct si_shader_selector *sel = shader->selector;
591 
592    /* no inputs to lower */
593    if (!sel->info.num_inputs)
594       return false;
595 
596    struct lower_vs_inputs_state state = {
597       .shader = shader,
598       .args = args,
599    };
600 
601    if (!sel->info.base.vs.blit_sgprs_amd)
602       get_vertex_index_for_all_inputs(nir, &state);
603 
604    return nir_shader_intrinsics_pass(nir, lower_vs_input_instr,
605                                        nir_metadata_control_flow,
606                                        &state);
607 }
608