xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_locals_to_regs.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2014 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_builder_opcodes.h"
27 #include "nir_intrinsics_indices.h"
28 
29 struct locals_to_regs_state {
30    nir_builder builder;
31 
32    /* A hash table mapping derefs to register handles */
33    struct hash_table *regs_table;
34 
35    /* Bit size to use for boolean registers */
36    uint8_t bool_bitsize;
37 
38    bool progress;
39 };
40 
41 /* The following two functions implement a hash and equality check for
42  * variable dreferences.  When the hash or equality function encounters an
43  * array, it ignores the offset and whether it is direct or indirect
44  * entirely.
45  */
46 static uint32_t
hash_deref(const void * void_deref)47 hash_deref(const void *void_deref)
48 {
49    uint32_t hash = 0;
50 
51    for (const nir_deref_instr *deref = void_deref; deref;
52         deref = nir_deref_instr_parent(deref)) {
53       switch (deref->deref_type) {
54       case nir_deref_type_var:
55          return XXH32(&deref->var, sizeof(deref->var), hash);
56 
57       case nir_deref_type_array:
58          continue; /* Do nothing */
59 
60       case nir_deref_type_struct:
61          hash = XXH32(&deref->strct.index, sizeof(deref->strct.index), hash);
62          continue;
63 
64       default:
65          unreachable("Invalid deref type");
66       }
67    }
68 
69    unreachable("We should have hit a variable dereference");
70 }
71 
72 static bool
derefs_equal(const void * void_a,const void * void_b)73 derefs_equal(const void *void_a, const void *void_b)
74 {
75    for (const nir_deref_instr *a = void_a, *b = void_b; a || b;
76         a = nir_deref_instr_parent(a), b = nir_deref_instr_parent(b)) {
77       if (a->deref_type != b->deref_type)
78          return false;
79 
80       switch (a->deref_type) {
81       case nir_deref_type_var:
82          return a->var == b->var;
83 
84       case nir_deref_type_array:
85          continue; /* Do nothing */
86 
87       case nir_deref_type_struct:
88          if (a->strct.index != b->strct.index)
89             return false;
90          continue;
91 
92       default:
93          unreachable("Invalid deref type");
94       }
95    }
96 
97    unreachable("We should have hit a variable dereference");
98 }
99 
100 static nir_def *
get_reg_for_deref(nir_deref_instr * deref,struct locals_to_regs_state * state)101 get_reg_for_deref(nir_deref_instr *deref, struct locals_to_regs_state *state)
102 {
103    uint32_t hash = hash_deref(deref);
104 
105    assert(nir_deref_instr_get_variable(deref)->constant_initializer == NULL &&
106           nir_deref_instr_get_variable(deref)->pointer_initializer == NULL);
107 
108    struct hash_entry *entry =
109       _mesa_hash_table_search_pre_hashed(state->regs_table, hash, deref);
110    if (entry)
111       return entry->data;
112 
113    unsigned array_size = 1;
114    for (nir_deref_instr *d = deref; d; d = nir_deref_instr_parent(d)) {
115       if (d->deref_type == nir_deref_type_array)
116          array_size *= glsl_get_length(nir_deref_instr_parent(d)->type);
117    }
118 
119    assert(glsl_type_is_vector_or_scalar(deref->type));
120 
121    uint8_t bit_size = glsl_get_bit_size(deref->type);
122    if (bit_size == 1)
123       bit_size = state->bool_bitsize;
124 
125    nir_def *reg = nir_decl_reg(&state->builder,
126                                glsl_get_vector_elements(deref->type),
127                                bit_size, array_size > 1 ? array_size : 0);
128 
129    _mesa_hash_table_insert_pre_hashed(state->regs_table, hash, deref, reg);
130 
131    return reg;
132 }
133 
134 struct reg_location {
135    nir_def *reg;
136    nir_def *indirect;
137    unsigned base_offset;
138 };
139 
140 static struct reg_location
get_deref_reg_location(nir_deref_instr * deref,struct locals_to_regs_state * state)141 get_deref_reg_location(nir_deref_instr *deref,
142                        struct locals_to_regs_state *state)
143 {
144    nir_builder *b = &state->builder;
145 
146    nir_def *reg = get_reg_for_deref(deref, state);
147    nir_intrinsic_instr *decl = nir_instr_as_intrinsic(reg->parent_instr);
148 
149    /* It is possible for a user to create a shader that has an array with a
150     * single element and then proceed to access it indirectly.  Indirectly
151     * accessing a non-array register is not allowed in NIR.  In order to
152     * handle this case we just convert it to a direct reference.
153     */
154    if (nir_intrinsic_num_array_elems(decl) == 0)
155       return (struct reg_location){ .reg = reg };
156 
157    nir_def *indirect = NULL;
158    unsigned base_offset = 0;
159 
160    unsigned inner_array_size = 1;
161    for (const nir_deref_instr *d = deref; d; d = nir_deref_instr_parent(d)) {
162       if (d->deref_type != nir_deref_type_array)
163          continue;
164 
165       if (nir_src_is_const(d->arr.index) && !indirect) {
166          base_offset += nir_src_as_uint(d->arr.index) * inner_array_size;
167       } else {
168          if (indirect) {
169             assert(base_offset == 0);
170          } else {
171             indirect = nir_imm_int(b, base_offset);
172             base_offset = 0;
173          }
174 
175          nir_def *index = nir_i2iN(b, d->arr.index.ssa, 32);
176          nir_def *offset = nir_imul_imm(b, index, inner_array_size);
177 
178          /* Avoid emitting iadd with 0, which is otherwise common, since this
179           * pass runs late enough that nothing will clean it up.
180           */
181          nir_scalar scal = nir_get_scalar(indirect, 0);
182          if (nir_scalar_is_const(scal))
183             indirect = nir_iadd_imm(b, offset, nir_scalar_as_uint(scal));
184          else
185             indirect = nir_iadd(b, offset, indirect);
186       }
187 
188       inner_array_size *= glsl_get_length(nir_deref_instr_parent(d)->type);
189    }
190 
191    return (struct reg_location){
192       .reg = reg,
193       .indirect = indirect,
194       .base_offset = base_offset
195    };
196 }
197 
198 static bool
lower_locals_to_regs_block(nir_block * block,struct locals_to_regs_state * state)199 lower_locals_to_regs_block(nir_block *block,
200                            struct locals_to_regs_state *state)
201 {
202    nir_builder *b = &state->builder;
203 
204    nir_foreach_instr_safe(instr, block) {
205       if (instr->type != nir_instr_type_intrinsic)
206          continue;
207 
208       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
209 
210       switch (intrin->intrinsic) {
211       case nir_intrinsic_load_deref: {
212          nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
213          if (!nir_deref_mode_is(deref, nir_var_function_temp))
214             continue;
215 
216          b->cursor = nir_after_instr(&intrin->instr);
217          struct reg_location loc = get_deref_reg_location(deref, state);
218          nir_intrinsic_instr *decl = nir_reg_get_decl(loc.reg);
219 
220          nir_def *value;
221          unsigned num_array_elems = nir_intrinsic_num_array_elems(decl);
222          unsigned num_components = nir_intrinsic_num_components(decl);
223          unsigned bit_size = nir_intrinsic_bit_size(decl);
224 
225          if (loc.base_offset >= MAX2(num_array_elems, 1)) {
226             /* out-of-bounds read, return 0 instead. */
227             value = nir_imm_zero(b, num_components, bit_size);
228          } else if (loc.indirect != NULL) {
229             value = nir_load_reg_indirect(b, num_components, bit_size,
230                                           loc.reg, loc.indirect,
231                                           .base = loc.base_offset);
232          } else {
233             value = nir_build_load_reg(b, num_components, bit_size,
234                                        loc.reg, .base = loc.base_offset);
235          }
236 
237          nir_def_rewrite_uses(&intrin->def, value);
238          nir_instr_remove(&intrin->instr);
239          state->progress = true;
240          break;
241       }
242 
243       case nir_intrinsic_store_deref: {
244          nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
245          if (!nir_deref_mode_is(deref, nir_var_function_temp))
246             continue;
247 
248          b->cursor = nir_before_instr(&intrin->instr);
249 
250          struct reg_location loc = get_deref_reg_location(deref, state);
251          nir_intrinsic_instr *decl = nir_reg_get_decl(loc.reg);
252 
253          nir_def *val = intrin->src[1].ssa;
254          unsigned num_array_elems = nir_intrinsic_num_array_elems(decl);
255          unsigned write_mask = nir_intrinsic_write_mask(intrin);
256 
257          if (loc.base_offset >= MAX2(num_array_elems, 1)) {
258             /* Out of bounds write, just eliminate it. */
259          } else if (loc.indirect) {
260             nir_store_reg_indirect(b, val, loc.reg, loc.indirect,
261                                    .base = loc.base_offset,
262                                    .write_mask = write_mask);
263          } else {
264             nir_build_store_reg(b, val, loc.reg, .base = loc.base_offset,
265                                 .write_mask = write_mask);
266          }
267 
268          nir_instr_remove(&intrin->instr);
269          state->progress = true;
270          break;
271       }
272 
273       case nir_intrinsic_copy_deref:
274          unreachable("There should be no copies whatsoever at this point");
275          break;
276 
277       default:
278          continue;
279       }
280    }
281 
282    return true;
283 }
284 
285 static bool
impl(nir_function_impl * impl,uint8_t bool_bitsize)286 impl(nir_function_impl *impl, uint8_t bool_bitsize)
287 {
288    struct locals_to_regs_state state;
289 
290    state.builder = nir_builder_create(impl);
291    state.progress = false;
292    state.regs_table = _mesa_hash_table_create(NULL, hash_deref, derefs_equal);
293    state.bool_bitsize = bool_bitsize;
294 
295    nir_metadata_require(impl, nir_metadata_dominance);
296 
297    nir_foreach_block(block, impl) {
298       lower_locals_to_regs_block(block, &state);
299    }
300 
301    nir_metadata_preserve(impl, nir_metadata_control_flow);
302 
303    _mesa_hash_table_destroy(state.regs_table, NULL);
304 
305    return state.progress;
306 }
307 
308 bool
nir_lower_locals_to_regs(nir_shader * shader,uint8_t bool_bitsize)309 nir_lower_locals_to_regs(nir_shader *shader, uint8_t bool_bitsize)
310 {
311    bool progress = false;
312 
313    nir_foreach_function_impl(func_impl, shader) {
314       progress = impl(func_impl, bool_bitsize) || progress;
315    }
316 
317    return progress;
318 }
319