xref: /aosp_15_r20/external/mesa3d/src/microsoft/compiler/dxil_nir.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "dxil_nir.h"
25 #include "dxil_module.h"
26 
27 #include "nir_builder.h"
28 #include "nir_deref.h"
29 #include "nir_worklist.h"
30 #include "nir_to_dxil.h"
31 #include "util/u_math.h"
32 #include "vulkan/vulkan_core.h"
33 
34 static void
cl_type_size_align(const struct glsl_type * type,unsigned * size,unsigned * align)35 cl_type_size_align(const struct glsl_type *type, unsigned *size,
36                    unsigned *align)
37 {
38    *size = glsl_get_cl_size(type);
39    *align = glsl_get_cl_alignment(type);
40 }
41 
42 static nir_def *
load_comps_to_vec(nir_builder * b,unsigned src_bit_size,nir_def ** src_comps,unsigned num_src_comps,unsigned dst_bit_size)43 load_comps_to_vec(nir_builder *b, unsigned src_bit_size,
44                   nir_def **src_comps, unsigned num_src_comps,
45                   unsigned dst_bit_size)
46 {
47    if (src_bit_size == dst_bit_size)
48       return nir_vec(b, src_comps, num_src_comps);
49    else if (src_bit_size > dst_bit_size)
50       return nir_extract_bits(b, src_comps, num_src_comps, 0, src_bit_size * num_src_comps / dst_bit_size, dst_bit_size);
51 
52    unsigned num_dst_comps = DIV_ROUND_UP(num_src_comps * src_bit_size, dst_bit_size);
53    unsigned comps_per_dst = dst_bit_size / src_bit_size;
54    nir_def *dst_comps[4];
55 
56    for (unsigned i = 0; i < num_dst_comps; i++) {
57       unsigned src_offs = i * comps_per_dst;
58 
59       dst_comps[i] = nir_u2uN(b, src_comps[src_offs], dst_bit_size);
60       for (unsigned j = 1; j < comps_per_dst && src_offs + j < num_src_comps; j++) {
61          nir_def *tmp = nir_ishl_imm(b, nir_u2uN(b, src_comps[src_offs + j], dst_bit_size),
62                                          j * src_bit_size);
63          dst_comps[i] = nir_ior(b, dst_comps[i], tmp);
64       }
65    }
66 
67    return nir_vec(b, dst_comps, num_dst_comps);
68 }
69 
70 static bool
lower_32b_offset_load(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)71 lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
72 {
73    unsigned bit_size = intr->def.bit_size;
74    unsigned num_components = intr->def.num_components;
75    unsigned num_bits = num_components * bit_size;
76 
77    b->cursor = nir_before_instr(&intr->instr);
78 
79    nir_def *offset = intr->src[0].ssa;
80    if (intr->intrinsic == nir_intrinsic_load_shared)
81       offset = nir_iadd_imm(b, offset, nir_intrinsic_base(intr));
82    else
83       offset = nir_u2u32(b, offset);
84    nir_def *index = nir_ushr_imm(b, offset, 2);
85    nir_def *comps[NIR_MAX_VEC_COMPONENTS];
86    nir_def *comps_32bit[NIR_MAX_VEC_COMPONENTS * 2];
87 
88    /* We need to split loads in 32-bit accesses because the buffer
89     * is an i32 array and DXIL does not support type casts.
90     */
91    unsigned num_32bit_comps = DIV_ROUND_UP(num_bits, 32);
92    for (unsigned i = 0; i < num_32bit_comps; i++)
93       comps_32bit[i] = nir_load_array_var(b, var, nir_iadd_imm(b, index, i));
94    unsigned num_comps_per_pass = MIN2(num_32bit_comps, 4);
95 
96    for (unsigned i = 0; i < num_32bit_comps; i += num_comps_per_pass) {
97       unsigned num_vec32_comps = MIN2(num_32bit_comps - i, 4);
98       unsigned num_dest_comps = num_vec32_comps * 32 / bit_size;
99       nir_def *vec32 = nir_vec(b, &comps_32bit[i], num_vec32_comps);
100 
101       /* If we have 16 bits or less to load we need to adjust the u32 value so
102        * we can always extract the LSB.
103        */
104       if (num_bits <= 16) {
105          nir_def *shift =
106             nir_imul_imm(b, nir_iand_imm(b, offset, 3), 8);
107          vec32 = nir_ushr(b, vec32, shift);
108       }
109 
110       /* And now comes the pack/unpack step to match the original type. */
111       unsigned dest_index = i * 32 / bit_size;
112       nir_def *temp_vec = nir_extract_bits(b, &vec32, 1, 0, num_dest_comps, bit_size);
113       for (unsigned comp = 0; comp < num_dest_comps; ++comp, ++dest_index)
114          comps[dest_index] = nir_channel(b, temp_vec, comp);
115    }
116 
117    nir_def *result = nir_vec(b, comps, num_components);
118    nir_def_replace(&intr->def, result);
119 
120    return true;
121 }
122 
123 static void
lower_masked_store_vec32(nir_builder * b,nir_def * offset,nir_def * index,nir_def * vec32,unsigned num_bits,nir_variable * var,unsigned alignment)124 lower_masked_store_vec32(nir_builder *b, nir_def *offset, nir_def *index,
125                          nir_def *vec32, unsigned num_bits, nir_variable *var, unsigned alignment)
126 {
127    nir_def *mask = nir_imm_int(b, (1 << num_bits) - 1);
128 
129    /* If we have small alignments, we need to place them correctly in the u32 component. */
130    if (alignment <= 2) {
131       nir_def *shift =
132          nir_imul_imm(b, nir_iand_imm(b, offset, 3), 8);
133 
134       vec32 = nir_ishl(b, vec32, shift);
135       mask = nir_ishl(b, mask, shift);
136    }
137 
138    if (var->data.mode == nir_var_mem_shared) {
139       /* Use the dedicated masked intrinsic */
140       nir_deref_instr *deref = nir_build_deref_array(b, nir_build_deref_var(b, var), index);
141       nir_deref_atomic(b, 32, &deref->def, nir_inot(b, mask), .atomic_op = nir_atomic_op_iand);
142       nir_deref_atomic(b, 32, &deref->def, vec32, .atomic_op = nir_atomic_op_ior);
143    } else {
144       /* For scratch, since we don't need atomics, just generate the read-modify-write in NIR */
145       nir_def *load = nir_load_array_var(b, var, index);
146 
147       nir_def *new_val = nir_ior(b, vec32,
148                                      nir_iand(b,
149                                               nir_inot(b, mask),
150                                               load));
151 
152       nir_store_array_var(b, var, index, new_val, 1);
153    }
154 }
155 
156 static bool
lower_32b_offset_store(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)157 lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
158 {
159    unsigned num_components = nir_src_num_components(intr->src[0]);
160    unsigned bit_size = nir_src_bit_size(intr->src[0]);
161    unsigned num_bits = num_components * bit_size;
162 
163    b->cursor = nir_before_instr(&intr->instr);
164 
165    nir_def *offset = intr->src[1].ssa;
166    if (intr->intrinsic == nir_intrinsic_store_shared)
167       offset = nir_iadd_imm(b, offset, nir_intrinsic_base(intr));
168    else
169       offset = nir_u2u32(b, offset);
170    nir_def *comps[NIR_MAX_VEC_COMPONENTS];
171 
172    unsigned comp_idx = 0;
173    for (unsigned i = 0; i < num_components; i++)
174       comps[i] = nir_channel(b, intr->src[0].ssa, i);
175 
176    unsigned step = MAX2(bit_size, 32);
177    for (unsigned i = 0; i < num_bits; i += step) {
178       /* For each 4byte chunk (or smaller) we generate a 32bit scalar store. */
179       unsigned substore_num_bits = MIN2(num_bits - i, step);
180       nir_def *local_offset = nir_iadd_imm(b, offset, i / 8);
181       nir_def *vec32 = load_comps_to_vec(b, bit_size, &comps[comp_idx],
182                                              substore_num_bits / bit_size, 32);
183       nir_def *index = nir_ushr_imm(b, local_offset, 2);
184 
185       /* For anything less than 32bits we need to use the masked version of the
186        * intrinsic to preserve data living in the same 32bit slot. */
187       if (substore_num_bits < 32) {
188          lower_masked_store_vec32(b, local_offset, index, vec32, num_bits, var, nir_intrinsic_align(intr));
189       } else {
190          for (unsigned i = 0; i < vec32->num_components; ++i)
191             nir_store_array_var(b, var, nir_iadd_imm(b, index, i), nir_channel(b, vec32, i), 1);
192       }
193 
194       comp_idx += substore_num_bits / bit_size;
195    }
196 
197    nir_instr_remove(&intr->instr);
198 
199    return true;
200 }
201 
202 #define CONSTANT_LOCATION_UNVISITED 0
203 #define CONSTANT_LOCATION_VALID 1
204 #define CONSTANT_LOCATION_INVALID 2
205 
206 bool
dxil_nir_lower_constant_to_temp(nir_shader * nir)207 dxil_nir_lower_constant_to_temp(nir_shader *nir)
208 {
209    bool progress = false;
210    nir_foreach_variable_with_modes(var, nir, nir_var_mem_constant)
211       var->data.location = var->constant_initializer ?
212          CONSTANT_LOCATION_UNVISITED : CONSTANT_LOCATION_INVALID;
213 
214    /* First pass: collect all UBO accesses that could be turned into
215     * shader temp accesses.
216     */
217    nir_foreach_function(func, nir) {
218       if (!func->is_entrypoint)
219          continue;
220       assert(func->impl);
221 
222       nir_foreach_block(block, func->impl) {
223          nir_foreach_instr_safe(instr, block) {
224             if (instr->type != nir_instr_type_deref)
225                continue;
226 
227             nir_deref_instr *deref = nir_instr_as_deref(instr);
228             if (!nir_deref_mode_is(deref, nir_var_mem_constant) ||
229                 deref->deref_type != nir_deref_type_var ||
230                 deref->var->data.location == CONSTANT_LOCATION_INVALID)
231                continue;
232 
233             deref->var->data.location = nir_deref_instr_has_complex_use(deref, 0) ?
234                CONSTANT_LOCATION_INVALID : CONSTANT_LOCATION_VALID;
235          }
236       }
237    }
238 
239    nir_foreach_variable_with_modes(var, nir, nir_var_mem_constant) {
240       if (var->data.location != CONSTANT_LOCATION_VALID)
241          continue;
242 
243       /* Change the variable mode. */
244       var->data.mode = nir_var_shader_temp;
245 
246       progress = true;
247    }
248 
249    /* Second pass: patch all derefs that were accessing the converted UBOs
250     * variables.
251     */
252    nir_foreach_function(func, nir) {
253       if (!func->is_entrypoint)
254          continue;
255       assert(func->impl);
256 
257       nir_builder b = nir_builder_create(func->impl);
258       nir_foreach_block(block, func->impl) {
259          nir_foreach_instr_safe(instr, block) {
260             if (instr->type != nir_instr_type_deref)
261                continue;
262 
263             nir_deref_instr *deref = nir_instr_as_deref(instr);
264             if (nir_deref_mode_is(deref, nir_var_mem_constant)) {
265                nir_deref_instr *parent = deref;
266                while (parent && parent->deref_type != nir_deref_type_var)
267                   parent = nir_src_as_deref(parent->parent);
268                if (parent && parent->var->data.mode != nir_var_mem_constant) {
269                   deref->modes = parent->var->data.mode;
270                   /* Also change "pointer" size to 32-bit since this is now a logical pointer */
271                   deref->def.bit_size = 32;
272                   if (deref->deref_type == nir_deref_type_array) {
273                      b.cursor = nir_before_instr(instr);
274                      nir_src_rewrite(&deref->arr.index, nir_u2u32(&b, deref->arr.index.ssa));
275                   }
276                }
277             }
278          }
279       }
280    }
281 
282    return progress;
283 }
284 
285 static bool
flatten_var_arrays(nir_builder * b,nir_intrinsic_instr * intr,void * data)286 flatten_var_arrays(nir_builder *b, nir_intrinsic_instr *intr, void *data)
287 {
288    switch (intr->intrinsic) {
289    case nir_intrinsic_load_deref:
290    case nir_intrinsic_store_deref:
291    case nir_intrinsic_deref_atomic:
292    case nir_intrinsic_deref_atomic_swap:
293       break;
294    default:
295       return false;
296    }
297 
298    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
299    nir_variable *var = NULL;
300    for (nir_deref_instr *d = deref; d; d = nir_deref_instr_parent(d)) {
301       if (d->deref_type == nir_deref_type_cast)
302          return false;
303       if (d->deref_type == nir_deref_type_var) {
304          var = d->var;
305          if (d->type == var->type)
306             return false;
307       }
308    }
309    if (!var)
310       return false;
311 
312    nir_deref_path path;
313    nir_deref_path_init(&path, deref, NULL);
314 
315    assert(path.path[0]->deref_type == nir_deref_type_var);
316    b->cursor = nir_before_instr(&path.path[0]->instr);
317    nir_deref_instr *new_var_deref = nir_build_deref_var(b, var);
318    nir_def *index = NULL;
319    for (unsigned level = 1; path.path[level]; ++level) {
320       nir_deref_instr *arr_deref = path.path[level];
321       assert(arr_deref->deref_type == nir_deref_type_array);
322       b->cursor = nir_before_instr(&arr_deref->instr);
323       nir_def *val = nir_imul_imm(b, arr_deref->arr.index.ssa,
324                                       glsl_get_component_slots(arr_deref->type));
325       if (index) {
326          index = nir_iadd(b, index, val);
327       } else {
328          index = val;
329       }
330    }
331 
332    unsigned vector_comps = intr->num_components;
333    if (vector_comps > 1) {
334       b->cursor = nir_before_instr(&intr->instr);
335       if (intr->intrinsic == nir_intrinsic_load_deref) {
336          nir_def *components[NIR_MAX_VEC_COMPONENTS];
337          for (unsigned i = 0; i < vector_comps; ++i) {
338             nir_def *final_index = index ? nir_iadd_imm(b, index, i) : nir_imm_int(b, i);
339             nir_deref_instr *comp_deref = nir_build_deref_array(b, new_var_deref, final_index);
340             components[i] = nir_load_deref(b, comp_deref);
341          }
342          nir_def_rewrite_uses(&intr->def, nir_vec(b, components, vector_comps));
343       } else if (intr->intrinsic == nir_intrinsic_store_deref) {
344          for (unsigned i = 0; i < vector_comps; ++i) {
345             if (((1 << i) & nir_intrinsic_write_mask(intr)) == 0)
346                continue;
347             nir_def *final_index = index ? nir_iadd_imm(b, index, i) : nir_imm_int(b, i);
348             nir_deref_instr *comp_deref = nir_build_deref_array(b, new_var_deref, final_index);
349             nir_store_deref(b, comp_deref, nir_channel(b, intr->src[1].ssa, i), 1);
350          }
351       }
352       nir_instr_remove(&intr->instr);
353    } else {
354       nir_src_rewrite(&intr->src[0], &nir_build_deref_array(b, new_var_deref, index)->def);
355    }
356 
357    nir_deref_path_finish(&path);
358    return true;
359 }
360 
361 static void
flatten_constant_initializer(nir_variable * var,nir_constant * src,nir_constant *** dest,unsigned vector_elements)362 flatten_constant_initializer(nir_variable *var, nir_constant *src, nir_constant ***dest, unsigned vector_elements)
363 {
364    if (src->num_elements == 0) {
365       for (unsigned i = 0; i < vector_elements; ++i) {
366          nir_constant *new_scalar = rzalloc(var, nir_constant);
367          memcpy(&new_scalar->values[0], &src->values[i], sizeof(src->values[0]));
368          new_scalar->is_null_constant = src->values[i].u64 == 0;
369 
370          nir_constant **array_entry = (*dest)++;
371          *array_entry = new_scalar;
372       }
373    } else {
374       for (unsigned i = 0; i < src->num_elements; ++i)
375          flatten_constant_initializer(var, src->elements[i], dest, vector_elements);
376    }
377 }
378 
379 static bool
flatten_var_array_types(nir_variable * var)380 flatten_var_array_types(nir_variable *var)
381 {
382    assert(!glsl_type_is_struct(glsl_without_array(var->type)));
383    const struct glsl_type *matrix_type = glsl_without_array(var->type);
384    if (!glsl_type_is_array_of_arrays(var->type) && glsl_get_components(matrix_type) == 1)
385       return false;
386 
387    enum glsl_base_type base_type = glsl_get_base_type(matrix_type);
388    const struct glsl_type *flattened_type = glsl_array_type(glsl_scalar_type(base_type),
389                                                             glsl_get_component_slots(var->type), 0);
390    var->type = flattened_type;
391    if (var->constant_initializer) {
392       nir_constant **new_elements = ralloc_array(var, nir_constant *, glsl_get_length(flattened_type));
393       nir_constant **temp = new_elements;
394       flatten_constant_initializer(var, var->constant_initializer, &temp, glsl_get_vector_elements(matrix_type));
395       var->constant_initializer->num_elements = glsl_get_length(flattened_type);
396       var->constant_initializer->elements = new_elements;
397    }
398    return true;
399 }
400 
401 bool
dxil_nir_flatten_var_arrays(nir_shader * shader,nir_variable_mode modes)402 dxil_nir_flatten_var_arrays(nir_shader *shader, nir_variable_mode modes)
403 {
404    bool progress = false;
405    nir_foreach_variable_with_modes(var, shader, modes & ~nir_var_function_temp)
406       progress |= flatten_var_array_types(var);
407 
408    if (modes & nir_var_function_temp) {
409       nir_foreach_function_impl(impl, shader) {
410          nir_foreach_function_temp_variable(var, impl)
411             progress |= flatten_var_array_types(var);
412       }
413    }
414 
415    if (!progress)
416       return false;
417 
418    nir_shader_intrinsics_pass(shader, flatten_var_arrays,
419                                 nir_metadata_control_flow |
420                                 nir_metadata_loop_analysis,
421                                 NULL);
422    nir_remove_dead_derefs(shader);
423    return true;
424 }
425 
426 static bool
lower_deref_bit_size(nir_builder * b,nir_intrinsic_instr * intr,void * data)427 lower_deref_bit_size(nir_builder *b, nir_intrinsic_instr *intr, void *data)
428 {
429    switch (intr->intrinsic) {
430    case nir_intrinsic_load_deref:
431    case nir_intrinsic_store_deref:
432       break;
433    default:
434       /* Atomics can't be smaller than 32-bit */
435       return false;
436    }
437 
438    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
439    nir_variable *var = nir_deref_instr_get_variable(deref);
440    /* Only interested in full deref chains */
441    if (!var)
442       return false;
443 
444    const struct glsl_type *var_scalar_type = glsl_without_array(var->type);
445    if (deref->type == var_scalar_type || !glsl_type_is_scalar(var_scalar_type))
446       return false;
447 
448    assert(deref->deref_type == nir_deref_type_var || deref->deref_type == nir_deref_type_array);
449    const struct glsl_type *old_glsl_type = deref->type;
450    nir_alu_type old_type = nir_get_nir_type_for_glsl_type(old_glsl_type);
451    nir_alu_type new_type = nir_get_nir_type_for_glsl_type(var_scalar_type);
452    if (glsl_get_bit_size(old_glsl_type) < glsl_get_bit_size(var_scalar_type)) {
453       deref->type = var_scalar_type;
454       if (intr->intrinsic == nir_intrinsic_load_deref) {
455          intr->def.bit_size = glsl_get_bit_size(var_scalar_type);
456          b->cursor = nir_after_instr(&intr->instr);
457          nir_def *downcast = nir_type_convert(b, &intr->def, new_type, old_type, nir_rounding_mode_undef);
458          nir_def_rewrite_uses_after(&intr->def, downcast, downcast->parent_instr);
459       }
460       else {
461          b->cursor = nir_before_instr(&intr->instr);
462          nir_def *upcast = nir_type_convert(b, intr->src[1].ssa, old_type, new_type, nir_rounding_mode_undef);
463          nir_src_rewrite(&intr->src[1], upcast);
464       }
465 
466       while (deref->deref_type == nir_deref_type_array) {
467          nir_deref_instr *parent = nir_deref_instr_parent(deref);
468          parent->type = glsl_type_wrap_in_arrays(deref->type, parent->type);
469          deref = parent;
470       }
471    } else {
472       /* Assumed arrays are already flattened */
473       b->cursor = nir_before_instr(&deref->instr);
474       nir_deref_instr *parent = nir_build_deref_var(b, var);
475       if (deref->deref_type == nir_deref_type_array)
476          deref = nir_build_deref_array(b, parent, nir_imul_imm(b, deref->arr.index.ssa, 2));
477       else
478          deref = nir_build_deref_array_imm(b, parent, 0);
479       nir_deref_instr *deref2 = nir_build_deref_array(b, parent,
480                                                       nir_iadd_imm(b, deref->arr.index.ssa, 1));
481       b->cursor = nir_before_instr(&intr->instr);
482       if (intr->intrinsic == nir_intrinsic_load_deref) {
483          nir_def *src1 = nir_load_deref(b, deref);
484          nir_def *src2 = nir_load_deref(b, deref2);
485          nir_def_rewrite_uses(&intr->def, nir_pack_64_2x32_split(b, src1, src2));
486       } else {
487          nir_def *src1 = nir_unpack_64_2x32_split_x(b, intr->src[1].ssa);
488          nir_def *src2 = nir_unpack_64_2x32_split_y(b, intr->src[1].ssa);
489          nir_store_deref(b, deref, src1, 1);
490          nir_store_deref(b, deref, src2, 1);
491       }
492       nir_instr_remove(&intr->instr);
493    }
494    return true;
495 }
496 
497 static bool
lower_var_bit_size_types(nir_variable * var,unsigned min_bit_size,unsigned max_bit_size)498 lower_var_bit_size_types(nir_variable *var, unsigned min_bit_size, unsigned max_bit_size)
499 {
500    assert(!glsl_type_is_array_of_arrays(var->type) && !glsl_type_is_struct(var->type));
501    const struct glsl_type *type = glsl_without_array(var->type);
502    assert(glsl_type_is_scalar(type));
503    enum glsl_base_type base_type = glsl_get_base_type(type);
504    if (glsl_base_type_get_bit_size(base_type) < min_bit_size) {
505       switch (min_bit_size) {
506       case 16:
507          switch (base_type) {
508          case GLSL_TYPE_BOOL:
509             base_type = GLSL_TYPE_UINT16;
510             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
511                var->constant_initializer->elements[i]->values[0].u16 = var->constant_initializer->elements[i]->values[0].b ? 0xffff : 0;
512             break;
513          case GLSL_TYPE_INT8:
514             base_type = GLSL_TYPE_INT16;
515             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
516                var->constant_initializer->elements[i]->values[0].i16 = var->constant_initializer->elements[i]->values[0].i8;
517             break;
518          case GLSL_TYPE_UINT8: base_type = GLSL_TYPE_UINT16; break;
519          default: unreachable("Unexpected base type");
520          }
521          break;
522       case 32:
523          switch (base_type) {
524          case GLSL_TYPE_BOOL:
525             base_type = GLSL_TYPE_UINT;
526             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
527                var->constant_initializer->elements[i]->values[0].u32 = var->constant_initializer->elements[i]->values[0].b ? 0xffffffff : 0;
528             break;
529          case GLSL_TYPE_INT8:
530             base_type = GLSL_TYPE_INT;
531             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
532                var->constant_initializer->elements[i]->values[0].i32 = var->constant_initializer->elements[i]->values[0].i8;
533             break;
534          case GLSL_TYPE_INT16:
535             base_type = GLSL_TYPE_INT;
536             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
537                var->constant_initializer->elements[i]->values[0].i32 = var->constant_initializer->elements[i]->values[0].i16;
538             break;
539          case GLSL_TYPE_FLOAT16:
540             base_type = GLSL_TYPE_FLOAT;
541             for (unsigned i = 0; i < (var->constant_initializer ? var->constant_initializer->num_elements : 0); ++i)
542                var->constant_initializer->elements[i]->values[0].f32 = _mesa_half_to_float(var->constant_initializer->elements[i]->values[0].u16);
543             break;
544          case GLSL_TYPE_UINT8: base_type = GLSL_TYPE_UINT; break;
545          case GLSL_TYPE_UINT16: base_type = GLSL_TYPE_UINT; break;
546          default: unreachable("Unexpected base type");
547          }
548          break;
549       default: unreachable("Unexpected min bit size");
550       }
551       var->type = glsl_type_wrap_in_arrays(glsl_scalar_type(base_type), var->type);
552       return true;
553    }
554    if (glsl_base_type_bit_size(base_type) > max_bit_size) {
555       assert(!glsl_type_is_array_of_arrays(var->type));
556       var->type = glsl_array_type(glsl_scalar_type(GLSL_TYPE_UINT),
557                                     glsl_type_is_array(var->type) ? glsl_get_length(var->type) * 2 : 2,
558                                     0);
559       if (var->constant_initializer) {
560          unsigned num_elements = var->constant_initializer->num_elements ?
561             var->constant_initializer->num_elements * 2 : 2;
562          nir_constant **element_arr = ralloc_array(var, nir_constant *, num_elements);
563          nir_constant *elements = rzalloc_array(var, nir_constant, num_elements);
564          for (unsigned i = 0; i < var->constant_initializer->num_elements; ++i) {
565             element_arr[i*2] = &elements[i*2];
566             element_arr[i*2+1] = &elements[i*2+1];
567             const nir_const_value *src = var->constant_initializer->num_elements ?
568                var->constant_initializer->elements[i]->values : var->constant_initializer->values;
569             elements[i*2].values[0].u32 = (uint32_t)src->u64;
570             elements[i*2].is_null_constant = (uint32_t)src->u64 == 0;
571             elements[i*2+1].values[0].u32 = (uint32_t)(src->u64 >> 32);
572             elements[i*2+1].is_null_constant = (uint32_t)(src->u64 >> 32) == 0;
573          }
574          var->constant_initializer->num_elements = num_elements;
575          var->constant_initializer->elements = element_arr;
576       }
577       return true;
578    }
579    return false;
580 }
581 
582 bool
dxil_nir_lower_var_bit_size(nir_shader * shader,nir_variable_mode modes,unsigned min_bit_size,unsigned max_bit_size)583 dxil_nir_lower_var_bit_size(nir_shader *shader, nir_variable_mode modes,
584                             unsigned min_bit_size, unsigned max_bit_size)
585 {
586    bool progress = false;
587    nir_foreach_variable_with_modes(var, shader, modes & ~nir_var_function_temp)
588       progress |= lower_var_bit_size_types(var, min_bit_size, max_bit_size);
589 
590    if (modes & nir_var_function_temp) {
591       nir_foreach_function_impl(impl, shader) {
592          nir_foreach_function_temp_variable(var, impl)
593             progress |= lower_var_bit_size_types(var, min_bit_size, max_bit_size);
594       }
595    }
596 
597    if (!progress)
598       return false;
599 
600    nir_shader_intrinsics_pass(shader, lower_deref_bit_size,
601                                 nir_metadata_control_flow |
602                                 nir_metadata_loop_analysis,
603                                 NULL);
604    nir_remove_dead_derefs(shader);
605    return true;
606 }
607 
608 static bool
remove_oob_array_access(nir_builder * b,nir_intrinsic_instr * intr,void * data)609 remove_oob_array_access(nir_builder *b, nir_intrinsic_instr *intr, void *data)
610 {
611    uint32_t num_derefs = 1;
612 
613    switch (intr->intrinsic) {
614    case nir_intrinsic_copy_deref:
615       num_derefs = 2;
616       FALLTHROUGH;
617    case nir_intrinsic_load_deref:
618    case nir_intrinsic_store_deref:
619    case nir_intrinsic_deref_atomic:
620    case nir_intrinsic_deref_atomic_swap:
621       break;
622    default:
623       return false;
624    }
625 
626    for (uint32_t i = 0; i < num_derefs; ++i) {
627       if (nir_deref_instr_is_known_out_of_bounds(nir_src_as_deref(intr->src[i]))) {
628          switch (intr->intrinsic) {
629          case nir_intrinsic_load_deref:
630          case nir_intrinsic_deref_atomic:
631          case nir_intrinsic_deref_atomic_swap:
632             b->cursor = nir_before_instr(&intr->instr);
633             nir_def *undef = nir_undef(b, intr->def.num_components, intr->def.bit_size);
634             nir_def_rewrite_uses(&intr->def, undef);
635             break;
636          default:
637             break;
638          }
639          nir_instr_remove(&intr->instr);
640          return true;
641       }
642    }
643 
644    return false;
645 }
646 
647 bool
dxil_nir_remove_oob_array_accesses(nir_shader * shader)648 dxil_nir_remove_oob_array_accesses(nir_shader *shader)
649 {
650    return nir_shader_intrinsics_pass(shader, remove_oob_array_access,
651                                      nir_metadata_control_flow |
652                                      nir_metadata_loop_analysis,
653                                      NULL);
654 }
655 
656 static bool
lower_shared_atomic(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)657 lower_shared_atomic(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var)
658 {
659    b->cursor = nir_before_instr(&intr->instr);
660 
661    nir_def *offset =
662       nir_iadd_imm(b, intr->src[0].ssa, nir_intrinsic_base(intr));
663    nir_def *index = nir_ushr_imm(b, offset, 2);
664 
665    nir_deref_instr *deref = nir_build_deref_array(b, nir_build_deref_var(b, var), index);
666    nir_def *result;
667    if (intr->intrinsic == nir_intrinsic_shared_atomic_swap)
668       result = nir_deref_atomic_swap(b, 32, &deref->def, intr->src[1].ssa, intr->src[2].ssa,
669                                      .atomic_op = nir_intrinsic_atomic_op(intr));
670    else
671       result = nir_deref_atomic(b, 32, &deref->def, intr->src[1].ssa,
672                                 .atomic_op = nir_intrinsic_atomic_op(intr));
673 
674    nir_def_replace(&intr->def, result);
675    return true;
676 }
677 
678 bool
dxil_nir_lower_loads_stores_to_dxil(nir_shader * nir,const struct dxil_nir_lower_loads_stores_options * options)679 dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir,
680                                     const struct dxil_nir_lower_loads_stores_options *options)
681 {
682    bool progress = nir_remove_dead_variables(nir, nir_var_function_temp | nir_var_mem_shared, NULL);
683    nir_variable *shared_var = NULL;
684    if (nir->info.shared_size) {
685       shared_var = nir_variable_create(nir, nir_var_mem_shared,
686                                        glsl_array_type(glsl_uint_type(), DIV_ROUND_UP(nir->info.shared_size, 4), 4),
687                                        "lowered_shared_mem");
688    }
689 
690    unsigned ptr_size = nir->info.cs.ptr_size;
691    if (nir->info.stage == MESA_SHADER_KERNEL) {
692       /* All the derefs created here will be used as GEP indices so force 32-bit */
693       nir->info.cs.ptr_size = 32;
694    }
695    nir_foreach_function_impl(impl, nir) {
696       nir_builder b = nir_builder_create(impl);
697 
698       nir_variable *scratch_var = NULL;
699       if (nir->scratch_size) {
700          const struct glsl_type *scratch_type = glsl_array_type(glsl_uint_type(), DIV_ROUND_UP(nir->scratch_size, 4), 4);
701          scratch_var = nir_local_variable_create(impl, scratch_type, "lowered_scratch_mem");
702       }
703 
704       nir_foreach_block(block, impl) {
705          nir_foreach_instr_safe(instr, block) {
706             if (instr->type != nir_instr_type_intrinsic)
707                continue;
708             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
709 
710             switch (intr->intrinsic) {
711             case nir_intrinsic_load_shared:
712                progress |= lower_32b_offset_load(&b, intr, shared_var);
713                break;
714             case nir_intrinsic_load_scratch:
715                progress |= lower_32b_offset_load(&b, intr, scratch_var);
716                break;
717             case nir_intrinsic_store_shared:
718                progress |= lower_32b_offset_store(&b, intr, shared_var);
719                break;
720             case nir_intrinsic_store_scratch:
721                progress |= lower_32b_offset_store(&b, intr, scratch_var);
722                break;
723             case nir_intrinsic_shared_atomic:
724             case nir_intrinsic_shared_atomic_swap:
725                progress |= lower_shared_atomic(&b, intr, shared_var);
726                break;
727             default:
728                break;
729             }
730          }
731       }
732    }
733    if (nir->info.stage == MESA_SHADER_KERNEL) {
734       nir->info.cs.ptr_size = ptr_size;
735    }
736 
737    return progress;
738 }
739 
740 static bool
lower_deref_ssbo(nir_builder * b,nir_deref_instr * deref)741 lower_deref_ssbo(nir_builder *b, nir_deref_instr *deref)
742 {
743    assert(nir_deref_mode_is(deref, nir_var_mem_ssbo));
744    assert(deref->deref_type == nir_deref_type_var ||
745           deref->deref_type == nir_deref_type_cast);
746    nir_variable *var = deref->var;
747 
748    b->cursor = nir_before_instr(&deref->instr);
749 
750    if (deref->deref_type == nir_deref_type_var) {
751       /* We turn all deref_var into deref_cast and build a pointer value based on
752        * the var binding which encodes the UAV id.
753        */
754       nir_def *ptr = nir_imm_int64(b, (uint64_t)var->data.binding << 32);
755       nir_deref_instr *deref_cast =
756          nir_build_deref_cast(b, ptr, nir_var_mem_ssbo, deref->type,
757                               glsl_get_explicit_stride(var->type));
758       nir_def_replace(&deref->def, &deref_cast->def);
759 
760       deref = deref_cast;
761       return true;
762    }
763    return false;
764 }
765 
766 bool
dxil_nir_lower_deref_ssbo(nir_shader * nir)767 dxil_nir_lower_deref_ssbo(nir_shader *nir)
768 {
769    bool progress = false;
770 
771    foreach_list_typed(nir_function, func, node, &nir->functions) {
772       if (!func->is_entrypoint)
773          continue;
774       assert(func->impl);
775 
776       nir_builder b = nir_builder_create(func->impl);
777 
778       nir_foreach_block(block, func->impl) {
779          nir_foreach_instr_safe(instr, block) {
780             if (instr->type != nir_instr_type_deref)
781                continue;
782 
783             nir_deref_instr *deref = nir_instr_as_deref(instr);
784 
785             if (!nir_deref_mode_is(deref, nir_var_mem_ssbo) ||
786                 (deref->deref_type != nir_deref_type_var &&
787                  deref->deref_type != nir_deref_type_cast))
788                continue;
789 
790             progress |= lower_deref_ssbo(&b, deref);
791          }
792       }
793    }
794 
795    return progress;
796 }
797 
798 static bool
lower_alu_deref_srcs(nir_builder * b,nir_alu_instr * alu)799 lower_alu_deref_srcs(nir_builder *b, nir_alu_instr *alu)
800 {
801    const nir_op_info *info = &nir_op_infos[alu->op];
802    bool progress = false;
803 
804    b->cursor = nir_before_instr(&alu->instr);
805 
806    for (unsigned i = 0; i < info->num_inputs; i++) {
807       nir_deref_instr *deref = nir_src_as_deref(alu->src[i].src);
808 
809       if (!deref)
810          continue;
811 
812       nir_deref_path path;
813       nir_deref_path_init(&path, deref, NULL);
814       nir_deref_instr *root_deref = path.path[0];
815       nir_deref_path_finish(&path);
816 
817       if (root_deref->deref_type != nir_deref_type_cast)
818          continue;
819 
820       nir_def *ptr =
821          nir_iadd(b, root_deref->parent.ssa,
822                      nir_build_deref_offset(b, deref, cl_type_size_align));
823       nir_src_rewrite(&alu->src[i].src, ptr);
824       progress = true;
825    }
826 
827    return progress;
828 }
829 
830 bool
dxil_nir_opt_alu_deref_srcs(nir_shader * nir)831 dxil_nir_opt_alu_deref_srcs(nir_shader *nir)
832 {
833    bool progress = false;
834 
835    foreach_list_typed(nir_function, func, node, &nir->functions) {
836       if (!func->is_entrypoint)
837          continue;
838       assert(func->impl);
839 
840       nir_builder b = nir_builder_create(func->impl);
841 
842       nir_foreach_block(block, func->impl) {
843          nir_foreach_instr_safe(instr, block) {
844             if (instr->type != nir_instr_type_alu)
845                continue;
846 
847             nir_alu_instr *alu = nir_instr_as_alu(instr);
848             progress |= lower_alu_deref_srcs(&b, alu);
849          }
850       }
851    }
852 
853    return progress;
854 }
855 
856 static void
cast_phi(nir_builder * b,nir_phi_instr * phi,unsigned new_bit_size)857 cast_phi(nir_builder *b, nir_phi_instr *phi, unsigned new_bit_size)
858 {
859    nir_phi_instr *lowered = nir_phi_instr_create(b->shader);
860    int num_components = 0;
861    int old_bit_size = phi->def.bit_size;
862 
863    nir_foreach_phi_src(src, phi) {
864       assert(num_components == 0 || num_components == src->src.ssa->num_components);
865       num_components = src->src.ssa->num_components;
866 
867       b->cursor = nir_after_instr_and_phis(src->src.ssa->parent_instr);
868 
869       nir_def *cast = nir_u2uN(b, src->src.ssa, new_bit_size);
870 
871       nir_phi_instr_add_src(lowered, src->pred, cast);
872    }
873 
874    nir_def_init(&lowered->instr, &lowered->def, num_components,
875                 new_bit_size);
876 
877    b->cursor = nir_before_instr(&phi->instr);
878    nir_builder_instr_insert(b, &lowered->instr);
879 
880    b->cursor = nir_after_phis(nir_cursor_current_block(b->cursor));
881    nir_def *result = nir_u2uN(b, &lowered->def, old_bit_size);
882 
883    nir_def_replace(&phi->def, result);
884 }
885 
886 static bool
upcast_phi_impl(nir_function_impl * impl,unsigned min_bit_size)887 upcast_phi_impl(nir_function_impl *impl, unsigned min_bit_size)
888 {
889    nir_builder b = nir_builder_create(impl);
890    bool progress = false;
891 
892    nir_foreach_block_reverse(block, impl) {
893       nir_foreach_phi_safe(phi, block) {
894          if (phi->def.bit_size == 1 ||
895              phi->def.bit_size >= min_bit_size)
896             continue;
897 
898          cast_phi(&b, phi, min_bit_size);
899          progress = true;
900       }
901    }
902 
903    if (progress) {
904       nir_metadata_preserve(impl, nir_metadata_control_flow);
905    } else {
906       nir_metadata_preserve(impl, nir_metadata_all);
907    }
908 
909    return progress;
910 }
911 
912 bool
dxil_nir_lower_upcast_phis(nir_shader * shader,unsigned min_bit_size)913 dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size)
914 {
915    bool progress = false;
916 
917    nir_foreach_function_impl(impl, shader) {
918       progress |= upcast_phi_impl(impl, min_bit_size);
919    }
920 
921    return progress;
922 }
923 
924 struct dxil_nir_split_clip_cull_distance_params {
925    nir_variable *new_var[2];
926    nir_shader *shader;
927 };
928 
929 /* In GLSL and SPIR-V, clip and cull distance are arrays of floats (with a limit of 8).
930  * In DXIL, clip and cull distances are up to 2 float4s combined.
931  * Coming from GLSL, we can request this 2 float4 format, but coming from SPIR-V,
932  * we can't, and have to accept a "compact" array of scalar floats.
933  *
934  * To help emitting a valid input signature for this case, split the variables so that they
935  * match what we need to put in the signature (e.g. { float clip[4]; float clip1; float cull[3]; })
936  */
937 static bool
dxil_nir_split_clip_cull_distance_instr(nir_builder * b,nir_instr * instr,void * cb_data)938 dxil_nir_split_clip_cull_distance_instr(nir_builder *b,
939                                         nir_instr *instr,
940                                         void *cb_data)
941 {
942    struct dxil_nir_split_clip_cull_distance_params *params = cb_data;
943 
944    if (instr->type != nir_instr_type_deref)
945       return false;
946 
947    nir_deref_instr *deref = nir_instr_as_deref(instr);
948    nir_variable *var = nir_deref_instr_get_variable(deref);
949    if (!var ||
950        var->data.location < VARYING_SLOT_CLIP_DIST0 ||
951        var->data.location > VARYING_SLOT_CULL_DIST1 ||
952        !var->data.compact)
953       return false;
954 
955    unsigned new_var_idx = var->data.mode == nir_var_shader_in ? 0 : 1;
956    nir_variable *new_var = params->new_var[new_var_idx];
957 
958    /* The location should only be inside clip distance, because clip
959     * and cull should've been merged by nir_lower_clip_cull_distance_arrays()
960     */
961    assert(var->data.location == VARYING_SLOT_CLIP_DIST0 ||
962           var->data.location == VARYING_SLOT_CLIP_DIST1);
963 
964    /* The deref chain to the clip/cull variables should be simple, just the
965     * var and an array with a constant index, otherwise more lowering/optimization
966     * might be needed before this pass, e.g. copy prop, lower_io_to_temporaries,
967     * split_var_copies, and/or lower_var_copies. In the case of arrayed I/O like
968     * inputs to the tessellation or geometry stages, there might be a second level
969     * of array index.
970     */
971    assert(deref->deref_type == nir_deref_type_var ||
972           deref->deref_type == nir_deref_type_array);
973 
974    b->cursor = nir_before_instr(instr);
975    unsigned arrayed_io_length = 0;
976    const struct glsl_type *old_type = var->type;
977    if (nir_is_arrayed_io(var, b->shader->info.stage)) {
978       arrayed_io_length = glsl_array_size(old_type);
979       old_type = glsl_get_array_element(old_type);
980    }
981    if (!new_var) {
982       /* Update lengths for new and old vars */
983       int old_length = glsl_array_size(old_type);
984       int new_length = (old_length + var->data.location_frac) - 4;
985       old_length -= new_length;
986 
987       /* The existing variable fits in the float4 */
988       if (new_length <= 0)
989          return false;
990 
991       new_var = nir_variable_clone(var, params->shader);
992       nir_shader_add_variable(params->shader, new_var);
993       assert(glsl_get_base_type(glsl_get_array_element(old_type)) == GLSL_TYPE_FLOAT);
994       var->type = glsl_array_type(glsl_float_type(), old_length, 0);
995       new_var->type = glsl_array_type(glsl_float_type(), new_length, 0);
996       if (arrayed_io_length) {
997          var->type = glsl_array_type(var->type, arrayed_io_length, 0);
998          new_var->type = glsl_array_type(new_var->type, arrayed_io_length, 0);
999       }
1000       new_var->data.location++;
1001       new_var->data.location_frac = 0;
1002       params->new_var[new_var_idx] = new_var;
1003    }
1004 
1005    /* Update the type for derefs of the old var */
1006    if (deref->deref_type == nir_deref_type_var) {
1007       deref->type = var->type;
1008       return false;
1009    }
1010 
1011    if (glsl_type_is_array(deref->type)) {
1012       assert(arrayed_io_length > 0);
1013       deref->type = glsl_get_array_element(var->type);
1014       return false;
1015    }
1016 
1017    assert(glsl_get_base_type(deref->type) == GLSL_TYPE_FLOAT);
1018 
1019    nir_const_value *index = nir_src_as_const_value(deref->arr.index);
1020    assert(index);
1021 
1022    /* Treat this array as a vector starting at the component index in location_frac,
1023     * so if location_frac is 1 and index is 0, then it's accessing the 'y' component
1024     * of the vector. If index + location_frac is >= 4, there's no component there,
1025     * so we need to add a new variable and adjust the index.
1026     */
1027    unsigned total_index = index->u32 + var->data.location_frac;
1028    if (total_index < 4)
1029       return false;
1030 
1031    nir_deref_instr *new_var_deref = nir_build_deref_var(b, new_var);
1032    nir_deref_instr *new_intermediate_deref = new_var_deref;
1033    if (arrayed_io_length) {
1034       nir_deref_instr *parent = nir_src_as_deref(deref->parent);
1035       assert(parent->deref_type == nir_deref_type_array);
1036       new_intermediate_deref = nir_build_deref_array(b, new_intermediate_deref, parent->arr.index.ssa);
1037    }
1038    nir_deref_instr *new_array_deref = nir_build_deref_array(b, new_intermediate_deref, nir_imm_int(b, total_index % 4));
1039    nir_def_rewrite_uses(&deref->def, &new_array_deref->def);
1040    return true;
1041 }
1042 
1043 bool
dxil_nir_split_clip_cull_distance(nir_shader * shader)1044 dxil_nir_split_clip_cull_distance(nir_shader *shader)
1045 {
1046    struct dxil_nir_split_clip_cull_distance_params params = {
1047       .new_var = { NULL, NULL },
1048       .shader = shader,
1049    };
1050    nir_shader_instructions_pass(shader,
1051                                 dxil_nir_split_clip_cull_distance_instr,
1052                                 nir_metadata_control_flow |
1053                                 nir_metadata_loop_analysis,
1054                                 &params);
1055    return params.new_var[0] != NULL || params.new_var[1] != NULL;
1056 }
1057 
1058 static bool
dxil_nir_lower_double_math_instr(nir_builder * b,nir_instr * instr,UNUSED void * cb_data)1059 dxil_nir_lower_double_math_instr(nir_builder *b,
1060                                  nir_instr *instr,
1061                                  UNUSED void *cb_data)
1062 {
1063    if (instr->type == nir_instr_type_intrinsic) {
1064       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1065       switch (intr->intrinsic) {
1066          case nir_intrinsic_reduce:
1067          case nir_intrinsic_exclusive_scan:
1068          case nir_intrinsic_inclusive_scan:
1069             break;
1070          default:
1071             return false;
1072       }
1073       if (intr->def.bit_size != 64)
1074          return false;
1075       nir_op reduction = nir_intrinsic_reduction_op(intr);
1076       switch (reduction) {
1077          case nir_op_fmul:
1078          case nir_op_fadd:
1079          case nir_op_fmin:
1080          case nir_op_fmax:
1081             break;
1082          default:
1083             return false;
1084       }
1085       b->cursor = nir_before_instr(instr);
1086       nir_src_rewrite(&intr->src[0], nir_pack_double_2x32_dxil(b, nir_unpack_64_2x32(b, intr->src[0].ssa)));
1087       b->cursor = nir_after_instr(instr);
1088       nir_def *result = nir_pack_64_2x32(b, nir_unpack_double_2x32_dxil(b, &intr->def));
1089       nir_def_rewrite_uses_after(&intr->def, result, result->parent_instr);
1090       return true;
1091    }
1092 
1093    if (instr->type != nir_instr_type_alu)
1094       return false;
1095 
1096    nir_alu_instr *alu = nir_instr_as_alu(instr);
1097 
1098    /* TODO: See if we can apply this explicitly to packs/unpacks that are then
1099     * used as a double. As-is, if we had an app explicitly do a 64bit integer op,
1100     * then try to bitcast to double (not expressible in HLSL, but it is in other
1101     * source languages), this would unpack the integer and repack as a double, when
1102     * we probably want to just send the bitcast through to the backend.
1103     */
1104 
1105    b->cursor = nir_before_instr(&alu->instr);
1106 
1107    bool progress = false;
1108    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i) {
1109       if (nir_alu_type_get_base_type(nir_op_infos[alu->op].input_types[i]) == nir_type_float &&
1110           alu->src[i].src.ssa->bit_size == 64) {
1111          unsigned num_components = nir_op_infos[alu->op].input_sizes[i];
1112          if (!num_components)
1113             num_components = alu->def.num_components;
1114          nir_def *components[NIR_MAX_VEC_COMPONENTS];
1115          for (unsigned c = 0; c < num_components; ++c) {
1116             nir_def *packed_double = nir_channel(b, alu->src[i].src.ssa, alu->src[i].swizzle[c]);
1117             nir_def *unpacked_double = nir_unpack_64_2x32(b, packed_double);
1118             components[c] = nir_pack_double_2x32_dxil(b, unpacked_double);
1119             alu->src[i].swizzle[c] = c;
1120          }
1121          nir_src_rewrite(&alu->src[i].src,
1122                          nir_vec(b, components, num_components));
1123          progress = true;
1124       }
1125    }
1126 
1127    if (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) == nir_type_float &&
1128        alu->def.bit_size == 64) {
1129       b->cursor = nir_after_instr(&alu->instr);
1130       nir_def *components[NIR_MAX_VEC_COMPONENTS];
1131       for (unsigned c = 0; c < alu->def.num_components; ++c) {
1132          nir_def *packed_double = nir_channel(b, &alu->def, c);
1133          nir_def *unpacked_double = nir_unpack_double_2x32_dxil(b, packed_double);
1134          components[c] = nir_pack_64_2x32(b, unpacked_double);
1135       }
1136       nir_def *repacked_dvec = nir_vec(b, components, alu->def.num_components);
1137       nir_def_rewrite_uses_after(&alu->def, repacked_dvec, repacked_dvec->parent_instr);
1138       progress = true;
1139    }
1140 
1141    return progress;
1142 }
1143 
1144 bool
dxil_nir_lower_double_math(nir_shader * shader)1145 dxil_nir_lower_double_math(nir_shader *shader)
1146 {
1147    return nir_shader_instructions_pass(shader,
1148                                        dxil_nir_lower_double_math_instr,
1149                                        nir_metadata_control_flow |
1150                                        nir_metadata_loop_analysis,
1151                                        NULL);
1152 }
1153 
1154 typedef struct {
1155    gl_system_value *values;
1156    uint32_t count;
1157 } zero_system_values_state;
1158 
1159 static bool
lower_system_value_to_zero_filter(const nir_instr * instr,const void * cb_state)1160 lower_system_value_to_zero_filter(const nir_instr* instr, const void* cb_state)
1161 {
1162    if (instr->type != nir_instr_type_intrinsic) {
1163       return false;
1164    }
1165 
1166    nir_intrinsic_instr* intrin = nir_instr_as_intrinsic(instr);
1167 
1168    /* All the intrinsics we care about are loads */
1169    if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
1170       return false;
1171 
1172    zero_system_values_state* state = (zero_system_values_state*)cb_state;
1173    for (uint32_t i = 0; i < state->count; ++i) {
1174       gl_system_value value = state->values[i];
1175       nir_intrinsic_op value_op = nir_intrinsic_from_system_value(value);
1176 
1177       if (intrin->intrinsic == value_op) {
1178          return true;
1179       } else if (intrin->intrinsic == nir_intrinsic_load_deref) {
1180          nir_deref_instr* deref = nir_src_as_deref(intrin->src[0]);
1181          if (!nir_deref_mode_is(deref, nir_var_system_value))
1182             return false;
1183 
1184          nir_variable* var = deref->var;
1185          if (var->data.location == value) {
1186             return true;
1187          }
1188       }
1189    }
1190 
1191    return false;
1192 }
1193 
1194 static nir_def*
lower_system_value_to_zero_instr(nir_builder * b,nir_instr * instr,void * _state)1195 lower_system_value_to_zero_instr(nir_builder* b, nir_instr* instr, void* _state)
1196 {
1197    return nir_imm_int(b, 0);
1198 }
1199 
1200 bool
dxil_nir_lower_system_values_to_zero(nir_shader * shader,gl_system_value * system_values,uint32_t count)1201 dxil_nir_lower_system_values_to_zero(nir_shader* shader,
1202                                      gl_system_value* system_values,
1203                                      uint32_t count)
1204 {
1205    zero_system_values_state state = { system_values, count };
1206    return nir_shader_lower_instructions(shader,
1207       lower_system_value_to_zero_filter,
1208       lower_system_value_to_zero_instr,
1209       &state);
1210 }
1211 
1212 static void
lower_load_local_group_size(nir_builder * b,nir_intrinsic_instr * intr)1213 lower_load_local_group_size(nir_builder *b, nir_intrinsic_instr *intr)
1214 {
1215    b->cursor = nir_after_instr(&intr->instr);
1216 
1217    nir_const_value v[3] = {
1218       nir_const_value_for_int(b->shader->info.workgroup_size[0], 32),
1219       nir_const_value_for_int(b->shader->info.workgroup_size[1], 32),
1220       nir_const_value_for_int(b->shader->info.workgroup_size[2], 32)
1221    };
1222    nir_def *size = nir_build_imm(b, 3, 32, v);
1223    nir_def_replace(&intr->def, size);
1224 }
1225 
1226 static bool
lower_system_values_impl(nir_builder * b,nir_intrinsic_instr * intr,void * _state)1227 lower_system_values_impl(nir_builder *b, nir_intrinsic_instr *intr,
1228                          void *_state)
1229 {
1230    switch (intr->intrinsic) {
1231    case nir_intrinsic_load_workgroup_size:
1232       lower_load_local_group_size(b, intr);
1233       return true;
1234    default:
1235       return false;
1236    }
1237 }
1238 
1239 bool
dxil_nir_lower_system_values(nir_shader * shader)1240 dxil_nir_lower_system_values(nir_shader *shader)
1241 {
1242    return nir_shader_intrinsics_pass(shader, lower_system_values_impl,
1243                                      nir_metadata_control_flow | nir_metadata_loop_analysis,
1244                                      NULL);
1245 }
1246 
1247 static const struct glsl_type *
get_bare_samplers_for_type(const struct glsl_type * type,bool is_shadow)1248 get_bare_samplers_for_type(const struct glsl_type *type, bool is_shadow)
1249 {
1250    const struct glsl_type *base_sampler_type =
1251       is_shadow ?
1252       glsl_bare_shadow_sampler_type() : glsl_bare_sampler_type();
1253    return glsl_type_wrap_in_arrays(base_sampler_type, type);
1254 }
1255 
1256 static const struct glsl_type *
get_textures_for_sampler_type(const struct glsl_type * type)1257 get_textures_for_sampler_type(const struct glsl_type *type)
1258 {
1259    return glsl_type_wrap_in_arrays(
1260       glsl_sampler_type_to_texture(
1261          glsl_without_array(type)), type);
1262 }
1263 
1264 static bool
redirect_sampler_derefs(struct nir_builder * b,nir_instr * instr,void * data)1265 redirect_sampler_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1266 {
1267    if (instr->type != nir_instr_type_tex)
1268       return false;
1269 
1270    nir_tex_instr *tex = nir_instr_as_tex(instr);
1271 
1272    int sampler_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref);
1273    if (sampler_idx == -1) {
1274       /* No sampler deref - does this instruction even need a sampler? If not,
1275        * sampler_index doesn't necessarily point to a sampler, so early-out.
1276        */
1277       if (!nir_tex_instr_need_sampler(tex))
1278          return false;
1279 
1280       /* No derefs but needs a sampler, must be using indices */
1281       nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->sampler_index);
1282 
1283       /* Already have a bare sampler here */
1284       if (bare_sampler)
1285          return false;
1286 
1287       nir_variable *old_sampler = NULL;
1288       nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1289          if (var->data.binding <= tex->sampler_index &&
1290              var->data.binding + glsl_type_get_sampler_count(var->type) >
1291                 tex->sampler_index) {
1292 
1293             /* Already have a bare sampler for this binding and it is of the
1294              * correct type, add it to the table */
1295             if (glsl_type_is_bare_sampler(glsl_without_array(var->type)) &&
1296                 glsl_sampler_type_is_shadow(glsl_without_array(var->type)) ==
1297                    tex->is_shadow) {
1298                _mesa_hash_table_u64_insert(data, tex->sampler_index, var);
1299                return false;
1300             }
1301 
1302             old_sampler = var;
1303          }
1304       }
1305 
1306       assert(old_sampler);
1307 
1308       /* Clone the original sampler to a bare sampler of the correct type */
1309       bare_sampler = nir_variable_clone(old_sampler, b->shader);
1310       nir_shader_add_variable(b->shader, bare_sampler);
1311 
1312       bare_sampler->type =
1313          get_bare_samplers_for_type(old_sampler->type, tex->is_shadow);
1314       _mesa_hash_table_u64_insert(data, tex->sampler_index, bare_sampler);
1315       return true;
1316    }
1317 
1318    /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1319    nir_deref_instr *final_deref = nir_src_as_deref(tex->src[sampler_idx].src);
1320    nir_deref_path path;
1321    nir_deref_path_init(&path, final_deref, NULL);
1322 
1323    nir_deref_instr *old_tail = path.path[0];
1324    assert(old_tail->deref_type == nir_deref_type_var);
1325    nir_variable *old_var = old_tail->var;
1326    if (glsl_type_is_bare_sampler(glsl_without_array(old_var->type)) &&
1327        glsl_sampler_type_is_shadow(glsl_without_array(old_var->type)) ==
1328           tex->is_shadow) {
1329       nir_deref_path_finish(&path);
1330       return false;
1331    }
1332 
1333    uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1334                       old_var->data.binding;
1335    nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1336    if (!new_var) {
1337       new_var = nir_variable_clone(old_var, b->shader);
1338       nir_shader_add_variable(b->shader, new_var);
1339       new_var->type =
1340          get_bare_samplers_for_type(old_var->type, tex->is_shadow);
1341       _mesa_hash_table_u64_insert(data, var_key, new_var);
1342    }
1343 
1344    b->cursor = nir_after_instr(&old_tail->instr);
1345    nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1346 
1347    for (unsigned i = 1; path.path[i]; ++i) {
1348       b->cursor = nir_after_instr(&path.path[i]->instr);
1349       new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1350    }
1351 
1352    nir_deref_path_finish(&path);
1353    nir_src_rewrite(&tex->src[sampler_idx].src, &new_tail->def);
1354    return true;
1355 }
1356 
1357 static bool
redirect_texture_derefs(struct nir_builder * b,nir_instr * instr,void * data)1358 redirect_texture_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1359 {
1360    if (instr->type != nir_instr_type_tex)
1361       return false;
1362 
1363    nir_tex_instr *tex = nir_instr_as_tex(instr);
1364 
1365    int texture_idx = nir_tex_instr_src_index(tex, nir_tex_src_texture_deref);
1366    if (texture_idx == -1) {
1367       /* No derefs, must be using indices */
1368       nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->texture_index);
1369 
1370       /* Already have a texture here */
1371       if (bare_sampler)
1372          return false;
1373 
1374       nir_variable *typed_sampler = NULL;
1375       nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1376          if (var->data.binding <= tex->texture_index &&
1377              var->data.binding + glsl_type_get_texture_count(var->type) > tex->texture_index) {
1378             /* Already have a texture for this binding, add it to the table */
1379             _mesa_hash_table_u64_insert(data, tex->texture_index, var);
1380             return false;
1381          }
1382 
1383          if (var->data.binding <= tex->texture_index &&
1384              var->data.binding + glsl_type_get_sampler_count(var->type) > tex->texture_index &&
1385              !glsl_type_is_bare_sampler(glsl_without_array(var->type))) {
1386             typed_sampler = var;
1387          }
1388       }
1389 
1390       /* Clone the typed sampler to a texture and we're done */
1391       assert(typed_sampler);
1392       bare_sampler = nir_variable_clone(typed_sampler, b->shader);
1393       bare_sampler->type = get_textures_for_sampler_type(typed_sampler->type);
1394       nir_shader_add_variable(b->shader, bare_sampler);
1395       _mesa_hash_table_u64_insert(data, tex->texture_index, bare_sampler);
1396       return true;
1397    }
1398 
1399    /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1400    nir_deref_instr *final_deref = nir_src_as_deref(tex->src[texture_idx].src);
1401    nir_deref_path path;
1402    nir_deref_path_init(&path, final_deref, NULL);
1403 
1404    nir_deref_instr *old_tail = path.path[0];
1405    assert(old_tail->deref_type == nir_deref_type_var);
1406    nir_variable *old_var = old_tail->var;
1407    if (glsl_type_is_texture(glsl_without_array(old_var->type)) ||
1408        glsl_type_is_image(glsl_without_array(old_var->type))) {
1409       nir_deref_path_finish(&path);
1410       return false;
1411    }
1412 
1413    uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1414                       old_var->data.binding;
1415    nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1416    if (!new_var) {
1417       new_var = nir_variable_clone(old_var, b->shader);
1418       new_var->type = get_textures_for_sampler_type(old_var->type);
1419       nir_shader_add_variable(b->shader, new_var);
1420       _mesa_hash_table_u64_insert(data, var_key, new_var);
1421    }
1422 
1423    b->cursor = nir_after_instr(&old_tail->instr);
1424    nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1425 
1426    for (unsigned i = 1; path.path[i]; ++i) {
1427       b->cursor = nir_after_instr(&path.path[i]->instr);
1428       new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1429    }
1430 
1431    nir_deref_path_finish(&path);
1432    nir_src_rewrite(&tex->src[texture_idx].src, &new_tail->def);
1433 
1434    return true;
1435 }
1436 
1437 bool
dxil_nir_split_typed_samplers(nir_shader * nir)1438 dxil_nir_split_typed_samplers(nir_shader *nir)
1439 {
1440    struct hash_table_u64 *hash_table = _mesa_hash_table_u64_create(NULL);
1441 
1442    bool progress = nir_shader_instructions_pass(nir, redirect_sampler_derefs,
1443       nir_metadata_control_flow | nir_metadata_loop_analysis, hash_table);
1444 
1445    _mesa_hash_table_u64_clear(hash_table);
1446 
1447    progress |= nir_shader_instructions_pass(nir, redirect_texture_derefs,
1448       nir_metadata_control_flow | nir_metadata_loop_analysis, hash_table);
1449 
1450    _mesa_hash_table_u64_destroy(hash_table);
1451    return progress;
1452 }
1453 
1454 
1455 static bool
lower_sysval_to_load_input_impl(nir_builder * b,nir_intrinsic_instr * intr,void * data)1456 lower_sysval_to_load_input_impl(nir_builder *b, nir_intrinsic_instr *intr,
1457                                 void *data)
1458 {
1459    gl_system_value sysval = SYSTEM_VALUE_MAX;
1460    switch (intr->intrinsic) {
1461    case nir_intrinsic_load_instance_id:
1462       sysval = SYSTEM_VALUE_INSTANCE_ID;
1463       break;
1464    case nir_intrinsic_load_vertex_id_zero_base:
1465       sysval = SYSTEM_VALUE_VERTEX_ID_ZERO_BASE;
1466       break;
1467    default:
1468       return false;
1469    }
1470 
1471    nir_variable **sysval_vars = (nir_variable **)data;
1472    nir_variable *var = sysval_vars[sysval];
1473    assert(var);
1474 
1475    const nir_alu_type dest_type = nir_get_nir_type_for_glsl_type(var->type);
1476    const unsigned bit_size = intr->def.bit_size;
1477 
1478    b->cursor = nir_before_instr(&intr->instr);
1479    nir_def *result = nir_load_input(b, intr->def.num_components, bit_size, nir_imm_int(b, 0),
1480       .base = var->data.driver_location, .dest_type = dest_type);
1481 
1482    nir_def_rewrite_uses(&intr->def, result);
1483    return true;
1484 }
1485 
1486 bool
dxil_nir_lower_sysval_to_load_input(nir_shader * s,nir_variable ** sysval_vars)1487 dxil_nir_lower_sysval_to_load_input(nir_shader *s, nir_variable **sysval_vars)
1488 {
1489    return nir_shader_intrinsics_pass(s, lower_sysval_to_load_input_impl,
1490                                      nir_metadata_control_flow,
1491                                      sysval_vars);
1492 }
1493 
1494 /* Comparison function to sort io values so that first come normal varyings,
1495  * then system values, and then system generated values.
1496  */
1497 static int
variable_location_cmp(const nir_variable * a,const nir_variable * b)1498 variable_location_cmp(const nir_variable* a, const nir_variable* b)
1499 {
1500    // Sort by stream, driver_location, location, location_frac, then index
1501    // If all else is equal, sort full vectors before partial ones
1502    unsigned a_location = a->data.location;
1503    if (a_location >= VARYING_SLOT_PATCH0)
1504       a_location -= VARYING_SLOT_PATCH0;
1505    unsigned b_location = b->data.location;
1506    if (b_location >= VARYING_SLOT_PATCH0)
1507       b_location -= VARYING_SLOT_PATCH0;
1508    unsigned a_stream = a->data.stream & ~NIR_STREAM_PACKED;
1509    unsigned b_stream = b->data.stream & ~NIR_STREAM_PACKED;
1510    return a_stream != b_stream ?
1511             a_stream - b_stream :
1512             a->data.driver_location != b->data.driver_location ?
1513                a->data.driver_location - b->data.driver_location :
1514                a_location !=  b_location ?
1515                   a_location - b_location :
1516                   a->data.location_frac != b->data.location_frac ?
1517                      a->data.location_frac - b->data.location_frac :
1518                      a->data.index != b->data.index ?
1519                         a->data.index - b->data.index :
1520                         glsl_get_component_slots(b->type) - glsl_get_component_slots(a->type);
1521 }
1522 
1523 /* Order varyings according to driver location */
1524 void
dxil_sort_by_driver_location(nir_shader * s,nir_variable_mode modes)1525 dxil_sort_by_driver_location(nir_shader* s, nir_variable_mode modes)
1526 {
1527    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1528 }
1529 
1530 /* Sort PS outputs so that color outputs come first */
1531 void
dxil_sort_ps_outputs(nir_shader * s)1532 dxil_sort_ps_outputs(nir_shader* s)
1533 {
1534    nir_foreach_variable_with_modes_safe(var, s, nir_var_shader_out) {
1535       /* We use the driver_location here to avoid introducing a new
1536        * struct or member variable here. The true, updated driver location
1537        * will be written below, after sorting */
1538       switch (var->data.location) {
1539       case FRAG_RESULT_DEPTH:
1540          var->data.driver_location = 1;
1541          break;
1542       case FRAG_RESULT_STENCIL:
1543          var->data.driver_location = 2;
1544          break;
1545       case FRAG_RESULT_SAMPLE_MASK:
1546          var->data.driver_location = 3;
1547          break;
1548       default:
1549          var->data.driver_location = 0;
1550       }
1551    }
1552 
1553    nir_sort_variables_with_modes(s, variable_location_cmp,
1554                                  nir_var_shader_out);
1555 
1556    unsigned driver_loc = 0;
1557    nir_foreach_variable_with_modes(var, s, nir_var_shader_out) {
1558       /* Fractional vars should use the same driver_location as the base. These will
1559        * get fully merged during signature processing.
1560        */
1561       var->data.driver_location = var->data.location_frac ? driver_loc - 1 : driver_loc++;
1562    }
1563 }
1564 
1565 enum dxil_sysvalue_type {
1566    DXIL_NO_SYSVALUE = 0,
1567    DXIL_USED_SYSVALUE,
1568    DXIL_UNUSED_NO_SYSVALUE,
1569    DXIL_SYSVALUE,
1570    DXIL_GENERATED_SYSVALUE,
1571 };
1572 
1573 static enum dxil_sysvalue_type
nir_var_to_dxil_sysvalue_type(nir_variable * var,uint64_t other_stage_mask,const BITSET_WORD * other_stage_frac_mask)1574 nir_var_to_dxil_sysvalue_type(nir_variable *var, uint64_t other_stage_mask,
1575                               const BITSET_WORD *other_stage_frac_mask)
1576 {
1577    switch (var->data.location) {
1578    case VARYING_SLOT_FACE:
1579       return DXIL_GENERATED_SYSVALUE;
1580    case VARYING_SLOT_POS:
1581    case VARYING_SLOT_PRIMITIVE_ID:
1582    case VARYING_SLOT_CLIP_DIST0:
1583    case VARYING_SLOT_CLIP_DIST1:
1584    case VARYING_SLOT_PSIZ:
1585    case VARYING_SLOT_TESS_LEVEL_INNER:
1586    case VARYING_SLOT_TESS_LEVEL_OUTER:
1587    case VARYING_SLOT_VIEWPORT:
1588    case VARYING_SLOT_LAYER:
1589    case VARYING_SLOT_VIEW_INDEX:
1590       if (!((1ull << var->data.location) & other_stage_mask))
1591          return DXIL_SYSVALUE;
1592       return DXIL_USED_SYSVALUE;
1593    default:
1594       if (var->data.location < VARYING_SLOT_PATCH0 &&
1595           !((1ull << var->data.location) & other_stage_mask))
1596          return DXIL_UNUSED_NO_SYSVALUE;
1597       if (var->data.location_frac && other_stage_frac_mask &&
1598           var->data.location >= VARYING_SLOT_VAR0 &&
1599           !BITSET_TEST(other_stage_frac_mask, ((var->data.location - VARYING_SLOT_VAR0) * 4 + var->data.location_frac)))
1600          return DXIL_UNUSED_NO_SYSVALUE;
1601       return DXIL_NO_SYSVALUE;
1602    }
1603 }
1604 
1605 /* Order between stage values so that normal varyings come first,
1606  * then sysvalues and then system generated values.
1607  */
1608 void
dxil_reassign_driver_locations(nir_shader * s,nir_variable_mode modes,uint64_t other_stage_mask,const BITSET_WORD * other_stage_frac_mask)1609 dxil_reassign_driver_locations(nir_shader* s, nir_variable_mode modes,
1610    uint64_t other_stage_mask, const BITSET_WORD *other_stage_frac_mask)
1611 {
1612    nir_foreach_variable_with_modes_safe(var, s, modes) {
1613       /* We use the driver_location here to avoid introducing a new
1614        * struct or member variable here. The true, updated driver location
1615        * will be written below, after sorting */
1616       var->data.driver_location = nir_var_to_dxil_sysvalue_type(var, other_stage_mask, other_stage_frac_mask);
1617    }
1618 
1619    nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1620 
1621    unsigned driver_loc = 0, driver_patch_loc = 0;
1622    nir_foreach_variable_with_modes(var, s, modes) {
1623       /* Overlap patches with non-patch */
1624       var->data.driver_location = var->data.patch ?
1625          driver_patch_loc++ : driver_loc++;
1626    }
1627 }
1628 
1629 static bool
lower_ubo_array_one_to_static(struct nir_builder * b,nir_intrinsic_instr * intrin,void * cb_data)1630 lower_ubo_array_one_to_static(struct nir_builder *b,
1631                               nir_intrinsic_instr *intrin,
1632                               void *cb_data)
1633 {
1634    if (intrin->intrinsic != nir_intrinsic_load_vulkan_descriptor)
1635       return false;
1636 
1637    nir_variable *var =
1638       nir_get_binding_variable(b->shader, nir_chase_binding(intrin->src[0]));
1639 
1640    if (!var)
1641       return false;
1642 
1643    if (!glsl_type_is_array(var->type) || glsl_array_size(var->type) != 1)
1644       return false;
1645 
1646    nir_intrinsic_instr *index = nir_src_as_intrinsic(intrin->src[0]);
1647    /* We currently do not support reindex */
1648    assert(index && index->intrinsic == nir_intrinsic_vulkan_resource_index);
1649 
1650    if (nir_src_is_const(index->src[0]) && nir_src_as_uint(index->src[0]) == 0)
1651       return false;
1652 
1653    if (nir_intrinsic_desc_type(index) != VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER)
1654       return false;
1655 
1656    b->cursor = nir_instr_remove(&index->instr);
1657 
1658    // Indexing out of bounds on array of UBOs is considered undefined
1659    // behavior. Therefore, we just hardcode all the index to 0.
1660    uint8_t bit_size = index->def.bit_size;
1661    nir_def *zero = nir_imm_intN_t(b, 0, bit_size);
1662    nir_def *dest =
1663       nir_vulkan_resource_index(b, index->num_components, bit_size, zero,
1664                                 .desc_set = nir_intrinsic_desc_set(index),
1665                                 .binding = nir_intrinsic_binding(index),
1666                                 .desc_type = nir_intrinsic_desc_type(index));
1667 
1668    nir_def_rewrite_uses(&index->def, dest);
1669 
1670    return true;
1671 }
1672 
1673 bool
dxil_nir_lower_ubo_array_one_to_static(nir_shader * s)1674 dxil_nir_lower_ubo_array_one_to_static(nir_shader *s)
1675 {
1676    bool progress = nir_shader_intrinsics_pass(s,
1677                                               lower_ubo_array_one_to_static,
1678                                               nir_metadata_none, NULL);
1679 
1680    return progress;
1681 }
1682 
1683 static bool
is_fquantize2f16(const nir_instr * instr,const void * data)1684 is_fquantize2f16(const nir_instr *instr, const void *data)
1685 {
1686    if (instr->type != nir_instr_type_alu)
1687       return false;
1688 
1689    nir_alu_instr *alu = nir_instr_as_alu(instr);
1690    return alu->op == nir_op_fquantize2f16;
1691 }
1692 
1693 static nir_def *
lower_fquantize2f16(struct nir_builder * b,nir_instr * instr,void * data)1694 lower_fquantize2f16(struct nir_builder *b, nir_instr *instr, void *data)
1695 {
1696    /*
1697     * SpvOpQuantizeToF16 documentation says:
1698     *
1699     * "
1700     * If Value is an infinity, the result is the same infinity.
1701     * If Value is a NaN, the result is a NaN, but not necessarily the same NaN.
1702     * If Value is positive with a magnitude too large to represent as a 16-bit
1703     * floating-point value, the result is positive infinity. If Value is negative
1704     * with a magnitude too large to represent as a 16-bit floating-point value,
1705     * the result is negative infinity. If the magnitude of Value is too small to
1706     * represent as a normalized 16-bit floating-point value, the result may be
1707     * either +0 or -0.
1708     * "
1709     *
1710     * which we turn into:
1711     *
1712     *   if (val < MIN_FLOAT16)
1713     *      return -INFINITY;
1714     *   else if (val > MAX_FLOAT16)
1715     *      return -INFINITY;
1716     *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) != 0)
1717     *      return -0.0f;
1718     *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) == 0)
1719     *      return +0.0f;
1720     *   else
1721     *      return round(val);
1722     */
1723    nir_alu_instr *alu = nir_instr_as_alu(instr);
1724    nir_def *src =
1725       alu->src[0].src.ssa;
1726 
1727    nir_def *neg_inf_cond =
1728       nir_flt_imm(b, src, -65504.0f);
1729    nir_def *pos_inf_cond =
1730       nir_fgt_imm(b, src, 65504.0f);
1731    nir_def *zero_cond =
1732       nir_flt_imm(b, nir_fabs(b, src), ldexpf(1.0, -14));
1733    nir_def *zero = nir_iand_imm(b, src, 1 << 31);
1734    nir_def *round = nir_iand_imm(b, src, ~BITFIELD_MASK(13));
1735 
1736    nir_def *res =
1737       nir_bcsel(b, neg_inf_cond, nir_imm_float(b, -INFINITY), round);
1738    res = nir_bcsel(b, pos_inf_cond, nir_imm_float(b, INFINITY), res);
1739    res = nir_bcsel(b, zero_cond, zero, res);
1740    return res;
1741 }
1742 
1743 bool
dxil_nir_lower_fquantize2f16(nir_shader * s)1744 dxil_nir_lower_fquantize2f16(nir_shader *s)
1745 {
1746    return nir_shader_lower_instructions(s, is_fquantize2f16, lower_fquantize2f16, NULL);
1747 }
1748 
1749 static bool
fix_io_uint_deref_types(struct nir_builder * builder,nir_instr * instr,void * data)1750 fix_io_uint_deref_types(struct nir_builder *builder, nir_instr *instr, void *data)
1751 {
1752    if (instr->type != nir_instr_type_deref)
1753       return false;
1754 
1755    nir_deref_instr *deref = nir_instr_as_deref(instr);
1756    nir_variable *var = nir_deref_instr_get_variable(deref);
1757 
1758    if (var == data) {
1759       deref->type = glsl_type_wrap_in_arrays(glsl_uint_type(), deref->type);
1760       return true;
1761    }
1762 
1763    return false;
1764 }
1765 
1766 static bool
fix_io_uint_type(nir_shader * s,nir_variable_mode modes,int slot)1767 fix_io_uint_type(nir_shader *s, nir_variable_mode modes, int slot)
1768 {
1769    nir_variable *fixed_var = NULL;
1770    nir_foreach_variable_with_modes(var, s, modes) {
1771       if (var->data.location == slot) {
1772          const struct glsl_type *plain_type = glsl_without_array(var->type);
1773          if (plain_type == glsl_uint_type())
1774             return false;
1775 
1776          assert(plain_type == glsl_int_type());
1777          var->type = glsl_type_wrap_in_arrays(glsl_uint_type(), var->type);
1778          fixed_var = var;
1779          break;
1780       }
1781    }
1782 
1783    assert(fixed_var);
1784 
1785    return nir_shader_instructions_pass(s, fix_io_uint_deref_types,
1786                                        nir_metadata_all, fixed_var);
1787 }
1788 
1789 bool
dxil_nir_fix_io_uint_type(nir_shader * s,uint64_t in_mask,uint64_t out_mask)1790 dxil_nir_fix_io_uint_type(nir_shader *s, uint64_t in_mask, uint64_t out_mask)
1791 {
1792    if (!(s->info.outputs_written & out_mask) &&
1793        !(s->info.inputs_read & in_mask))
1794       return false;
1795 
1796    bool progress = false;
1797 
1798    while (in_mask) {
1799       int slot = u_bit_scan64(&in_mask);
1800       progress |= (s->info.inputs_read & (1ull << slot)) &&
1801                   fix_io_uint_type(s, nir_var_shader_in, slot);
1802    }
1803 
1804    while (out_mask) {
1805       int slot = u_bit_scan64(&out_mask);
1806       progress |= (s->info.outputs_written & (1ull << slot)) &&
1807                   fix_io_uint_type(s, nir_var_shader_out, slot);
1808    }
1809 
1810    return progress;
1811 }
1812 
1813 static bool
lower_kill(struct nir_builder * builder,nir_intrinsic_instr * intr,void * _cb_data)1814 lower_kill(struct nir_builder *builder, nir_intrinsic_instr *intr,
1815            void *_cb_data)
1816 {
1817    if (intr->intrinsic != nir_intrinsic_terminate &&
1818        intr->intrinsic != nir_intrinsic_terminate_if)
1819       return false;
1820 
1821    builder->cursor = nir_instr_remove(&intr->instr);
1822    nir_def *condition;
1823 
1824    if (intr->intrinsic == nir_intrinsic_terminate) {
1825       nir_demote(builder);
1826       condition = nir_imm_true(builder);
1827    } else {
1828       nir_demote_if(builder, intr->src[0].ssa);
1829       condition = intr->src[0].ssa;
1830    }
1831 
1832    /* Create a new block by branching on the discard condition so that this return
1833     * is definitely the last instruction in its own block */
1834    nir_if *nif = nir_push_if(builder, condition);
1835    nir_jump(builder, nir_jump_return);
1836    nir_pop_if(builder, nif);
1837 
1838    return true;
1839 }
1840 
1841 bool
dxil_nir_lower_discard_and_terminate(nir_shader * s)1842 dxil_nir_lower_discard_and_terminate(nir_shader *s)
1843 {
1844    if (s->info.stage != MESA_SHADER_FRAGMENT)
1845       return false;
1846 
1847    // This pass only works if all functions have been inlined
1848    assert(exec_list_length(&s->functions) == 1);
1849    return nir_shader_intrinsics_pass(s, lower_kill, nir_metadata_none, NULL);
1850 }
1851 
1852 static bool
update_writes(struct nir_builder * b,nir_intrinsic_instr * intr,void * _state)1853 update_writes(struct nir_builder *b, nir_intrinsic_instr *intr, void *_state)
1854 {
1855    if (intr->intrinsic != nir_intrinsic_store_output)
1856       return false;
1857 
1858    nir_io_semantics io = nir_intrinsic_io_semantics(intr);
1859    if (io.location != VARYING_SLOT_POS)
1860       return false;
1861 
1862    nir_def *src = intr->src[0].ssa;
1863    unsigned write_mask = nir_intrinsic_write_mask(intr);
1864    if (src->num_components == 4 && write_mask == 0xf)
1865       return false;
1866 
1867    b->cursor = nir_before_instr(&intr->instr);
1868    unsigned first_comp = nir_intrinsic_component(intr);
1869    nir_def *channels[4] = { NULL, NULL, NULL, NULL };
1870    assert(first_comp + src->num_components <= ARRAY_SIZE(channels));
1871    for (unsigned i = 0; i < src->num_components; ++i)
1872       if (write_mask & (1 << i))
1873          channels[i + first_comp] = nir_channel(b, src, i);
1874    for (unsigned i = 0; i < 4; ++i)
1875       if (!channels[i])
1876          channels[i] = nir_imm_intN_t(b, 0, src->bit_size);
1877 
1878    intr->num_components = 4;
1879    nir_src_rewrite(&intr->src[0], nir_vec(b, channels, 4));
1880    nir_intrinsic_set_component(intr, 0);
1881    nir_intrinsic_set_write_mask(intr, 0xf);
1882    return true;
1883 }
1884 
1885 bool
dxil_nir_ensure_position_writes(nir_shader * s)1886 dxil_nir_ensure_position_writes(nir_shader *s)
1887 {
1888    if (s->info.stage != MESA_SHADER_VERTEX &&
1889        s->info.stage != MESA_SHADER_GEOMETRY &&
1890        s->info.stage != MESA_SHADER_TESS_EVAL)
1891       return false;
1892    if ((s->info.outputs_written & VARYING_BIT_POS) == 0)
1893       return false;
1894 
1895    return nir_shader_intrinsics_pass(s, update_writes,
1896                                        nir_metadata_control_flow,
1897                                        NULL);
1898 }
1899 
1900 static bool
is_sample_pos(const nir_instr * instr,const void * _data)1901 is_sample_pos(const nir_instr *instr, const void *_data)
1902 {
1903    if (instr->type != nir_instr_type_intrinsic)
1904       return false;
1905    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1906    return intr->intrinsic == nir_intrinsic_load_sample_pos;
1907 }
1908 
1909 static nir_def *
lower_sample_pos(nir_builder * b,nir_instr * instr,void * _data)1910 lower_sample_pos(nir_builder *b, nir_instr *instr, void *_data)
1911 {
1912    return nir_load_sample_pos_from_id(b, 32, nir_load_sample_id(b));
1913 }
1914 
1915 bool
dxil_nir_lower_sample_pos(nir_shader * s)1916 dxil_nir_lower_sample_pos(nir_shader *s)
1917 {
1918    return nir_shader_lower_instructions(s, is_sample_pos, lower_sample_pos, NULL);
1919 }
1920 
1921 static bool
lower_subgroup_id(nir_builder * b,nir_intrinsic_instr * intr,void * data)1922 lower_subgroup_id(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1923 {
1924    if (intr->intrinsic != nir_intrinsic_load_subgroup_id)
1925       return false;
1926 
1927    b->cursor = nir_before_impl(b->impl);
1928    if (b->shader->info.stage == MESA_SHADER_COMPUTE &&
1929        b->shader->info.workgroup_size[1] == 1 &&
1930        b->shader->info.workgroup_size[2] == 1) {
1931       /* When using Nx1x1 groups, use a simple stable algorithm
1932        * which is almost guaranteed to be correct. */
1933       nir_def *subgroup_id = nir_udiv(b, nir_load_local_invocation_index(b), nir_load_subgroup_size(b));
1934       nir_def_rewrite_uses(&intr->def, subgroup_id);
1935       return true;
1936    }
1937 
1938    nir_def **subgroup_id = (nir_def **)data;
1939    if (*subgroup_id == NULL) {
1940       nir_variable *subgroup_id_counter = nir_variable_create(b->shader, nir_var_mem_shared, glsl_uint_type(), "dxil_SubgroupID_counter");
1941       nir_variable *subgroup_id_local = nir_local_variable_create(b->impl, glsl_uint_type(), "dxil_SubgroupID_local");
1942       nir_store_var(b, subgroup_id_local, nir_imm_int(b, 0), 1);
1943 
1944       nir_deref_instr *counter_deref = nir_build_deref_var(b, subgroup_id_counter);
1945       nir_def *tid = nir_load_local_invocation_index(b);
1946       nir_if *nif = nir_push_if(b, nir_ieq_imm(b, tid, 0));
1947       nir_store_deref(b, counter_deref, nir_imm_int(b, 0), 1);
1948       nir_pop_if(b, nif);
1949 
1950       nir_barrier(b,
1951                          .execution_scope = SCOPE_WORKGROUP,
1952                          .memory_scope = SCOPE_WORKGROUP,
1953                          .memory_semantics = NIR_MEMORY_ACQ_REL,
1954                          .memory_modes = nir_var_mem_shared);
1955 
1956       nif = nir_push_if(b, nir_elect(b, 1));
1957       nir_def *subgroup_id_first_thread = nir_deref_atomic(b, 32, &counter_deref->def, nir_imm_int(b, 1),
1958                                                                .atomic_op = nir_atomic_op_iadd);
1959       nir_store_var(b, subgroup_id_local, subgroup_id_first_thread, 1);
1960       nir_pop_if(b, nif);
1961 
1962       nir_def *subgroup_id_loaded = nir_load_var(b, subgroup_id_local);
1963       *subgroup_id = nir_read_first_invocation(b, subgroup_id_loaded);
1964    }
1965    nir_def_rewrite_uses(&intr->def, *subgroup_id);
1966    return true;
1967 }
1968 
1969 bool
dxil_nir_lower_subgroup_id(nir_shader * s)1970 dxil_nir_lower_subgroup_id(nir_shader *s)
1971 {
1972    nir_def *subgroup_id = NULL;
1973    return nir_shader_intrinsics_pass(s, lower_subgroup_id, nir_metadata_none,
1974                                      &subgroup_id);
1975 }
1976 
1977 static bool
lower_num_subgroups(nir_builder * b,nir_intrinsic_instr * intr,void * data)1978 lower_num_subgroups(nir_builder *b, nir_intrinsic_instr *intr, void *data)
1979 {
1980    if (intr->intrinsic != nir_intrinsic_load_num_subgroups)
1981       return false;
1982 
1983    b->cursor = nir_before_instr(&intr->instr);
1984    nir_def *subgroup_size = nir_load_subgroup_size(b);
1985    nir_def *size_minus_one = nir_iadd_imm(b, subgroup_size, -1);
1986    nir_def *workgroup_size_vec = nir_load_workgroup_size(b);
1987    nir_def *workgroup_size = nir_imul(b, nir_channel(b, workgroup_size_vec, 0),
1988                                              nir_imul(b, nir_channel(b, workgroup_size_vec, 1),
1989                                                          nir_channel(b, workgroup_size_vec, 2)));
1990    nir_def *ret = nir_idiv(b, nir_iadd(b, workgroup_size, size_minus_one), subgroup_size);
1991    nir_def_rewrite_uses(&intr->def, ret);
1992    return true;
1993 }
1994 
1995 bool
dxil_nir_lower_num_subgroups(nir_shader * s)1996 dxil_nir_lower_num_subgroups(nir_shader *s)
1997 {
1998    return nir_shader_intrinsics_pass(s, lower_num_subgroups,
1999                                        nir_metadata_control_flow |
2000                                        nir_metadata_loop_analysis, NULL);
2001 }
2002 
2003 
2004 static const struct glsl_type *
get_cast_type(unsigned bit_size)2005 get_cast_type(unsigned bit_size)
2006 {
2007    switch (bit_size) {
2008    case 64:
2009       return glsl_int64_t_type();
2010    case 32:
2011       return glsl_int_type();
2012    case 16:
2013       return glsl_int16_t_type();
2014    case 8:
2015       return glsl_int8_t_type();
2016    }
2017    unreachable("Invalid bit_size");
2018 }
2019 
2020 static void
split_unaligned_load(nir_builder * b,nir_intrinsic_instr * intrin,unsigned alignment)2021 split_unaligned_load(nir_builder *b, nir_intrinsic_instr *intrin, unsigned alignment)
2022 {
2023    enum gl_access_qualifier access = nir_intrinsic_access(intrin);
2024    nir_def *srcs[NIR_MAX_VEC_COMPONENTS * NIR_MAX_VEC_COMPONENTS * sizeof(int64_t) / 8];
2025    unsigned comp_size = intrin->def.bit_size / 8;
2026    unsigned num_comps = intrin->def.num_components;
2027 
2028    b->cursor = nir_before_instr(&intrin->instr);
2029 
2030    nir_deref_instr *ptr = nir_src_as_deref(intrin->src[0]);
2031 
2032    const struct glsl_type *cast_type = get_cast_type(alignment * 8);
2033    nir_deref_instr *cast = nir_build_deref_cast(b, &ptr->def, ptr->modes, cast_type, alignment);
2034 
2035    unsigned num_loads = DIV_ROUND_UP(comp_size * num_comps, alignment);
2036    for (unsigned i = 0; i < num_loads; ++i) {
2037       nir_deref_instr *elem = nir_build_deref_ptr_as_array(b, cast, nir_imm_intN_t(b, i, cast->def.bit_size));
2038       srcs[i] = nir_load_deref_with_access(b, elem, access);
2039    }
2040 
2041    nir_def *new_dest = nir_extract_bits(b, srcs, num_loads, 0, num_comps, intrin->def.bit_size);
2042    nir_def_replace(&intrin->def, new_dest);
2043 }
2044 
2045 static void
split_unaligned_store(nir_builder * b,nir_intrinsic_instr * intrin,unsigned alignment)2046 split_unaligned_store(nir_builder *b, nir_intrinsic_instr *intrin, unsigned alignment)
2047 {
2048    enum gl_access_qualifier access = nir_intrinsic_access(intrin);
2049 
2050    nir_def *value = intrin->src[1].ssa;
2051    unsigned comp_size = value->bit_size / 8;
2052    unsigned num_comps = value->num_components;
2053 
2054    b->cursor = nir_before_instr(&intrin->instr);
2055 
2056    nir_deref_instr *ptr = nir_src_as_deref(intrin->src[0]);
2057 
2058    const struct glsl_type *cast_type = get_cast_type(alignment * 8);
2059    nir_deref_instr *cast = nir_build_deref_cast(b, &ptr->def, ptr->modes, cast_type, alignment);
2060 
2061    unsigned num_stores = DIV_ROUND_UP(comp_size * num_comps, alignment);
2062    for (unsigned i = 0; i < num_stores; ++i) {
2063       nir_def *substore_val = nir_extract_bits(b, &value, 1, i * alignment * 8, 1, alignment * 8);
2064       nir_deref_instr *elem = nir_build_deref_ptr_as_array(b, cast, nir_imm_intN_t(b, i, cast->def.bit_size));
2065       nir_store_deref_with_access(b, elem, substore_val, ~0, access);
2066    }
2067 
2068    nir_instr_remove(&intrin->instr);
2069 }
2070 
2071 bool
dxil_nir_split_unaligned_loads_stores(nir_shader * shader,nir_variable_mode modes)2072 dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode modes)
2073 {
2074    bool progress = false;
2075 
2076    nir_foreach_function_impl(impl, shader) {
2077       nir_builder b = nir_builder_create(impl);
2078 
2079       nir_foreach_block(block, impl) {
2080          nir_foreach_instr_safe(instr, block) {
2081             if (instr->type != nir_instr_type_intrinsic)
2082                continue;
2083             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2084             if (intrin->intrinsic != nir_intrinsic_load_deref &&
2085                 intrin->intrinsic != nir_intrinsic_store_deref)
2086                continue;
2087             nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
2088             if (!nir_deref_mode_may_be(deref, modes))
2089                continue;
2090 
2091             unsigned align_mul = 0, align_offset = 0;
2092             nir_get_explicit_deref_align(deref, true, &align_mul, &align_offset);
2093 
2094             unsigned alignment = align_offset ? 1 << (ffs(align_offset) - 1) : align_mul;
2095 
2096             /* We can load anything at 4-byte alignment, except for
2097              * UBOs (AKA CBs where the granularity is 16 bytes).
2098              */
2099             unsigned req_align = (nir_deref_mode_is_one_of(deref, nir_var_mem_ubo | nir_var_mem_push_const) ? 16 : 4);
2100             if (alignment >= req_align)
2101                continue;
2102 
2103             nir_def *val;
2104             if (intrin->intrinsic == nir_intrinsic_load_deref) {
2105                val = &intrin->def;
2106             } else {
2107                val = intrin->src[1].ssa;
2108             }
2109 
2110             unsigned scalar_byte_size = glsl_type_is_boolean(deref->type) ? 4 : glsl_get_bit_size(deref->type) / 8;
2111             unsigned num_components =
2112                /* If the vector stride is larger than the scalar size, lower_explicit_io will
2113                 * turn this into multiple scalar loads anyway, so we don't have to split it here. */
2114                glsl_get_explicit_stride(deref->type) > scalar_byte_size ? 1 :
2115                (val->num_components == 3 ? 4 : val->num_components);
2116             unsigned natural_alignment = scalar_byte_size * num_components;
2117 
2118             if (alignment >= natural_alignment)
2119                continue;
2120 
2121             if (intrin->intrinsic == nir_intrinsic_load_deref)
2122                split_unaligned_load(&b, intrin, alignment);
2123             else
2124                split_unaligned_store(&b, intrin, alignment);
2125             progress = true;
2126          }
2127       }
2128    }
2129 
2130    return progress;
2131 }
2132 
2133 static void
lower_inclusive_to_exclusive(nir_builder * b,nir_intrinsic_instr * intr)2134 lower_inclusive_to_exclusive(nir_builder *b, nir_intrinsic_instr *intr)
2135 {
2136    b->cursor = nir_after_instr(&intr->instr);
2137 
2138    nir_op op = nir_intrinsic_reduction_op(intr);
2139    intr->intrinsic = nir_intrinsic_exclusive_scan;
2140    nir_intrinsic_set_reduction_op(intr, op);
2141 
2142    nir_def *final_val = nir_build_alu2(b, nir_intrinsic_reduction_op(intr),
2143                                            &intr->def, intr->src[0].ssa);
2144    nir_def_rewrite_uses_after(&intr->def, final_val, final_val->parent_instr);
2145 }
2146 
2147 static bool
lower_subgroup_scan(nir_builder * b,nir_intrinsic_instr * intr,void * data)2148 lower_subgroup_scan(nir_builder *b, nir_intrinsic_instr *intr, void *data)
2149 {
2150    switch (intr->intrinsic) {
2151    case nir_intrinsic_exclusive_scan:
2152    case nir_intrinsic_inclusive_scan:
2153       switch ((nir_op)nir_intrinsic_reduction_op(intr)) {
2154       case nir_op_iadd:
2155       case nir_op_fadd:
2156       case nir_op_imul:
2157       case nir_op_fmul:
2158          if (intr->intrinsic == nir_intrinsic_exclusive_scan)
2159             return false;
2160          lower_inclusive_to_exclusive(b, intr);
2161          return true;
2162       default:
2163          break;
2164       }
2165       break;
2166    default:
2167       return false;
2168    }
2169 
2170    b->cursor = nir_before_instr(&intr->instr);
2171    nir_op op = nir_intrinsic_reduction_op(intr);
2172    nir_def *subgroup_id = nir_load_subgroup_invocation(b);
2173    nir_def *subgroup_size = nir_load_subgroup_size(b);
2174    nir_def *active_threads = nir_ballot(b, 4, 32, nir_imm_true(b));
2175    nir_def *base_value;
2176    uint32_t bit_size = intr->def.bit_size;
2177    if (op == nir_op_iand || op == nir_op_umin)
2178       base_value = nir_imm_intN_t(b, ~0ull, bit_size);
2179    else if (op == nir_op_imin)
2180       base_value = nir_imm_intN_t(b, (1ull << (bit_size - 1)) - 1, bit_size);
2181    else if (op == nir_op_imax)
2182       base_value = nir_imm_intN_t(b, 1ull << (bit_size - 1), bit_size);
2183    else if (op == nir_op_fmax)
2184       base_value = nir_imm_floatN_t(b, -INFINITY, bit_size);
2185    else if (op == nir_op_fmin)
2186       base_value = nir_imm_floatN_t(b, INFINITY, bit_size);
2187    else
2188       base_value = nir_imm_intN_t(b, 0, bit_size);
2189 
2190    nir_variable *loop_counter_var = nir_local_variable_create(b->impl, glsl_uint_type(), "subgroup_loop_counter");
2191    nir_variable *result_var = nir_local_variable_create(b->impl,
2192                                                         glsl_vector_type(nir_get_glsl_base_type_for_nir_type(
2193                                                            nir_op_infos[op].input_types[0] | bit_size), 1),
2194                                                         "subgroup_loop_result");
2195    nir_store_var(b, loop_counter_var, nir_imm_int(b, 0), 1);
2196    nir_store_var(b, result_var, base_value, 1);
2197    nir_loop *loop = nir_push_loop(b);
2198    nir_def *loop_counter = nir_load_var(b, loop_counter_var);
2199 
2200    nir_if *nif = nir_push_if(b, nir_ilt(b, loop_counter, subgroup_size));
2201    nir_def *other_thread_val = nir_read_invocation(b, intr->src[0].ssa, loop_counter);
2202    nir_def *thread_in_range = intr->intrinsic == nir_intrinsic_inclusive_scan ?
2203       nir_ige(b, subgroup_id, loop_counter) :
2204       nir_ilt(b, loop_counter, subgroup_id);
2205    nir_def *thread_active = nir_ballot_bitfield_extract(b, 1, active_threads, loop_counter);
2206 
2207    nir_if *if_active_thread = nir_push_if(b, nir_iand(b, thread_in_range, thread_active));
2208    nir_def *result = nir_build_alu2(b, op, nir_load_var(b, result_var), other_thread_val);
2209    nir_store_var(b, result_var, result, 1);
2210    nir_pop_if(b, if_active_thread);
2211 
2212    nir_store_var(b, loop_counter_var, nir_iadd_imm(b, loop_counter, 1), 1);
2213    nir_jump(b, nir_jump_continue);
2214    nir_pop_if(b, nif);
2215 
2216    nir_jump(b, nir_jump_break);
2217    nir_pop_loop(b, loop);
2218 
2219    result = nir_load_var(b, result_var);
2220    nir_def_rewrite_uses(&intr->def, result);
2221    return true;
2222 }
2223 
2224 bool
dxil_nir_lower_unsupported_subgroup_scan(nir_shader * s)2225 dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s)
2226 {
2227    bool ret = nir_shader_intrinsics_pass(s, lower_subgroup_scan,
2228                                          nir_metadata_none, NULL);
2229    if (ret) {
2230       /* Lower the ballot bitfield tests */
2231       nir_lower_subgroups_options options = { .ballot_bit_size = 32, .ballot_components = 4 };
2232       nir_lower_subgroups(s, &options);
2233    }
2234    return ret;
2235 }
2236 
2237 bool
dxil_nir_forward_front_face(nir_shader * nir)2238 dxil_nir_forward_front_face(nir_shader *nir)
2239 {
2240    assert(nir->info.stage == MESA_SHADER_FRAGMENT);
2241 
2242    nir_variable *var = nir_find_variable_with_location(nir, nir_var_shader_in, VARYING_SLOT_FACE);
2243    if (var) {
2244       var->data.location = VARYING_SLOT_VAR12;
2245       return true;
2246    }
2247    return false;
2248 }
2249 
2250 static bool
move_consts(nir_builder * b,nir_instr * instr,void * data)2251 move_consts(nir_builder *b, nir_instr *instr, void *data)
2252 {
2253    bool progress = false;
2254    switch (instr->type) {
2255    case nir_instr_type_load_const: {
2256       /* Sink load_const to their uses if there's multiple */
2257       nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
2258       if (!list_is_singular(&load_const->def.uses)) {
2259          nir_foreach_use_safe(src, &load_const->def) {
2260             b->cursor = nir_before_src(src);
2261             nir_load_const_instr *new_load = nir_load_const_instr_create(b->shader,
2262                                                                          load_const->def.num_components,
2263                                                                          load_const->def.bit_size);
2264             memcpy(new_load->value, load_const->value, sizeof(load_const->value[0]) * load_const->def.num_components);
2265             nir_builder_instr_insert(b, &new_load->instr);
2266             nir_src_rewrite(src, &new_load->def);
2267             progress = true;
2268          }
2269       }
2270       return progress;
2271    }
2272    default:
2273       return false;
2274    }
2275 }
2276 
2277 /* Sink all consts so that they have only have a single use.
2278  * The DXIL backend will already de-dupe the constants to the
2279  * same dxil_value if they have the same type, but this allows a single constant
2280  * to have different types without bitcasts. */
2281 bool
dxil_nir_move_consts(nir_shader * s)2282 dxil_nir_move_consts(nir_shader *s)
2283 {
2284    return nir_shader_instructions_pass(s, move_consts,
2285                                        nir_metadata_control_flow,
2286                                        NULL);
2287 }
2288 
2289 static void
clear_pass_flags(nir_function_impl * impl)2290 clear_pass_flags(nir_function_impl *impl)
2291 {
2292    nir_foreach_block(block, impl) {
2293       nir_foreach_instr(instr, block) {
2294          instr->pass_flags = 0;
2295       }
2296    }
2297 }
2298 
2299 static bool
add_def_to_worklist(nir_def * def,void * state)2300 add_def_to_worklist(nir_def *def, void *state)
2301 {
2302    nir_foreach_use_including_if(src, def) {
2303       if (nir_src_is_if(src)) {
2304          nir_if *nif = nir_src_parent_if(src);
2305          nir_foreach_block_in_cf_node(block, &nif->cf_node) {
2306             nir_foreach_instr(instr, block)
2307                nir_instr_worklist_push_tail(state, instr);
2308          }
2309       } else
2310          nir_instr_worklist_push_tail(state, nir_src_parent_instr(src));
2311    }
2312    return true;
2313 }
2314 
2315 static bool
set_input_bits(struct dxil_module * mod,nir_intrinsic_instr * intr,BITSET_WORD * input_bits,uint32_t *** tables,const uint32_t ** table_sizes)2316 set_input_bits(struct dxil_module *mod, nir_intrinsic_instr *intr, BITSET_WORD *input_bits, uint32_t ***tables, const uint32_t **table_sizes)
2317 {
2318    if (intr->intrinsic == nir_intrinsic_load_view_index) {
2319       BITSET_SET(input_bits, 0);
2320       return true;
2321    }
2322 
2323    bool any_bits_set = false;
2324    nir_src *row_src = intr->intrinsic == nir_intrinsic_load_per_vertex_input ? &intr->src[1] : &intr->src[0];
2325    bool is_patch_constant = mod->shader_kind == DXIL_DOMAIN_SHADER && intr->intrinsic == nir_intrinsic_load_input;
2326    const struct dxil_signature_record *sig_rec = is_patch_constant ?
2327       &mod->patch_consts[nir_intrinsic_base(intr)] :
2328       &mod->inputs[mod->input_mappings[nir_intrinsic_base(intr)]];
2329    if (is_patch_constant) {
2330       /* Redirect to the second I/O table */
2331       *tables = *tables + 1;
2332       *table_sizes = *table_sizes + 1;
2333    }
2334    for (uint32_t component = 0; component < intr->num_components; ++component) {
2335       uint32_t base_element = 0;
2336       uint32_t num_elements = sig_rec->num_elements;
2337       if (nir_src_is_const(*row_src)) {
2338          base_element = (uint32_t)nir_src_as_uint(*row_src);
2339          num_elements = 1;
2340       }
2341       for (uint32_t element = 0; element < num_elements; ++element) {
2342          uint32_t row = sig_rec->elements[element + base_element].reg;
2343          if (row == 0xffffffff)
2344             continue;
2345          BITSET_SET(input_bits, row * 4 + component + nir_intrinsic_component(intr));
2346          any_bits_set = true;
2347       }
2348    }
2349    return any_bits_set;
2350 }
2351 
2352 static bool
set_output_bits(struct dxil_module * mod,nir_intrinsic_instr * intr,BITSET_WORD * input_bits,uint32_t ** tables,const uint32_t * table_sizes)2353 set_output_bits(struct dxil_module *mod, nir_intrinsic_instr *intr, BITSET_WORD *input_bits, uint32_t **tables, const uint32_t *table_sizes)
2354 {
2355    bool any_bits_set = false;
2356    nir_src *row_src = intr->intrinsic == nir_intrinsic_store_per_vertex_output ? &intr->src[2] : &intr->src[1];
2357    bool is_patch_constant = mod->shader_kind == DXIL_HULL_SHADER && intr->intrinsic == nir_intrinsic_store_output;
2358    const struct dxil_signature_record *sig_rec = is_patch_constant ?
2359       &mod->patch_consts[nir_intrinsic_base(intr)] :
2360       &mod->outputs[nir_intrinsic_base(intr)];
2361    for (uint32_t component = 0; component < intr->num_components; ++component) {
2362       uint32_t base_element = 0;
2363       uint32_t num_elements = sig_rec->num_elements;
2364       if (nir_src_is_const(*row_src)) {
2365          base_element = (uint32_t)nir_src_as_uint(*row_src);
2366          num_elements = 1;
2367       }
2368       for (uint32_t element = 0; element < num_elements; ++element) {
2369          uint32_t row = sig_rec->elements[element + base_element].reg;
2370          if (row == 0xffffffff)
2371             continue;
2372          uint32_t stream = sig_rec->elements[element + base_element].stream;
2373          uint32_t table_idx = is_patch_constant ? 1 : stream;
2374          uint32_t *table = tables[table_idx];
2375          uint32_t output_component = component + nir_intrinsic_component(intr);
2376          uint32_t input_component;
2377          BITSET_FOREACH_SET(input_component, input_bits, 32 * 4) {
2378             uint32_t *table_for_input_component = table + table_sizes[table_idx] * input_component;
2379             BITSET_SET(table_for_input_component, row * 4 + output_component);
2380             any_bits_set = true;
2381          }
2382       }
2383    }
2384    return any_bits_set;
2385 }
2386 
2387 static bool
propagate_input_to_output_dependencies(struct dxil_module * mod,nir_intrinsic_instr * load_intr,uint32_t ** tables,const uint32_t * table_sizes)2388 propagate_input_to_output_dependencies(struct dxil_module *mod, nir_intrinsic_instr *load_intr, uint32_t **tables, const uint32_t *table_sizes)
2389 {
2390    /* Which input components are being loaded by this instruction */
2391    BITSET_DECLARE(input_bits, 32 * 4) = { 0 };
2392    if (!set_input_bits(mod, load_intr, input_bits, &tables, &table_sizes))
2393       return false;
2394 
2395    nir_instr_worklist *worklist = nir_instr_worklist_create();
2396    nir_instr_worklist_push_tail(worklist, &load_intr->instr);
2397    bool any_bits_set = false;
2398    nir_foreach_instr_in_worklist(instr, worklist) {
2399       if (instr->pass_flags)
2400          continue;
2401 
2402       instr->pass_flags = 1;
2403       nir_foreach_def(instr, add_def_to_worklist, worklist);
2404       switch (instr->type) {
2405       case nir_instr_type_jump: {
2406          nir_jump_instr *jump = nir_instr_as_jump(instr);
2407          switch (jump->type) {
2408          case nir_jump_break:
2409          case nir_jump_continue: {
2410             nir_cf_node *parent = &instr->block->cf_node;
2411             while (parent->type != nir_cf_node_loop)
2412                parent = parent->parent;
2413             nir_foreach_block_in_cf_node(block, parent)
2414                nir_foreach_instr(i, block)
2415                nir_instr_worklist_push_tail(worklist, i);
2416             }
2417             break;
2418          default:
2419             unreachable("Don't expect any other jumps");
2420          }
2421          break;
2422       }
2423       case nir_instr_type_intrinsic: {
2424          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2425          switch (intr->intrinsic) {
2426          case nir_intrinsic_store_output:
2427          case nir_intrinsic_store_per_vertex_output:
2428             any_bits_set |= set_output_bits(mod, intr, input_bits, tables, table_sizes);
2429             break;
2430             /* TODO: Memory writes */
2431          default:
2432             break;
2433          }
2434          break;
2435       }
2436       default:
2437          break;
2438       }
2439    }
2440 
2441    nir_instr_worklist_destroy(worklist);
2442    return any_bits_set;
2443 }
2444 
2445 /* For every input load, compute the set of output stores that it can contribute to.
2446  * If it contributes to a store to memory, If it's used for control flow, then any
2447  * instruction in the CFG that it impacts is considered to contribute.
2448  * Ideally, we should also handle stores to outputs/memory and then loads from that
2449  * output/memory, but this is non-trivial and unclear how much impact that would have. */
2450 bool
dxil_nir_analyze_io_dependencies(struct dxil_module * mod,nir_shader * s)2451 dxil_nir_analyze_io_dependencies(struct dxil_module *mod, nir_shader *s)
2452 {
2453    bool any_outputs = false;
2454    for (uint32_t i = 0; i < 4; ++i)
2455       any_outputs |= mod->num_psv_outputs[i] > 0;
2456    if (mod->shader_kind == DXIL_HULL_SHADER)
2457       any_outputs |= mod->num_psv_patch_consts > 0;
2458    if (!any_outputs)
2459       return false;
2460 
2461    bool any_bits_set = false;
2462    nir_foreach_function(func, s) {
2463       assert(func->impl);
2464       /* Hull shaders have a patch constant function */
2465       assert(func->is_entrypoint || s->info.stage == MESA_SHADER_TESS_CTRL);
2466 
2467       /* Pass 1: input/view ID -> output dependencies */
2468       nir_foreach_block(block, func->impl) {
2469          nir_foreach_instr(instr, block) {
2470             if (instr->type != nir_instr_type_intrinsic)
2471                continue;
2472             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2473             uint32_t **tables = mod->io_dependency_table;
2474             const uint32_t *table_sizes = mod->dependency_table_dwords_per_input;
2475             switch (intr->intrinsic) {
2476             case nir_intrinsic_load_view_index:
2477                tables = mod->viewid_dependency_table;
2478                FALLTHROUGH;
2479             case nir_intrinsic_load_input:
2480             case nir_intrinsic_load_per_vertex_input:
2481             case nir_intrinsic_load_interpolated_input:
2482                break;
2483             default:
2484                continue;
2485             }
2486 
2487             clear_pass_flags(func->impl);
2488             any_bits_set |= propagate_input_to_output_dependencies(mod, intr, tables, table_sizes);
2489          }
2490       }
2491 
2492       /* Pass 2: output -> output dependencies */
2493       /* TODO */
2494    }
2495    return any_bits_set;
2496 }
2497 
2498 static enum pipe_format
get_format_for_var(unsigned num_comps,enum glsl_base_type sampled_type)2499 get_format_for_var(unsigned num_comps, enum glsl_base_type sampled_type)
2500 {
2501    switch (sampled_type) {
2502    case GLSL_TYPE_INT:
2503    case GLSL_TYPE_INT64:
2504    case GLSL_TYPE_INT16:
2505       switch (num_comps) {
2506       case 1: return PIPE_FORMAT_R32_SINT;
2507       case 2: return PIPE_FORMAT_R32G32_SINT;
2508       case 3: return PIPE_FORMAT_R32G32B32_SINT;
2509       case 4: return PIPE_FORMAT_R32G32B32A32_SINT;
2510       default: unreachable("Invalid num_comps");
2511       }
2512    case GLSL_TYPE_UINT:
2513    case GLSL_TYPE_UINT64:
2514    case GLSL_TYPE_UINT16:
2515       switch (num_comps) {
2516       case 1: return PIPE_FORMAT_R32_UINT;
2517       case 2: return PIPE_FORMAT_R32G32_UINT;
2518       case 3: return PIPE_FORMAT_R32G32B32_UINT;
2519       case 4: return PIPE_FORMAT_R32G32B32A32_UINT;
2520       default: unreachable("Invalid num_comps");
2521       }
2522    case GLSL_TYPE_FLOAT:
2523    case GLSL_TYPE_FLOAT16:
2524    case GLSL_TYPE_DOUBLE:
2525       switch (num_comps) {
2526       case 1: return PIPE_FORMAT_R32_FLOAT;
2527       case 2: return PIPE_FORMAT_R32G32_FLOAT;
2528       case 3: return PIPE_FORMAT_R32G32B32_FLOAT;
2529       case 4: return PIPE_FORMAT_R32G32B32A32_FLOAT;
2530       default: unreachable("Invalid num_comps");
2531       }
2532    default: unreachable("Invalid sampler return type");
2533    }
2534 }
2535 
2536 static unsigned
aoa_size(const struct glsl_type * type)2537 aoa_size(const struct glsl_type *type)
2538 {
2539    return glsl_type_is_array(type) ? glsl_get_aoa_size(type) : 1;
2540 }
2541 
2542 static bool
guess_image_format_for_var(nir_shader * s,nir_variable * var)2543 guess_image_format_for_var(nir_shader *s, nir_variable *var)
2544 {
2545    const struct glsl_type *base_type = glsl_without_array(var->type);
2546    if (!glsl_type_is_image(base_type))
2547       return false;
2548    if (var->data.image.format != PIPE_FORMAT_NONE)
2549       return false;
2550 
2551    nir_foreach_function_impl(impl, s) {
2552       nir_foreach_block(block, impl) {
2553          nir_foreach_instr(instr, block) {
2554             if (instr->type != nir_instr_type_intrinsic)
2555                continue;
2556             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2557             switch (intr->intrinsic) {
2558             case nir_intrinsic_image_deref_load:
2559             case nir_intrinsic_image_deref_store:
2560             case nir_intrinsic_image_deref_atomic:
2561             case nir_intrinsic_image_deref_atomic_swap:
2562                if (nir_intrinsic_get_var(intr, 0) != var)
2563                   continue;
2564                break;
2565             case nir_intrinsic_image_load:
2566             case nir_intrinsic_image_store:
2567             case nir_intrinsic_image_atomic:
2568             case nir_intrinsic_image_atomic_swap: {
2569                unsigned binding = nir_src_as_uint(intr->src[0]);
2570                if (binding < var->data.binding ||
2571                    binding >= var->data.binding + aoa_size(var->type))
2572                   continue;
2573                break;
2574                }
2575             default:
2576                continue;
2577             }
2578             break;
2579 
2580             switch (intr->intrinsic) {
2581             case nir_intrinsic_image_deref_load:
2582             case nir_intrinsic_image_load:
2583             case nir_intrinsic_image_deref_store:
2584             case nir_intrinsic_image_store:
2585                /* Increase unknown formats up to 4 components if a 4-component accessor is used */
2586                if (intr->num_components > util_format_get_nr_components(var->data.image.format))
2587                   var->data.image.format = get_format_for_var(intr->num_components, glsl_get_sampler_result_type(base_type));
2588                break;
2589             default:
2590                /* If an atomic is used, the image format must be 1-component; return immediately */
2591                var->data.image.format = get_format_for_var(1, glsl_get_sampler_result_type(base_type));
2592                return true;
2593             }
2594          }
2595       }
2596    }
2597    /* Dunno what it is, assume 4-component */
2598    if (var->data.image.format == PIPE_FORMAT_NONE)
2599       var->data.image.format = get_format_for_var(4, glsl_get_sampler_result_type(base_type));
2600    return true;
2601 }
2602 
2603 static void
update_intrinsic_format_and_type(nir_intrinsic_instr * intr,nir_variable * var)2604 update_intrinsic_format_and_type(nir_intrinsic_instr *intr, nir_variable *var)
2605 {
2606    nir_intrinsic_set_format(intr, var->data.image.format);
2607    nir_alu_type alu_type =
2608       nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(glsl_without_array(var->type)));
2609    if (nir_intrinsic_has_src_type(intr))
2610       nir_intrinsic_set_src_type(intr, alu_type);
2611    else if (nir_intrinsic_has_dest_type(intr))
2612       nir_intrinsic_set_dest_type(intr, alu_type);
2613 }
2614 
2615 static bool
update_intrinsic_formats(nir_builder * b,nir_intrinsic_instr * intr,void * data)2616 update_intrinsic_formats(nir_builder *b, nir_intrinsic_instr *intr,
2617                          void *data)
2618 {
2619    if (!nir_intrinsic_has_format(intr))
2620       return false;
2621    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
2622    if (deref) {
2623       nir_variable *var = nir_deref_instr_get_variable(deref);
2624       if (var)
2625          update_intrinsic_format_and_type(intr, var);
2626       return var != NULL;
2627    }
2628 
2629    if (!nir_intrinsic_has_range_base(intr))
2630       return false;
2631 
2632    unsigned binding = nir_src_as_uint(intr->src[0]);
2633    nir_foreach_variable_with_modes(var, b->shader, nir_var_image) {
2634       if (var->data.binding <= binding &&
2635           var->data.binding + aoa_size(var->type) > binding) {
2636          update_intrinsic_format_and_type(intr, var);
2637          return true;
2638       }
2639    }
2640    return false;
2641 }
2642 
2643 bool
dxil_nir_guess_image_formats(nir_shader * s)2644 dxil_nir_guess_image_formats(nir_shader *s)
2645 {
2646    bool progress = false;
2647    nir_foreach_variable_with_modes(var, s, nir_var_image) {
2648       progress |= guess_image_format_for_var(s, var);
2649    }
2650    nir_shader_intrinsics_pass(s, update_intrinsic_formats, nir_metadata_all,
2651                               NULL);
2652    return progress;
2653 }
2654 
2655 static void
set_binding_variables_coherent(nir_shader * s,nir_binding binding,nir_variable_mode modes)2656 set_binding_variables_coherent(nir_shader *s, nir_binding binding, nir_variable_mode modes)
2657 {
2658    nir_foreach_variable_with_modes(var, s, modes) {
2659       if (var->data.binding == binding.binding &&
2660           var->data.descriptor_set == binding.desc_set) {
2661          var->data.access |= ACCESS_COHERENT;
2662       }
2663    }
2664 }
2665 
2666 static void
set_deref_variables_coherent(nir_shader * s,nir_deref_instr * deref)2667 set_deref_variables_coherent(nir_shader *s, nir_deref_instr *deref)
2668 {
2669    while (deref->deref_type != nir_deref_type_var &&
2670           deref->deref_type != nir_deref_type_cast) {
2671       deref = nir_deref_instr_parent(deref);
2672    }
2673    if (deref->deref_type == nir_deref_type_var) {
2674       deref->var->data.access |= ACCESS_COHERENT;
2675       return;
2676    }
2677 
2678    /* For derefs with casts, we only support pre-lowered Vulkan accesses */
2679    assert(deref->deref_type == nir_deref_type_cast);
2680    nir_intrinsic_instr *cast_src = nir_instr_as_intrinsic(deref->parent.ssa->parent_instr);
2681    assert(cast_src->intrinsic == nir_intrinsic_load_vulkan_descriptor);
2682    nir_binding binding = nir_chase_binding(cast_src->src[0]);
2683    set_binding_variables_coherent(s, binding, nir_var_mem_ssbo);
2684 }
2685 
2686 static nir_def *
get_atomic_for_load_store(nir_builder * b,nir_intrinsic_instr * intr,unsigned bit_size)2687 get_atomic_for_load_store(nir_builder *b, nir_intrinsic_instr *intr, unsigned bit_size)
2688 {
2689    nir_def *zero = nir_imm_intN_t(b, 0, bit_size);
2690    switch (intr->intrinsic) {
2691    case nir_intrinsic_load_deref:
2692       return nir_deref_atomic(b, bit_size, intr->src[0].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2693    case nir_intrinsic_load_ssbo:
2694       return nir_ssbo_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2695    case nir_intrinsic_image_deref_load:
2696       return nir_image_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2697    case nir_intrinsic_image_load:
2698       return nir_image_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, zero, .atomic_op = nir_atomic_op_iadd);
2699    case nir_intrinsic_store_deref:
2700       return nir_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, .atomic_op = nir_atomic_op_xchg);
2701    case nir_intrinsic_store_ssbo:
2702       return nir_ssbo_atomic(b, bit_size, intr->src[1].ssa, intr->src[2].ssa, intr->src[0].ssa, .atomic_op = nir_atomic_op_xchg);
2703    case nir_intrinsic_image_deref_store:
2704       return nir_image_deref_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, intr->src[3].ssa, .atomic_op = nir_atomic_op_xchg);
2705    case nir_intrinsic_image_store:
2706       return nir_image_atomic(b, bit_size, intr->src[0].ssa, intr->src[1].ssa, intr->src[2].ssa, intr->src[3].ssa, .atomic_op = nir_atomic_op_xchg);
2707    default:
2708       return NULL;
2709    }
2710 }
2711 
2712 static bool
lower_coherent_load_store(nir_builder * b,nir_intrinsic_instr * intr,void * context)2713 lower_coherent_load_store(nir_builder *b, nir_intrinsic_instr *intr, void *context)
2714 {
2715    if (!nir_intrinsic_has_access(intr) || (nir_intrinsic_access(intr) & ACCESS_COHERENT) == 0)
2716       return false;
2717 
2718    nir_def *atomic_def = NULL;
2719    b->cursor = nir_before_instr(&intr->instr);
2720    switch (intr->intrinsic) {
2721    case nir_intrinsic_load_deref:
2722    case nir_intrinsic_load_ssbo:
2723    case nir_intrinsic_image_deref_load:
2724    case nir_intrinsic_image_load: {
2725       if (intr->def.bit_size < 32 || intr->def.num_components > 1) {
2726          if (intr->intrinsic == nir_intrinsic_load_deref)
2727             set_deref_variables_coherent(b->shader, nir_src_as_deref(intr->src[0]));
2728          else {
2729             nir_binding binding = {0};
2730             if (nir_src_is_const(intr->src[0]))
2731                binding.binding = nir_src_as_uint(intr->src[0]);
2732             set_binding_variables_coherent(b->shader, binding,
2733                                            intr->intrinsic == nir_intrinsic_load_ssbo ? nir_var_mem_ssbo : nir_var_image);
2734          }
2735          return false;
2736       }
2737 
2738       atomic_def = get_atomic_for_load_store(b, intr, intr->def.bit_size);
2739       nir_def_rewrite_uses(&intr->def, atomic_def);
2740       break;
2741    }
2742    case nir_intrinsic_store_deref:
2743    case nir_intrinsic_store_ssbo:
2744    case nir_intrinsic_image_deref_store:
2745    case nir_intrinsic_image_store: {
2746       int resource_idx = intr->intrinsic == nir_intrinsic_store_ssbo ? 1 : 0;
2747       int value_idx = intr->intrinsic == nir_intrinsic_store_ssbo ? 0 :
2748          intr->intrinsic == nir_intrinsic_store_deref ? 1 : 3;
2749       unsigned num_components = nir_intrinsic_has_write_mask(intr) ?
2750          util_bitcount(nir_intrinsic_write_mask(intr)) : intr->src[value_idx].ssa->num_components;
2751       if (intr->src[value_idx].ssa->bit_size < 32 || num_components > 1) {
2752          if (intr->intrinsic == nir_intrinsic_store_deref)
2753             set_deref_variables_coherent(b->shader, nir_src_as_deref(intr->src[resource_idx]));
2754          else {
2755             nir_binding binding = {0};
2756             if (nir_src_is_const(intr->src[resource_idx]))
2757                binding.binding = nir_src_as_uint(intr->src[resource_idx]);
2758             set_binding_variables_coherent(b->shader, binding,
2759                                            intr->intrinsic == nir_intrinsic_store_ssbo ? nir_var_mem_ssbo : nir_var_image);
2760          }
2761          return false;
2762       }
2763 
2764       atomic_def = get_atomic_for_load_store(b, intr, intr->src[value_idx].ssa->bit_size);
2765       break;
2766    }
2767    default:
2768       return false;
2769    }
2770 
2771    nir_intrinsic_instr *atomic = nir_instr_as_intrinsic(atomic_def->parent_instr);
2772    nir_intrinsic_set_access(atomic, nir_intrinsic_access(intr));
2773    if (nir_intrinsic_has_image_dim(intr))
2774       nir_intrinsic_set_image_dim(atomic, nir_intrinsic_image_dim(intr));
2775    if (nir_intrinsic_has_image_array(intr))
2776       nir_intrinsic_set_image_array(atomic, nir_intrinsic_image_array(intr));
2777    if (nir_intrinsic_has_format(intr))
2778       nir_intrinsic_set_format(atomic, nir_intrinsic_format(intr));
2779    if (nir_intrinsic_has_range_base(intr))
2780       nir_intrinsic_set_range_base(atomic, nir_intrinsic_range_base(intr));
2781    nir_instr_remove(&intr->instr);
2782    return true;
2783 }
2784 
2785 bool
dxil_nir_lower_coherent_loads_and_stores(nir_shader * s)2786 dxil_nir_lower_coherent_loads_and_stores(nir_shader *s)
2787 {
2788    return nir_shader_intrinsics_pass(s, lower_coherent_load_store,
2789                                      nir_metadata_control_flow | nir_metadata_loop_analysis,
2790                                      NULL);
2791 }
2792 
2793 struct undefined_varying_masks {
2794    uint64_t io_mask;
2795    uint32_t patch_io_mask;
2796    const BITSET_WORD *frac_io_mask;
2797 };
2798 
2799 static bool
is_dead_in_variable(nir_variable * var,void * data)2800 is_dead_in_variable(nir_variable *var, void *data)
2801 {
2802    switch (var->data.location) {
2803    /* Only these values can be system generated values in addition to varyings */
2804    case VARYING_SLOT_PRIMITIVE_ID:
2805    case VARYING_SLOT_FACE:
2806    case VARYING_SLOT_VIEW_INDEX:
2807       return false;
2808    /* Tessellation input vars must remain untouched */
2809    case VARYING_SLOT_TESS_LEVEL_INNER:
2810    case VARYING_SLOT_TESS_LEVEL_OUTER:
2811       return false;
2812    default:
2813       return true;
2814    }
2815 }
2816 
2817 static bool
kill_undefined_varyings(struct nir_builder * b,nir_instr * instr,void * data)2818 kill_undefined_varyings(struct nir_builder *b,
2819                         nir_instr *instr,
2820                         void *data)
2821 {
2822    const struct undefined_varying_masks *masks = data;
2823 
2824    if (instr->type != nir_instr_type_intrinsic)
2825       return false;
2826 
2827    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2828 
2829    if (intr->intrinsic != nir_intrinsic_load_deref)
2830       return false;
2831 
2832    nir_variable *var = nir_intrinsic_get_var(intr, 0);
2833    if (!var || var->data.mode != nir_var_shader_in)
2834       return false;
2835 
2836    if (!is_dead_in_variable(var, NULL))
2837       return false;
2838 
2839    uint32_t loc = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2840       var->data.location - VARYING_SLOT_PATCH0 :
2841       var->data.location;
2842    uint64_t written = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2843       masks->patch_io_mask : masks->io_mask;
2844    if (BITFIELD64_RANGE(loc, glsl_varying_count(var->type)) & written) {
2845       if (!masks->frac_io_mask || !var->data.location_frac ||
2846           var->data.location < VARYING_SLOT_VAR0 ||
2847           BITSET_TEST(masks->frac_io_mask, (var->data.location - VARYING_SLOT_VAR0) * 4 + var->data.location_frac))
2848       return false;
2849    }
2850 
2851    b->cursor = nir_after_instr(instr);
2852    /* Note: zero is used instead of undef, because optimization is not run here, but is
2853     * run later on. If we load an undef here, and that undef ends up being used to store
2854     * to position later on, that can cause some or all of the components in that position
2855     * write to be removed, which is problematic especially in the case of all components,
2856     * since that would remove the store instruction, and would make it tricky to satisfy
2857     * the DXIL requirements of writing all position components.
2858     */
2859    nir_def *zero = nir_imm_zero(b, intr->def.num_components,
2860                                        intr->def.bit_size);
2861    nir_def_replace(&intr->def, zero);
2862    return true;
2863 }
2864 
2865 bool
dxil_nir_kill_undefined_varyings(nir_shader * shader,uint64_t prev_stage_written_mask,uint32_t prev_stage_patch_written_mask,const BITSET_WORD * prev_stage_frac_output_mask)2866 dxil_nir_kill_undefined_varyings(nir_shader *shader, uint64_t prev_stage_written_mask, uint32_t prev_stage_patch_written_mask,
2867                                  const BITSET_WORD *prev_stage_frac_output_mask)
2868 {
2869    struct undefined_varying_masks masks = {
2870       .io_mask = prev_stage_written_mask,
2871       .patch_io_mask = prev_stage_patch_written_mask,
2872       .frac_io_mask = prev_stage_frac_output_mask
2873    };
2874    bool progress = nir_shader_instructions_pass(shader,
2875                                                 kill_undefined_varyings,
2876                                                 nir_metadata_control_flow |
2877                                                 nir_metadata_loop_analysis,
2878                                                 (void *)&masks);
2879    if (progress) {
2880       nir_opt_dce(shader);
2881       nir_remove_dead_derefs(shader);
2882    }
2883 
2884    const struct nir_remove_dead_variables_options options = {
2885       .can_remove_var = is_dead_in_variable,
2886       .can_remove_var_data = &masks,
2887    };
2888    progress |= nir_remove_dead_variables(shader, nir_var_shader_in, &options);
2889    return progress;
2890 }
2891 
2892 static bool
is_dead_out_variable(nir_variable * var,void * data)2893 is_dead_out_variable(nir_variable *var, void *data)
2894 {
2895    return !nir_slot_is_sysval_output(var->data.location, MESA_SHADER_NONE);
2896 }
2897 
2898 static bool
kill_unused_outputs(struct nir_builder * b,nir_instr * instr,void * data)2899 kill_unused_outputs(struct nir_builder *b,
2900                     nir_instr *instr,
2901                     void *data)
2902 {
2903    const struct undefined_varying_masks *masks = data;
2904 
2905    if (instr->type != nir_instr_type_intrinsic)
2906       return false;
2907 
2908    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2909 
2910    if (intr->intrinsic != nir_intrinsic_store_deref &&
2911        intr->intrinsic != nir_intrinsic_load_deref)
2912       return false;
2913 
2914    nir_variable *var = nir_intrinsic_get_var(intr, 0);
2915    if (!var || var->data.mode != nir_var_shader_out ||
2916        /* always_active_io can mean two things: xfb or GL separable shaders. We can't delete
2917         * varyings that are used for xfb (we'll just sort them last), but we must delete varyings
2918         * that are mismatching between TCS and TES. Fortunately TCS can't do xfb, so we can ignore
2919         the always_active_io bit for TCS outputs. */
2920        (b->shader->info.stage != MESA_SHADER_TESS_CTRL && var->data.always_active_io))
2921       return false;
2922 
2923    if (!is_dead_out_variable(var, NULL))
2924       return false;
2925 
2926    unsigned loc = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2927       var->data.location - VARYING_SLOT_PATCH0 :
2928       var->data.location;
2929    uint64_t read = var->data.patch && var->data.location >= VARYING_SLOT_PATCH0 ?
2930       masks->patch_io_mask : masks->io_mask;
2931    if (BITFIELD64_RANGE(loc, glsl_varying_count(var->type)) & read) {
2932       if (!masks->frac_io_mask || !var->data.location_frac ||
2933           var->data.location < VARYING_SLOT_VAR0 ||
2934           BITSET_TEST(masks->frac_io_mask, (var->data.location - VARYING_SLOT_VAR0) * 4 + var->data.location_frac))
2935       return false;
2936    }
2937 
2938    if (intr->intrinsic == nir_intrinsic_load_deref) {
2939       b->cursor = nir_after_instr(&intr->instr);
2940       nir_def *zero = nir_imm_zero(b, intr->def.num_components, intr->def.bit_size);
2941       nir_def_rewrite_uses(&intr->def, zero);
2942    }
2943    nir_instr_remove(instr);
2944    return true;
2945 }
2946 
2947 bool
dxil_nir_kill_unused_outputs(nir_shader * shader,uint64_t next_stage_read_mask,uint32_t next_stage_patch_read_mask,const BITSET_WORD * next_stage_frac_input_mask)2948 dxil_nir_kill_unused_outputs(nir_shader *shader, uint64_t next_stage_read_mask, uint32_t next_stage_patch_read_mask,
2949                              const BITSET_WORD *next_stage_frac_input_mask)
2950 {
2951    struct undefined_varying_masks masks = {
2952       .io_mask = next_stage_read_mask,
2953       .patch_io_mask = next_stage_patch_read_mask,
2954       .frac_io_mask = next_stage_frac_input_mask
2955    };
2956 
2957    bool progress = nir_shader_instructions_pass(shader,
2958                                                 kill_unused_outputs,
2959                                                 nir_metadata_control_flow |
2960                                                 nir_metadata_loop_analysis,
2961                                                 (void *)&masks);
2962 
2963    if (progress) {
2964       nir_opt_dce(shader);
2965       nir_remove_dead_derefs(shader);
2966    }
2967    const struct nir_remove_dead_variables_options options = {
2968       .can_remove_var = is_dead_out_variable,
2969       .can_remove_var_data = &masks,
2970    };
2971    progress |= nir_remove_dead_variables(shader, nir_var_shader_out, &options);
2972    return progress;
2973 }
2974