xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_split_vars.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2018 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_deref.h"
27 #include "nir_vla.h"
28 
29 #include "util/set.h"
30 #include "util/u_math.h"
31 
32 static bool
is_array_deref_of_vec(nir_deref_instr * deref)33 is_array_deref_of_vec(nir_deref_instr *deref)
34 {
35    if (deref->deref_type != nir_deref_type_array &&
36        deref->deref_type != nir_deref_type_array_wildcard)
37       return false;
38 
39    nir_deref_instr *parent = nir_deref_instr_parent(deref);
40    return glsl_type_is_vector_or_scalar(parent->type);
41 }
42 
43 
44 static struct set *
get_complex_used_vars(nir_shader * shader,void * mem_ctx)45 get_complex_used_vars(nir_shader *shader, void *mem_ctx)
46 {
47    struct set *complex_vars = _mesa_pointer_set_create(mem_ctx);
48 
49    nir_foreach_function_impl(impl, shader) {
50       nir_foreach_block(block, impl) {
51          nir_foreach_instr(instr, block) {
52             if (instr->type != nir_instr_type_deref)
53                continue;
54 
55             nir_deref_instr *deref = nir_instr_as_deref(instr);
56 
57             /* We only need to consider var derefs because
58              * nir_deref_instr_has_complex_use is recursive.
59              */
60             if (deref->deref_type == nir_deref_type_var &&
61                 nir_deref_instr_has_complex_use(deref,
62                                                 nir_deref_instr_has_complex_use_allow_atomics))
63                _mesa_set_add(complex_vars, deref->var);
64          }
65       }
66    }
67 
68    return complex_vars;
69 }
70 
71 struct split_var_state {
72    void *mem_ctx;
73 
74    nir_shader *shader;
75    nir_function_impl *impl;
76 
77    nir_variable *base_var;
78 };
79 
80 struct field {
81    struct field *parent;
82 
83    const struct glsl_type *type;
84 
85    unsigned num_fields;
86    struct field *fields;
87 
88    /* The field currently being recursed */
89    unsigned current_index;
90 
91    nir_variable *var;
92 };
93 
94 static int
num_array_levels_in_array_of_vector_type(const struct glsl_type * type)95 num_array_levels_in_array_of_vector_type(const struct glsl_type *type)
96 {
97    int num_levels = 0;
98    while (true) {
99       if (glsl_type_is_array_or_matrix(type)) {
100          num_levels++;
101          type = glsl_get_array_element(type);
102       } else if (glsl_type_is_vector_or_scalar(type) &&
103                  !glsl_type_is_cmat(type)) {
104          /* glsl_type_is_vector_or_scalar would more accruately be called "can
105           * be an r-value that isn't an array, structure, or matrix. This
106           * optimization pass really shouldn't do anything to cooperative
107           * matrices. These matrices will eventually be lowered to something
108           * else (dependent on the backend), and that thing may (or may not)
109           * be handled by this or another pass.
110           */
111          return num_levels;
112       } else {
113          /* Not an array of vectors */
114          return -1;
115       }
116    }
117 }
118 
119 static nir_constant *
gather_constant_initializers(nir_constant * src,nir_variable * var,const struct glsl_type * type,struct field * field,struct split_var_state * state)120 gather_constant_initializers(nir_constant *src,
121                              nir_variable *var,
122                              const struct glsl_type *type,
123                              struct field *field,
124                              struct split_var_state *state)
125 {
126    if (!src)
127       return NULL;
128    if (glsl_type_is_array(type)) {
129       const struct glsl_type *element = glsl_get_array_element(type);
130       assert(src->num_elements == glsl_get_length(type));
131       nir_constant *dst = rzalloc(var, nir_constant);
132       dst->num_elements = src->num_elements;
133       dst->elements = rzalloc_array(var, nir_constant *, src->num_elements);
134       for (unsigned i = 0; i < src->num_elements; ++i) {
135          dst->elements[i] = gather_constant_initializers(src->elements[i], var, element, field, state);
136       }
137       return dst;
138    } else if (glsl_type_is_struct(type)) {
139       const struct glsl_type *element = glsl_get_struct_field(type, field->current_index);
140       return gather_constant_initializers(src->elements[field->current_index], var, element, &field->fields[field->current_index], state);
141    } else {
142       return nir_constant_clone(src, var);
143    }
144 }
145 
146 static void
init_field_for_type(struct field * field,struct field * parent,const struct glsl_type * type,const char * name,struct split_var_state * state)147 init_field_for_type(struct field *field, struct field *parent,
148                     const struct glsl_type *type,
149                     const char *name,
150                     struct split_var_state *state)
151 {
152    *field = (struct field){
153       .parent = parent,
154       .type = type,
155    };
156 
157    const struct glsl_type *struct_type = glsl_without_array(type);
158    if (glsl_type_is_struct_or_ifc(struct_type)) {
159       field->num_fields = glsl_get_length(struct_type),
160       field->fields = ralloc_array(state->mem_ctx, struct field,
161                                    field->num_fields);
162       for (unsigned i = 0; i < field->num_fields; i++) {
163          char *field_name = NULL;
164          if (name) {
165             field_name = ralloc_asprintf(state->mem_ctx, "%s_%s", name,
166                                          glsl_get_struct_elem_name(struct_type, i));
167          } else {
168             field_name = ralloc_asprintf(state->mem_ctx, "{unnamed %s}_%s",
169                                          glsl_get_type_name(struct_type),
170                                          glsl_get_struct_elem_name(struct_type, i));
171          }
172          field->current_index = i;
173          init_field_for_type(&field->fields[i], field,
174                              glsl_get_struct_field(struct_type, i),
175                              field_name, state);
176       }
177    } else {
178       const struct glsl_type *var_type = type;
179       struct field *root = field;
180       for (struct field *f = field->parent; f; f = f->parent) {
181          var_type = glsl_type_wrap_in_arrays(var_type, f->type);
182          root = f;
183       }
184 
185       nir_variable_mode mode = state->base_var->data.mode;
186       if (mode == nir_var_function_temp) {
187          field->var = nir_local_variable_create(state->impl, var_type, name);
188       } else {
189          field->var = nir_variable_create(state->shader, mode, var_type, name);
190       }
191       field->var->data.ray_query = state->base_var->data.ray_query;
192       field->var->constant_initializer = gather_constant_initializers(state->base_var->constant_initializer,
193                                                                       field->var, state->base_var->type,
194                                                                       root, state);
195    }
196 }
197 
198 static bool
split_var_list_structs(nir_shader * shader,nir_function_impl * impl,struct exec_list * vars,nir_variable_mode mode,struct hash_table * var_field_map,struct set ** complex_vars,void * mem_ctx)199 split_var_list_structs(nir_shader *shader,
200                        nir_function_impl *impl,
201                        struct exec_list *vars,
202                        nir_variable_mode mode,
203                        struct hash_table *var_field_map,
204                        struct set **complex_vars,
205                        void *mem_ctx)
206 {
207    struct split_var_state state = {
208       .mem_ctx = mem_ctx,
209       .shader = shader,
210       .impl = impl,
211    };
212 
213    struct exec_list split_vars;
214    exec_list_make_empty(&split_vars);
215 
216    /* To avoid list confusion (we'll be adding things as we split variables),
217     * pull all of the variables we plan to split off of the list
218     */
219    nir_foreach_variable_in_list_safe(var, vars) {
220       if (var->data.mode != mode)
221          continue;
222 
223       if (!glsl_type_is_struct_or_ifc(glsl_without_array(var->type)))
224          continue;
225 
226       if (*complex_vars == NULL)
227          *complex_vars = get_complex_used_vars(shader, mem_ctx);
228 
229       /* We can't split a variable that's referenced with deref that has any
230        * sort of complex usage.
231        */
232       if (_mesa_set_search(*complex_vars, var))
233          continue;
234 
235       exec_node_remove(&var->node);
236       exec_list_push_tail(&split_vars, &var->node);
237    }
238 
239    nir_foreach_variable_in_list(var, &split_vars) {
240       state.base_var = var;
241 
242       struct field *root_field = ralloc(mem_ctx, struct field);
243       init_field_for_type(root_field, NULL, var->type, var->name, &state);
244       _mesa_hash_table_insert(var_field_map, var, root_field);
245    }
246 
247    return !exec_list_is_empty(&split_vars);
248 }
249 
250 static void
split_struct_derefs_impl(nir_function_impl * impl,struct hash_table * var_field_map,nir_variable_mode modes,void * mem_ctx)251 split_struct_derefs_impl(nir_function_impl *impl,
252                          struct hash_table *var_field_map,
253                          nir_variable_mode modes,
254                          void *mem_ctx)
255 {
256    nir_builder b = nir_builder_create(impl);
257 
258    nir_foreach_block(block, impl) {
259       nir_foreach_instr_safe(instr, block) {
260          if (instr->type != nir_instr_type_deref)
261             continue;
262 
263          nir_deref_instr *deref = nir_instr_as_deref(instr);
264          if (!nir_deref_mode_may_be(deref, modes))
265             continue;
266 
267          /* Clean up any dead derefs we find lying around.  They may refer to
268           * variables we're planning to split.
269           */
270          if (nir_deref_instr_remove_if_unused(deref))
271             continue;
272 
273          if (!glsl_type_is_vector_or_scalar(deref->type))
274             continue;
275 
276          nir_variable *base_var = nir_deref_instr_get_variable(deref);
277          /* If we can't chase back to the variable, then we're a complex use.
278           * This should have been detected by get_complex_used_vars() and the
279           * variable should not have been split.  However, we have no way of
280           * knowing that here, so we just have to trust it.
281           */
282          if (base_var == NULL)
283             continue;
284 
285          struct hash_entry *entry =
286             _mesa_hash_table_search(var_field_map, base_var);
287          if (!entry)
288             continue;
289 
290          struct field *root_field = entry->data;
291 
292          nir_deref_path path;
293          nir_deref_path_init(&path, deref, mem_ctx);
294 
295          struct field *tail_field = root_field;
296          for (unsigned i = 0; path.path[i]; i++) {
297             if (path.path[i]->deref_type != nir_deref_type_struct)
298                continue;
299 
300             assert(i > 0);
301             assert(glsl_type_is_struct_or_ifc(path.path[i - 1]->type));
302             assert(path.path[i - 1]->type ==
303                    glsl_without_array(tail_field->type));
304 
305             tail_field = &tail_field->fields[path.path[i]->strct.index];
306          }
307          nir_variable *split_var = tail_field->var;
308 
309          nir_deref_instr *new_deref = NULL;
310          for (unsigned i = 0; path.path[i]; i++) {
311             nir_deref_instr *p = path.path[i];
312             b.cursor = nir_after_instr(&p->instr);
313 
314             switch (p->deref_type) {
315             case nir_deref_type_var:
316                assert(new_deref == NULL);
317                new_deref = nir_build_deref_var(&b, split_var);
318                break;
319 
320             case nir_deref_type_array:
321             case nir_deref_type_array_wildcard:
322                new_deref = nir_build_deref_follower(&b, new_deref, p);
323                break;
324 
325             case nir_deref_type_struct:
326                /* Nothing to do; we're splitting structs */
327                break;
328 
329             default:
330                unreachable("Invalid deref type in path");
331             }
332          }
333 
334          assert(new_deref->type == deref->type);
335          nir_def_rewrite_uses(&deref->def,
336                               &new_deref->def);
337          nir_deref_instr_remove_if_unused(deref);
338       }
339    }
340 }
341 
342 /** A pass for splitting structs into multiple variables
343  *
344  * This pass splits arrays of structs into multiple variables, one for each
345  * (possibly nested) structure member.  After this pass completes, no
346  * variables of the given mode will contain a struct type.
347  */
348 bool
nir_split_struct_vars(nir_shader * shader,nir_variable_mode modes)349 nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
350 {
351    void *mem_ctx = ralloc_context(NULL);
352    struct hash_table *var_field_map =
353       _mesa_pointer_hash_table_create(mem_ctx);
354    struct set *complex_vars = NULL;
355 
356    bool has_global_splits = false;
357    nir_variable_mode global_modes = modes & ~nir_var_function_temp;
358    if (global_modes) {
359       has_global_splits = split_var_list_structs(shader, NULL,
360                                                  &shader->variables,
361                                                  global_modes,
362                                                  var_field_map,
363                                                  &complex_vars,
364                                                  mem_ctx);
365    }
366 
367    bool progress = false;
368    nir_foreach_function_impl(impl, shader) {
369       bool has_local_splits = false;
370       if (modes & nir_var_function_temp) {
371          has_local_splits = split_var_list_structs(shader, impl,
372                                                    &impl->locals,
373                                                    nir_var_function_temp,
374                                                    var_field_map,
375                                                    &complex_vars,
376                                                    mem_ctx);
377       }
378 
379       if (has_global_splits || has_local_splits) {
380          split_struct_derefs_impl(impl, var_field_map,
381                                   modes, mem_ctx);
382 
383          nir_metadata_preserve(impl, nir_metadata_control_flow);
384          progress = true;
385       } else {
386          nir_metadata_preserve(impl, nir_metadata_all);
387       }
388    }
389 
390    ralloc_free(mem_ctx);
391 
392    return progress;
393 }
394 
395 struct array_level_info {
396    unsigned array_len;
397    bool split;
398 };
399 
400 struct array_split {
401    /* Only set if this is the tail end of the splitting */
402    nir_variable *var;
403 
404    unsigned num_splits;
405    struct array_split *splits;
406 };
407 
408 struct array_var_info {
409    nir_variable *base_var;
410 
411    const struct glsl_type *split_var_type;
412 
413    bool split_var;
414    struct array_split root_split;
415 
416    unsigned num_levels;
417    struct array_level_info levels[0];
418 };
419 
420 static bool
init_var_list_array_infos(nir_shader * shader,struct exec_list * vars,nir_variable_mode mode,struct hash_table * var_info_map,struct set ** complex_vars,void * mem_ctx)421 init_var_list_array_infos(nir_shader *shader,
422                           struct exec_list *vars,
423                           nir_variable_mode mode,
424                           struct hash_table *var_info_map,
425                           struct set **complex_vars,
426                           void *mem_ctx)
427 {
428    bool has_array = false;
429 
430    nir_foreach_variable_in_list(var, vars) {
431       if (var->data.mode != mode)
432          continue;
433 
434       int num_levels = num_array_levels_in_array_of_vector_type(var->type);
435       if (num_levels <= 0)
436          continue;
437 
438       if (*complex_vars == NULL)
439          *complex_vars = get_complex_used_vars(shader, mem_ctx);
440 
441       /* We can't split a variable that's referenced with deref that has any
442        * sort of complex usage.
443        */
444       if (_mesa_set_search(*complex_vars, var))
445          continue;
446 
447       struct array_var_info *info =
448          rzalloc_size(mem_ctx, sizeof(*info) +
449                                   num_levels * sizeof(info->levels[0]));
450 
451       info->base_var = var;
452       info->num_levels = num_levels;
453 
454       const struct glsl_type *type = var->type;
455       for (int i = 0; i < num_levels; i++) {
456          info->levels[i].array_len = glsl_get_length(type);
457          type = glsl_get_array_element(type);
458 
459          /* All levels start out initially as split */
460          info->levels[i].split = true;
461       }
462 
463       _mesa_hash_table_insert(var_info_map, var, info);
464       has_array = true;
465    }
466 
467    return has_array;
468 }
469 
470 static struct array_var_info *
get_array_var_info(nir_variable * var,struct hash_table * var_info_map)471 get_array_var_info(nir_variable *var,
472                    struct hash_table *var_info_map)
473 {
474    struct hash_entry *entry =
475       _mesa_hash_table_search(var_info_map, var);
476    return entry ? entry->data : NULL;
477 }
478 
479 static struct array_var_info *
get_array_deref_info(nir_deref_instr * deref,struct hash_table * var_info_map,nir_variable_mode modes)480 get_array_deref_info(nir_deref_instr *deref,
481                      struct hash_table *var_info_map,
482                      nir_variable_mode modes)
483 {
484    if (!nir_deref_mode_may_be(deref, modes))
485       return NULL;
486 
487    nir_variable *var = nir_deref_instr_get_variable(deref);
488    if (var == NULL)
489       return NULL;
490 
491    return get_array_var_info(var, var_info_map);
492 }
493 
494 static void
mark_array_deref_used(nir_deref_instr * deref,struct hash_table * var_info_map,nir_variable_mode modes,void * mem_ctx)495 mark_array_deref_used(nir_deref_instr *deref,
496                       struct hash_table *var_info_map,
497                       nir_variable_mode modes,
498                       void *mem_ctx)
499 {
500    struct array_var_info *info =
501       get_array_deref_info(deref, var_info_map, modes);
502    if (!info)
503       return;
504 
505    nir_deref_path path;
506    nir_deref_path_init(&path, deref, mem_ctx);
507 
508    /* Walk the path and look for indirects.  If we have an array deref with an
509     * indirect, mark the given level as not being split.
510     */
511    for (unsigned i = 0; i < info->num_levels; i++) {
512       nir_deref_instr *p = path.path[i + 1];
513       if (p->deref_type == nir_deref_type_array &&
514           !nir_src_is_const(p->arr.index))
515          info->levels[i].split = false;
516    }
517 }
518 
519 static void
mark_array_usage_impl(nir_function_impl * impl,struct hash_table * var_info_map,nir_variable_mode modes,void * mem_ctx)520 mark_array_usage_impl(nir_function_impl *impl,
521                       struct hash_table *var_info_map,
522                       nir_variable_mode modes,
523                       void *mem_ctx)
524 {
525    nir_foreach_block(block, impl) {
526       nir_foreach_instr(instr, block) {
527          if (instr->type != nir_instr_type_intrinsic)
528             continue;
529 
530          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
531          switch (intrin->intrinsic) {
532          case nir_intrinsic_copy_deref:
533             mark_array_deref_used(nir_src_as_deref(intrin->src[1]),
534                                   var_info_map, modes, mem_ctx);
535             FALLTHROUGH;
536 
537          case nir_intrinsic_load_deref:
538          case nir_intrinsic_store_deref:
539             mark_array_deref_used(nir_src_as_deref(intrin->src[0]),
540                                   var_info_map, modes, mem_ctx);
541             break;
542 
543          default:
544             break;
545          }
546       }
547    }
548 }
549 
550 static void
create_split_array_vars(struct array_var_info * var_info,unsigned level,struct array_split * split,const char * name,nir_shader * shader,nir_function_impl * impl,void * mem_ctx)551 create_split_array_vars(struct array_var_info *var_info,
552                         unsigned level,
553                         struct array_split *split,
554                         const char *name,
555                         nir_shader *shader,
556                         nir_function_impl *impl,
557                         void *mem_ctx)
558 {
559    while (level < var_info->num_levels && !var_info->levels[level].split) {
560       name = ralloc_asprintf(mem_ctx, "%s[*]", name);
561       level++;
562    }
563 
564    if (level == var_info->num_levels) {
565       /* We add parens to the variable name so it looks like "(foo[2][*])" so
566        * that further derefs will look like "(foo[2][*])[ssa_6]"
567        */
568       name = ralloc_asprintf(mem_ctx, "(%s)", name);
569 
570       nir_variable_mode mode = var_info->base_var->data.mode;
571       if (mode == nir_var_function_temp) {
572          split->var = nir_local_variable_create(impl,
573                                                 var_info->split_var_type, name);
574       } else {
575          split->var = nir_variable_create(shader, mode,
576                                           var_info->split_var_type, name);
577       }
578       split->var->data.ray_query = var_info->base_var->data.ray_query;
579    } else {
580       assert(var_info->levels[level].split);
581       split->num_splits = var_info->levels[level].array_len;
582       split->splits = rzalloc_array(mem_ctx, struct array_split,
583                                     split->num_splits);
584       for (unsigned i = 0; i < split->num_splits; i++) {
585          create_split_array_vars(var_info, level + 1, &split->splits[i],
586                                  ralloc_asprintf(mem_ctx, "%s[%d]", name, i),
587                                  shader, impl, mem_ctx);
588       }
589    }
590 }
591 
592 static bool
split_var_list_arrays(nir_shader * shader,nir_function_impl * impl,struct exec_list * vars,nir_variable_mode mode,struct hash_table * var_info_map,void * mem_ctx)593 split_var_list_arrays(nir_shader *shader,
594                       nir_function_impl *impl,
595                       struct exec_list *vars,
596                       nir_variable_mode mode,
597                       struct hash_table *var_info_map,
598                       void *mem_ctx)
599 {
600    struct exec_list split_vars;
601    exec_list_make_empty(&split_vars);
602 
603    nir_foreach_variable_in_list_safe(var, vars) {
604       if (var->data.mode != mode)
605          continue;
606 
607       struct array_var_info *info = get_array_var_info(var, var_info_map);
608       if (!info)
609          continue;
610 
611       bool has_split = false;
612       const struct glsl_type *split_type =
613          glsl_without_array_or_matrix(var->type);
614       for (int i = info->num_levels - 1; i >= 0; i--) {
615          if (info->levels[i].split) {
616             has_split = true;
617             continue;
618          }
619 
620          /* If the original type was a matrix type, we'd like to keep that so
621           * we don't convert matrices into arrays.
622           */
623          if (i == info->num_levels - 1 &&
624              glsl_type_is_matrix(glsl_without_array(var->type))) {
625             split_type = glsl_matrix_type(glsl_get_base_type(split_type),
626                                           glsl_get_components(split_type),
627                                           info->levels[i].array_len);
628          } else {
629             split_type = glsl_array_type(split_type, info->levels[i].array_len, 0);
630          }
631       }
632 
633       if (has_split) {
634          info->split_var_type = split_type;
635          /* To avoid list confusion (we'll be adding things as we split
636           * variables), pull all of the variables we plan to split off of the
637           * main variable list.
638           */
639          exec_node_remove(&var->node);
640          exec_list_push_tail(&split_vars, &var->node);
641       } else {
642          assert(split_type == glsl_get_bare_type(var->type));
643          /* If we're not modifying this variable, delete the info so we skip
644           * it faster in later passes.
645           */
646          _mesa_hash_table_remove_key(var_info_map, var);
647       }
648    }
649 
650    nir_foreach_variable_in_list(var, &split_vars) {
651       struct array_var_info *info = get_array_var_info(var, var_info_map);
652       create_split_array_vars(info, 0, &info->root_split, var->name,
653                               shader, impl, mem_ctx);
654    }
655 
656    return !exec_list_is_empty(&split_vars);
657 }
658 
659 static bool
deref_has_split_wildcard(nir_deref_path * path,struct array_var_info * info)660 deref_has_split_wildcard(nir_deref_path *path,
661                          struct array_var_info *info)
662 {
663    if (info == NULL)
664       return false;
665 
666    assert(path->path[0]->var == info->base_var);
667    for (unsigned i = 0; i < info->num_levels; i++) {
668       if (path->path[i + 1]->deref_type == nir_deref_type_array_wildcard &&
669           info->levels[i].split)
670          return true;
671    }
672 
673    return false;
674 }
675 
676 static bool
array_path_is_out_of_bounds(nir_deref_path * path,struct array_var_info * info)677 array_path_is_out_of_bounds(nir_deref_path *path,
678                             struct array_var_info *info)
679 {
680    if (info == NULL)
681       return false;
682 
683    assert(path->path[0]->var == info->base_var);
684    for (unsigned i = 0; i < info->num_levels; i++) {
685       nir_deref_instr *p = path->path[i + 1];
686       if (p->deref_type == nir_deref_type_array_wildcard)
687          continue;
688 
689       if (nir_src_is_const(p->arr.index) &&
690           nir_src_as_uint(p->arr.index) >= info->levels[i].array_len)
691          return true;
692    }
693 
694    return false;
695 }
696 
697 static void
emit_split_copies(nir_builder * b,struct array_var_info * dst_info,nir_deref_path * dst_path,unsigned dst_level,nir_deref_instr * dst,struct array_var_info * src_info,nir_deref_path * src_path,unsigned src_level,nir_deref_instr * src)698 emit_split_copies(nir_builder *b,
699                   struct array_var_info *dst_info, nir_deref_path *dst_path,
700                   unsigned dst_level, nir_deref_instr *dst,
701                   struct array_var_info *src_info, nir_deref_path *src_path,
702                   unsigned src_level, nir_deref_instr *src)
703 {
704    nir_deref_instr *dst_p, *src_p;
705 
706    while ((dst_p = dst_path->path[dst_level + 1])) {
707       if (dst_p->deref_type == nir_deref_type_array_wildcard)
708          break;
709 
710       dst = nir_build_deref_follower(b, dst, dst_p);
711       dst_level++;
712    }
713 
714    while ((src_p = src_path->path[src_level + 1])) {
715       if (src_p->deref_type == nir_deref_type_array_wildcard)
716          break;
717 
718       src = nir_build_deref_follower(b, src, src_p);
719       src_level++;
720    }
721 
722    if (src_p == NULL || dst_p == NULL) {
723       assert(src_p == NULL && dst_p == NULL);
724       nir_copy_deref(b, dst, src);
725    } else {
726       assert(dst_p->deref_type == nir_deref_type_array_wildcard &&
727              src_p->deref_type == nir_deref_type_array_wildcard);
728 
729       if ((dst_info && dst_info->levels[dst_level].split) ||
730           (src_info && src_info->levels[src_level].split)) {
731          /* There are no indirects at this level on one of the source or the
732           * destination so we are lowering it.
733           */
734          assert(glsl_get_length(dst_path->path[dst_level]->type) ==
735                 glsl_get_length(src_path->path[src_level]->type));
736          unsigned len = glsl_get_length(dst_path->path[dst_level]->type);
737          for (unsigned i = 0; i < len; i++) {
738             emit_split_copies(b, dst_info, dst_path, dst_level + 1,
739                               nir_build_deref_array_imm(b, dst, i),
740                               src_info, src_path, src_level + 1,
741                               nir_build_deref_array_imm(b, src, i));
742          }
743       } else {
744          /* Neither side is being split so we just keep going */
745          emit_split_copies(b, dst_info, dst_path, dst_level + 1,
746                            nir_build_deref_array_wildcard(b, dst),
747                            src_info, src_path, src_level + 1,
748                            nir_build_deref_array_wildcard(b, src));
749       }
750    }
751 }
752 
753 static void
split_array_copies_impl(nir_function_impl * impl,struct hash_table * var_info_map,nir_variable_mode modes,void * mem_ctx)754 split_array_copies_impl(nir_function_impl *impl,
755                         struct hash_table *var_info_map,
756                         nir_variable_mode modes,
757                         void *mem_ctx)
758 {
759    nir_builder b = nir_builder_create(impl);
760 
761    nir_foreach_block(block, impl) {
762       nir_foreach_instr_safe(instr, block) {
763          if (instr->type != nir_instr_type_intrinsic)
764             continue;
765 
766          nir_intrinsic_instr *copy = nir_instr_as_intrinsic(instr);
767          if (copy->intrinsic != nir_intrinsic_copy_deref)
768             continue;
769 
770          nir_deref_instr *dst_deref = nir_src_as_deref(copy->src[0]);
771          nir_deref_instr *src_deref = nir_src_as_deref(copy->src[1]);
772 
773          struct array_var_info *dst_info =
774             get_array_deref_info(dst_deref, var_info_map, modes);
775          struct array_var_info *src_info =
776             get_array_deref_info(src_deref, var_info_map, modes);
777 
778          if (!src_info && !dst_info)
779             continue;
780 
781          nir_deref_path dst_path, src_path;
782          nir_deref_path_init(&dst_path, dst_deref, mem_ctx);
783          nir_deref_path_init(&src_path, src_deref, mem_ctx);
784 
785          if (!deref_has_split_wildcard(&dst_path, dst_info) &&
786              !deref_has_split_wildcard(&src_path, src_info))
787             continue;
788 
789          b.cursor = nir_instr_remove(&copy->instr);
790 
791          emit_split_copies(&b, dst_info, &dst_path, 0, dst_path.path[0],
792                            src_info, &src_path, 0, src_path.path[0]);
793       }
794    }
795 }
796 
797 static void
split_array_access_impl(nir_function_impl * impl,struct hash_table * var_info_map,nir_variable_mode modes,void * mem_ctx)798 split_array_access_impl(nir_function_impl *impl,
799                         struct hash_table *var_info_map,
800                         nir_variable_mode modes,
801                         void *mem_ctx)
802 {
803    nir_builder b = nir_builder_create(impl);
804 
805    nir_foreach_block(block, impl) {
806       nir_foreach_instr_safe(instr, block) {
807          if (instr->type == nir_instr_type_deref) {
808             /* Clean up any dead derefs we find lying around.  They may refer
809              * to variables we're planning to split.
810              */
811             nir_deref_instr *deref = nir_instr_as_deref(instr);
812             if (nir_deref_mode_may_be(deref, modes))
813                nir_deref_instr_remove_if_unused(deref);
814             continue;
815          }
816 
817          if (instr->type != nir_instr_type_intrinsic)
818             continue;
819 
820          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
821          if (intrin->intrinsic != nir_intrinsic_load_deref &&
822              intrin->intrinsic != nir_intrinsic_store_deref &&
823              intrin->intrinsic != nir_intrinsic_copy_deref)
824             continue;
825 
826          const unsigned num_derefs =
827             intrin->intrinsic == nir_intrinsic_copy_deref ? 2 : 1;
828 
829          for (unsigned d = 0; d < num_derefs; d++) {
830             nir_deref_instr *deref = nir_src_as_deref(intrin->src[d]);
831 
832             struct array_var_info *info =
833                get_array_deref_info(deref, var_info_map, modes);
834             if (!info)
835                continue;
836 
837             nir_deref_path path;
838             nir_deref_path_init(&path, deref, mem_ctx);
839 
840             b.cursor = nir_before_instr(&intrin->instr);
841 
842             if (array_path_is_out_of_bounds(&path, info)) {
843                /* If one of the derefs is out-of-bounds, we just delete the
844                 * instruction.  If a destination is out of bounds, then it may
845                 * have been in-bounds prior to shrinking so we don't want to
846                 * accidentally stomp something.  However, we've already proven
847                 * that it will never be read so it's safe to delete.  If a
848                 * source is out of bounds then it is loading random garbage.
849                 * For loads, we replace their uses with an undef instruction
850                 * and for copies we just delete the copy since it was writing
851                 * undefined garbage anyway and we may as well leave the random
852                 * garbage in the destination alone.
853                 */
854                if (intrin->intrinsic == nir_intrinsic_load_deref) {
855                   nir_def *u =
856                      nir_undef(&b, intrin->def.num_components,
857                                intrin->def.bit_size);
858                   nir_def_rewrite_uses(&intrin->def,
859                                        u);
860                }
861                nir_instr_remove(&intrin->instr);
862                for (unsigned i = 0; i < num_derefs; i++)
863                   nir_deref_instr_remove_if_unused(nir_src_as_deref(intrin->src[i]));
864                break;
865             }
866 
867             struct array_split *split = &info->root_split;
868             for (unsigned i = 0; i < info->num_levels; i++) {
869                if (info->levels[i].split) {
870                   nir_deref_instr *p = path.path[i + 1];
871                   unsigned index = nir_src_as_uint(p->arr.index);
872                   assert(index < info->levels[i].array_len);
873                   split = &split->splits[index];
874                }
875             }
876             assert(!split->splits && split->var);
877 
878             nir_deref_instr *new_deref = nir_build_deref_var(&b, split->var);
879             for (unsigned i = 0; i < info->num_levels; i++) {
880                if (!info->levels[i].split) {
881                   new_deref = nir_build_deref_follower(&b, new_deref,
882                                                        path.path[i + 1]);
883                }
884             }
885 
886             if (is_array_deref_of_vec(deref))
887                new_deref = nir_build_deref_follower(&b, new_deref, deref);
888 
889             assert(new_deref->type == deref->type);
890 
891             /* Rewrite the deref source to point to the split one */
892             nir_src_rewrite(&intrin->src[d], &new_deref->def);
893             nir_deref_instr_remove_if_unused(deref);
894          }
895       }
896    }
897 }
898 
899 /** A pass for splitting arrays of vectors into multiple variables
900  *
901  * This pass looks at arrays (possibly multiple levels) of vectors (not
902  * structures or other types) and tries to split them into piles of variables,
903  * one for each array element.  The heuristic used is simple: If a given array
904  * level is never used with an indirect, that array level will get split.
905  *
906  * This pass probably could handles structures easily enough but making a pass
907  * that could see through an array of structures of arrays would be difficult
908  * so it's best to just run nir_split_struct_vars first.
909  */
910 bool
nir_split_array_vars(nir_shader * shader,nir_variable_mode modes)911 nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
912 {
913    void *mem_ctx = ralloc_context(NULL);
914    struct hash_table *var_info_map = _mesa_pointer_hash_table_create(mem_ctx);
915    struct set *complex_vars = NULL;
916 
917    bool has_global_array = false;
918    if (modes & (~nir_var_function_temp)) {
919       has_global_array = init_var_list_array_infos(shader,
920                                                    &shader->variables,
921                                                    modes,
922                                                    var_info_map,
923                                                    &complex_vars,
924                                                    mem_ctx);
925    }
926 
927    bool has_any_array = false;
928    nir_foreach_function_impl(impl, shader) {
929       bool has_local_array = false;
930       if (modes & nir_var_function_temp) {
931          has_local_array = init_var_list_array_infos(shader,
932                                                      &impl->locals,
933                                                      nir_var_function_temp,
934                                                      var_info_map,
935                                                      &complex_vars,
936                                                      mem_ctx);
937       }
938 
939       if (has_global_array || has_local_array) {
940          has_any_array = true;
941          mark_array_usage_impl(impl, var_info_map, modes, mem_ctx);
942       }
943    }
944 
945    /* If we failed to find any arrays of arrays, bail early. */
946    if (!has_any_array) {
947       ralloc_free(mem_ctx);
948       nir_shader_preserve_all_metadata(shader);
949       return false;
950    }
951 
952    bool has_global_splits = false;
953    if (modes & (~nir_var_function_temp)) {
954       has_global_splits = split_var_list_arrays(shader, NULL,
955                                                 &shader->variables,
956                                                 modes,
957                                                 var_info_map, mem_ctx);
958    }
959 
960    bool progress = false;
961    nir_foreach_function_impl(impl, shader) {
962       bool has_local_splits = false;
963       if (modes & nir_var_function_temp) {
964          has_local_splits = split_var_list_arrays(shader, impl,
965                                                   &impl->locals,
966                                                   nir_var_function_temp,
967                                                   var_info_map, mem_ctx);
968       }
969 
970       if (has_global_splits || has_local_splits) {
971          split_array_copies_impl(impl, var_info_map, modes, mem_ctx);
972          split_array_access_impl(impl, var_info_map, modes, mem_ctx);
973 
974          nir_metadata_preserve(impl, nir_metadata_control_flow);
975          progress = true;
976       } else {
977          nir_metadata_preserve(impl, nir_metadata_all);
978       }
979    }
980 
981    ralloc_free(mem_ctx);
982 
983    return progress;
984 }
985 
986 struct array_level_usage {
987    unsigned array_len;
988 
989    /* The value UINT_MAX will be used to indicate an indirect */
990    unsigned max_read;
991    unsigned max_written;
992 
993    /* True if there is a copy that isn't to/from a shrinkable array */
994    bool has_external_copy;
995    struct set *levels_copied;
996 };
997 
998 struct vec_var_usage {
999    /* Convenience set of all components this variable has */
1000    nir_component_mask_t all_comps;
1001 
1002    nir_component_mask_t comps_read;
1003    nir_component_mask_t comps_written;
1004 
1005    nir_component_mask_t comps_kept;
1006 
1007    /* True if there is a copy that isn't to/from a shrinkable vector */
1008    bool has_external_copy;
1009    bool has_complex_use;
1010    struct set *vars_copied;
1011 
1012    unsigned num_levels;
1013    struct array_level_usage levels[0];
1014 };
1015 
1016 static struct vec_var_usage *
get_vec_var_usage(nir_variable * var,struct hash_table * var_usage_map,bool add_usage_entry,void * mem_ctx)1017 get_vec_var_usage(nir_variable *var,
1018                   struct hash_table *var_usage_map,
1019                   bool add_usage_entry, void *mem_ctx)
1020 {
1021    struct hash_entry *entry = _mesa_hash_table_search(var_usage_map, var);
1022    if (entry)
1023       return entry->data;
1024 
1025    if (!add_usage_entry)
1026       return NULL;
1027 
1028    /* Check to make sure that we are working with an array of vectors.  We
1029     * don't bother to shrink single vectors because we figure that we can
1030     * clean it up better with SSA than by inserting piles of vecN instructions
1031     * to compact results.
1032     */
1033    int num_levels = num_array_levels_in_array_of_vector_type(var->type);
1034    if (num_levels < 1)
1035       return NULL; /* Not an array of vectors */
1036 
1037    struct vec_var_usage *usage =
1038       rzalloc_size(mem_ctx, sizeof(*usage) +
1039                                num_levels * sizeof(usage->levels[0]));
1040 
1041    usage->num_levels = num_levels;
1042    const struct glsl_type *type = var->type;
1043    for (unsigned i = 0; i < num_levels; i++) {
1044       usage->levels[i].array_len = glsl_get_length(type);
1045       type = glsl_get_array_element(type);
1046    }
1047    assert(glsl_type_is_vector_or_scalar(type));
1048 
1049    usage->all_comps = (1 << glsl_get_components(type)) - 1;
1050 
1051    _mesa_hash_table_insert(var_usage_map, var, usage);
1052 
1053    return usage;
1054 }
1055 
1056 static struct vec_var_usage *
get_vec_deref_usage(nir_deref_instr * deref,struct hash_table * var_usage_map,nir_variable_mode modes,bool add_usage_entry,void * mem_ctx)1057 get_vec_deref_usage(nir_deref_instr *deref,
1058                     struct hash_table *var_usage_map,
1059                     nir_variable_mode modes,
1060                     bool add_usage_entry, void *mem_ctx)
1061 {
1062    if (!nir_deref_mode_may_be(deref, modes))
1063       return NULL;
1064 
1065    nir_variable *var = nir_deref_instr_get_variable(deref);
1066    if (var == NULL)
1067       return NULL;
1068 
1069    return get_vec_var_usage(nir_deref_instr_get_variable(deref),
1070                             var_usage_map, add_usage_entry, mem_ctx);
1071 }
1072 
1073 static void
mark_deref_if_complex(nir_deref_instr * deref,struct hash_table * var_usage_map,nir_variable_mode modes,void * mem_ctx)1074 mark_deref_if_complex(nir_deref_instr *deref,
1075                       struct hash_table *var_usage_map,
1076                       nir_variable_mode modes,
1077                       void *mem_ctx)
1078 {
1079    /* Only bother with var derefs because nir_deref_instr_has_complex_use is
1080     * recursive.
1081     */
1082    if (deref->deref_type != nir_deref_type_var)
1083       return;
1084 
1085    if (!(deref->var->data.mode & modes))
1086       return;
1087 
1088    if (!nir_deref_instr_has_complex_use(deref, nir_deref_instr_has_complex_use_allow_atomics))
1089       return;
1090 
1091    struct vec_var_usage *usage =
1092       get_vec_var_usage(deref->var, var_usage_map, true, mem_ctx);
1093    if (!usage)
1094       return;
1095 
1096    usage->has_complex_use = true;
1097 }
1098 
1099 static void
mark_deref_used(nir_deref_instr * deref,nir_component_mask_t comps_read,nir_component_mask_t comps_written,nir_deref_instr * copy_deref,struct hash_table * var_usage_map,nir_variable_mode modes,void * mem_ctx)1100 mark_deref_used(nir_deref_instr *deref,
1101                 nir_component_mask_t comps_read,
1102                 nir_component_mask_t comps_written,
1103                 nir_deref_instr *copy_deref,
1104                 struct hash_table *var_usage_map,
1105                 nir_variable_mode modes,
1106                 void *mem_ctx)
1107 {
1108    if (!nir_deref_mode_may_be(deref, modes))
1109       return;
1110 
1111    nir_variable *var = nir_deref_instr_get_variable(deref);
1112    if (var == NULL)
1113       return;
1114 
1115    struct vec_var_usage *usage =
1116       get_vec_var_usage(var, var_usage_map, true, mem_ctx);
1117    if (!usage)
1118       return;
1119 
1120    if (is_array_deref_of_vec(deref)) {
1121       if (comps_read)
1122          comps_read = usage->all_comps;
1123       if (comps_written)
1124          comps_written = usage->all_comps;
1125    }
1126 
1127    usage->comps_read |= comps_read & usage->all_comps;
1128    usage->comps_written |= comps_written & usage->all_comps;
1129 
1130    struct vec_var_usage *copy_usage = NULL;
1131    if (copy_deref) {
1132       copy_usage = get_vec_deref_usage(copy_deref, var_usage_map, modes,
1133                                        true, mem_ctx);
1134       if (copy_usage) {
1135          if (usage->vars_copied == NULL) {
1136             usage->vars_copied = _mesa_pointer_set_create(mem_ctx);
1137          }
1138          _mesa_set_add(usage->vars_copied, copy_usage);
1139       } else {
1140          usage->has_external_copy = true;
1141       }
1142    }
1143 
1144    nir_deref_path path;
1145    nir_deref_path_init(&path, deref, mem_ctx);
1146 
1147    nir_deref_path copy_path;
1148    if (copy_usage)
1149       nir_deref_path_init(&copy_path, copy_deref, mem_ctx);
1150 
1151    unsigned copy_i = 0;
1152    for (unsigned i = 0; i < usage->num_levels; i++) {
1153       struct array_level_usage *level = &usage->levels[i];
1154       nir_deref_instr *deref = path.path[i + 1];
1155       assert(deref->deref_type == nir_deref_type_array ||
1156              deref->deref_type == nir_deref_type_array_wildcard);
1157 
1158       unsigned max_used;
1159       if (deref->deref_type == nir_deref_type_array) {
1160          max_used = nir_src_is_const(deref->arr.index) ? nir_src_as_uint(deref->arr.index) : UINT_MAX;
1161       } else {
1162          /* For wildcards, we read or wrote the whole thing. */
1163          assert(deref->deref_type == nir_deref_type_array_wildcard);
1164          max_used = level->array_len - 1;
1165 
1166          if (copy_usage) {
1167             /* Match each wildcard level with the level on copy_usage */
1168             for (; copy_path.path[copy_i + 1]; copy_i++) {
1169                if (copy_path.path[copy_i + 1]->deref_type ==
1170                    nir_deref_type_array_wildcard)
1171                   break;
1172             }
1173             struct array_level_usage *copy_level =
1174                &copy_usage->levels[copy_i++];
1175 
1176             if (level->levels_copied == NULL) {
1177                level->levels_copied = _mesa_pointer_set_create(mem_ctx);
1178             }
1179             _mesa_set_add(level->levels_copied, copy_level);
1180          } else {
1181             /* We have a wildcard and it comes from a variable we aren't
1182              * tracking; flag it and we'll know to not shorten this array.
1183              */
1184             level->has_external_copy = true;
1185          }
1186       }
1187 
1188       if (comps_written)
1189          level->max_written = MAX2(level->max_written, max_used);
1190       if (comps_read)
1191          level->max_read = MAX2(level->max_read, max_used);
1192    }
1193 }
1194 
1195 static bool
src_is_load_deref(nir_src src,nir_src deref_src)1196 src_is_load_deref(nir_src src, nir_src deref_src)
1197 {
1198    nir_intrinsic_instr *load = nir_src_as_intrinsic(src);
1199    if (load == NULL || load->intrinsic != nir_intrinsic_load_deref)
1200       return false;
1201 
1202    return load->src[0].ssa == deref_src.ssa;
1203 }
1204 
1205 /* Returns all non-self-referential components of a store instruction.  A
1206  * component is self-referential if it comes from the same component of a load
1207  * instruction on the same deref.  If the only data in a particular component
1208  * of a variable came directly from that component then it's undefined.  The
1209  * only way to get defined data into a component of a variable is for it to
1210  * get written there by something outside or from a different component.
1211  *
1212  * This is a fairly common pattern in shaders that come from either GLSL IR or
1213  * GLSLang because both glsl_to_nir and GLSLang implement write-masking with
1214  * load-vec-store.
1215  */
1216 static nir_component_mask_t
get_non_self_referential_store_comps(nir_intrinsic_instr * store)1217 get_non_self_referential_store_comps(nir_intrinsic_instr *store)
1218 {
1219    nir_component_mask_t comps = nir_intrinsic_write_mask(store);
1220 
1221    nir_instr *src_instr = store->src[1].ssa->parent_instr;
1222    if (src_instr->type != nir_instr_type_alu)
1223       return comps;
1224 
1225    nir_alu_instr *src_alu = nir_instr_as_alu(src_instr);
1226 
1227    if (src_alu->op == nir_op_mov) {
1228       /* If it's just a swizzle of a load from the same deref, discount any
1229        * channels that don't move in the swizzle.
1230        */
1231       if (src_is_load_deref(src_alu->src[0].src, store->src[0])) {
1232          for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
1233             if (src_alu->src[0].swizzle[i] == i)
1234                comps &= ~(1u << i);
1235          }
1236       }
1237    } else if (nir_op_is_vec(src_alu->op)) {
1238       /* If it's a vec, discount any channels that are just loads from the
1239        * same deref put in the same spot.
1240        */
1241       for (unsigned i = 0; i < nir_op_infos[src_alu->op].num_inputs; i++) {
1242          if (src_is_load_deref(src_alu->src[i].src, store->src[0]) &&
1243              src_alu->src[i].swizzle[0] == i)
1244             comps &= ~(1u << i);
1245       }
1246    }
1247 
1248    return comps;
1249 }
1250 
1251 static void
find_used_components_impl(nir_function_impl * impl,struct hash_table * var_usage_map,nir_variable_mode modes,void * mem_ctx)1252 find_used_components_impl(nir_function_impl *impl,
1253                           struct hash_table *var_usage_map,
1254                           nir_variable_mode modes,
1255                           void *mem_ctx)
1256 {
1257    nir_foreach_block(block, impl) {
1258       nir_foreach_instr(instr, block) {
1259          if (instr->type == nir_instr_type_deref) {
1260             mark_deref_if_complex(nir_instr_as_deref(instr),
1261                                   var_usage_map, modes, mem_ctx);
1262          }
1263 
1264          if (instr->type != nir_instr_type_intrinsic)
1265             continue;
1266 
1267          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1268          switch (intrin->intrinsic) {
1269          case nir_intrinsic_load_deref:
1270             mark_deref_used(nir_src_as_deref(intrin->src[0]),
1271                             nir_def_components_read(&intrin->def), 0,
1272                             NULL, var_usage_map, modes, mem_ctx);
1273             break;
1274 
1275          case nir_intrinsic_store_deref:
1276             mark_deref_used(nir_src_as_deref(intrin->src[0]),
1277                             0, get_non_self_referential_store_comps(intrin),
1278                             NULL, var_usage_map, modes, mem_ctx);
1279             break;
1280 
1281          case nir_intrinsic_copy_deref: {
1282             /* Just mark everything used for copies. */
1283             nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
1284             nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
1285             mark_deref_used(dst, 0, ~0, src, var_usage_map, modes, mem_ctx);
1286             mark_deref_used(src, ~0, 0, dst, var_usage_map, modes, mem_ctx);
1287             break;
1288          }
1289 
1290          default:
1291             break;
1292          }
1293       }
1294    }
1295 }
1296 
1297 static bool
shrink_vec_var_list(struct exec_list * vars,nir_variable_mode mode,struct hash_table * var_usage_map)1298 shrink_vec_var_list(struct exec_list *vars,
1299                     nir_variable_mode mode,
1300                     struct hash_table *var_usage_map)
1301 {
1302    /* Initialize the components kept field of each variable.  This is the
1303     * AND of the components written and components read.  If a component is
1304     * written but never read, it's dead.  If it is read but never written,
1305     * then all values read are undefined garbage and we may as well not read
1306     * them.
1307     *
1308     * The same logic applies to the array length.  We make the array length
1309     * the minimum needed required length between read and write and plan to
1310     * discard any OOB access.  The one exception here is indirect writes
1311     * because we don't know where they will land and we can't shrink an array
1312     * with indirect writes because previously in-bounds writes may become
1313     * out-of-bounds and have undefined behavior.
1314     *
1315     * Also, if we have a copy that to/from something we can't shrink, we need
1316     * to leave components and array_len of any wildcards alone.
1317     */
1318    nir_foreach_variable_in_list(var, vars) {
1319       if (var->data.mode != mode)
1320          continue;
1321 
1322       struct vec_var_usage *usage =
1323          get_vec_var_usage(var, var_usage_map, false, NULL);
1324       if (!usage)
1325          continue;
1326 
1327       assert(usage->comps_kept == 0);
1328       if (usage->has_external_copy || usage->has_complex_use)
1329          usage->comps_kept = usage->all_comps;
1330       else
1331          usage->comps_kept = usage->comps_read & usage->comps_written;
1332 
1333       for (unsigned i = 0; i < usage->num_levels; i++) {
1334          struct array_level_usage *level = &usage->levels[i];
1335          assert(level->array_len > 0);
1336 
1337          if (level->max_written == UINT_MAX || level->has_external_copy ||
1338              usage->has_complex_use)
1339             continue; /* Can't shrink */
1340 
1341          unsigned max_used = MIN2(level->max_read, level->max_written);
1342          level->array_len = MIN2(max_used, level->array_len - 1) + 1;
1343       }
1344    }
1345 
1346    /* In order for variable copies to work, we have to have the same data type
1347     * on the source and the destination.  In order to satisfy this, we run a
1348     * little fixed-point algorithm to transitively ensure that we get enough
1349     * components and array elements for this to hold for all copies.
1350     */
1351    bool fp_progress;
1352    do {
1353       fp_progress = false;
1354       nir_foreach_variable_in_list(var, vars) {
1355          if (var->data.mode != mode)
1356             continue;
1357 
1358          struct vec_var_usage *var_usage =
1359             get_vec_var_usage(var, var_usage_map, false, NULL);
1360          if (!var_usage || !var_usage->vars_copied)
1361             continue;
1362 
1363          set_foreach(var_usage->vars_copied, copy_entry) {
1364             struct vec_var_usage *copy_usage = (void *)copy_entry->key;
1365             if (copy_usage->comps_kept != var_usage->comps_kept) {
1366                nir_component_mask_t comps_kept =
1367                   (var_usage->comps_kept | copy_usage->comps_kept);
1368                var_usage->comps_kept = comps_kept;
1369                copy_usage->comps_kept = comps_kept;
1370                fp_progress = true;
1371             }
1372          }
1373 
1374          for (unsigned i = 0; i < var_usage->num_levels; i++) {
1375             struct array_level_usage *var_level = &var_usage->levels[i];
1376             if (!var_level->levels_copied)
1377                continue;
1378 
1379             set_foreach(var_level->levels_copied, copy_entry) {
1380                struct array_level_usage *copy_level = (void *)copy_entry->key;
1381                if (var_level->array_len != copy_level->array_len) {
1382                   unsigned array_len =
1383                      MAX2(var_level->array_len, copy_level->array_len);
1384                   var_level->array_len = array_len;
1385                   copy_level->array_len = array_len;
1386                   fp_progress = true;
1387                }
1388             }
1389          }
1390       }
1391    } while (fp_progress);
1392 
1393    bool vars_shrunk = false;
1394    nir_foreach_variable_in_list_safe(var, vars) {
1395       if (var->data.mode != mode)
1396          continue;
1397 
1398       struct vec_var_usage *usage =
1399          get_vec_var_usage(var, var_usage_map, false, NULL);
1400       if (!usage)
1401          continue;
1402 
1403       bool shrunk = false;
1404       const struct glsl_type *vec_type = var->type;
1405       for (unsigned i = 0; i < usage->num_levels; i++) {
1406          /* If we've reduced the array to zero elements at some level, just
1407           * set comps_kept to 0 and delete the variable.
1408           */
1409          if (usage->levels[i].array_len == 0) {
1410             usage->comps_kept = 0;
1411             break;
1412          }
1413 
1414          assert(usage->levels[i].array_len <= glsl_get_length(vec_type));
1415          if (usage->levels[i].array_len < glsl_get_length(vec_type))
1416             shrunk = true;
1417          vec_type = glsl_get_array_element(vec_type);
1418       }
1419       assert(glsl_type_is_vector_or_scalar(vec_type));
1420 
1421       assert(usage->comps_kept == (usage->comps_kept & usage->all_comps));
1422       if (usage->comps_kept != usage->all_comps)
1423          shrunk = true;
1424 
1425       if (usage->comps_kept == 0) {
1426          /* This variable is dead, remove it */
1427          vars_shrunk = true;
1428          exec_node_remove(&var->node);
1429          continue;
1430       }
1431 
1432       if (!shrunk) {
1433          /* This variable doesn't need to be shrunk.  Remove it from the
1434           * hash table so later steps will ignore it.
1435           */
1436          _mesa_hash_table_remove_key(var_usage_map, var);
1437          continue;
1438       }
1439 
1440       /* Build the new var type */
1441       unsigned new_num_comps = util_bitcount(usage->comps_kept);
1442       const struct glsl_type *new_type =
1443          glsl_vector_type(glsl_get_base_type(vec_type), new_num_comps);
1444       for (int i = usage->num_levels - 1; i >= 0; i--) {
1445          assert(usage->levels[i].array_len > 0);
1446          /* If the original type was a matrix type, we'd like to keep that so
1447           * we don't convert matrices into arrays.
1448           */
1449          if (i == usage->num_levels - 1 &&
1450              glsl_type_is_matrix(glsl_without_array(var->type)) &&
1451              new_num_comps > 1 && usage->levels[i].array_len > 1) {
1452             new_type = glsl_matrix_type(glsl_get_base_type(new_type),
1453                                         new_num_comps,
1454                                         usage->levels[i].array_len);
1455          } else {
1456             new_type = glsl_array_type(new_type, usage->levels[i].array_len, 0);
1457          }
1458       }
1459       var->type = new_type;
1460 
1461       vars_shrunk = true;
1462    }
1463 
1464    return vars_shrunk;
1465 }
1466 
1467 static bool
vec_deref_is_oob(nir_deref_instr * deref,struct vec_var_usage * usage)1468 vec_deref_is_oob(nir_deref_instr *deref,
1469                  struct vec_var_usage *usage)
1470 {
1471    nir_deref_path path;
1472    nir_deref_path_init(&path, deref, NULL);
1473 
1474    bool oob = false;
1475    for (unsigned i = 0; i < usage->num_levels; i++) {
1476       nir_deref_instr *p = path.path[i + 1];
1477       if (p->deref_type == nir_deref_type_array_wildcard)
1478          continue;
1479 
1480       if (nir_src_is_const(p->arr.index) &&
1481           nir_src_as_uint(p->arr.index) >= usage->levels[i].array_len) {
1482          oob = true;
1483          break;
1484       }
1485    }
1486 
1487    nir_deref_path_finish(&path);
1488 
1489    return oob;
1490 }
1491 
1492 static bool
vec_deref_is_dead_or_oob(nir_deref_instr * deref,struct hash_table * var_usage_map,nir_variable_mode modes)1493 vec_deref_is_dead_or_oob(nir_deref_instr *deref,
1494                          struct hash_table *var_usage_map,
1495                          nir_variable_mode modes)
1496 {
1497    struct vec_var_usage *usage =
1498       get_vec_deref_usage(deref, var_usage_map, modes, false, NULL);
1499    if (!usage)
1500       return false;
1501 
1502    return usage->comps_kept == 0 || vec_deref_is_oob(deref, usage);
1503 }
1504 
1505 static void
shrink_vec_var_access_impl(nir_function_impl * impl,struct hash_table * var_usage_map,nir_variable_mode modes)1506 shrink_vec_var_access_impl(nir_function_impl *impl,
1507                            struct hash_table *var_usage_map,
1508                            nir_variable_mode modes)
1509 {
1510    nir_builder b = nir_builder_create(impl);
1511 
1512    nir_foreach_block(block, impl) {
1513       nir_foreach_instr_safe(instr, block) {
1514          switch (instr->type) {
1515          case nir_instr_type_deref: {
1516             nir_deref_instr *deref = nir_instr_as_deref(instr);
1517             if (!nir_deref_mode_may_be(deref, modes))
1518                break;
1519 
1520             /* Clean up any dead derefs we find lying around.  They may refer
1521              * to variables we've deleted.
1522              */
1523             if (nir_deref_instr_remove_if_unused(deref))
1524                break;
1525 
1526             /* Update the type in the deref to keep the types consistent as
1527              * you walk down the chain.  We don't need to check if this is one
1528              * of the derefs we're shrinking because this is a no-op if it
1529              * isn't.  The worst that could happen is that we accidentally fix
1530              * an invalid deref.
1531              */
1532             if (deref->deref_type == nir_deref_type_var) {
1533                deref->type = deref->var->type;
1534             } else if (deref->deref_type == nir_deref_type_array ||
1535                        deref->deref_type == nir_deref_type_array_wildcard) {
1536                nir_deref_instr *parent = nir_deref_instr_parent(deref);
1537                assert(glsl_type_is_array(parent->type) ||
1538                       glsl_type_is_matrix(parent->type) ||
1539                       glsl_type_is_vector(parent->type));
1540                deref->type = glsl_get_array_element(parent->type);
1541             }
1542             break;
1543          }
1544 
1545          case nir_instr_type_intrinsic: {
1546             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1547 
1548             /* If we have a copy whose source or destination has been deleted
1549              * because we determined the variable was dead, then we just
1550              * delete the copy instruction.  If the source variable was dead
1551              * then it was writing undefined garbage anyway and if it's the
1552              * destination variable that's dead then the write isn't needed.
1553              */
1554             if (intrin->intrinsic == nir_intrinsic_copy_deref) {
1555                nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
1556                nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
1557                if (vec_deref_is_dead_or_oob(dst, var_usage_map, modes) ||
1558                    vec_deref_is_dead_or_oob(src, var_usage_map, modes)) {
1559                   nir_instr_remove(&intrin->instr);
1560                   nir_deref_instr_remove_if_unused(dst);
1561                   nir_deref_instr_remove_if_unused(src);
1562                }
1563                continue;
1564             }
1565 
1566             if (intrin->intrinsic != nir_intrinsic_load_deref &&
1567                 intrin->intrinsic != nir_intrinsic_store_deref)
1568                continue;
1569 
1570             nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1571             if (!nir_deref_mode_may_be(deref, modes))
1572                continue;
1573 
1574             struct vec_var_usage *usage =
1575                get_vec_deref_usage(deref, var_usage_map, modes, false, NULL);
1576             if (!usage)
1577                continue;
1578 
1579             if (usage->comps_kept == 0 || vec_deref_is_oob(deref, usage)) {
1580                if (intrin->intrinsic == nir_intrinsic_load_deref) {
1581                   nir_def *u =
1582                      nir_undef(&b, intrin->def.num_components,
1583                                intrin->def.bit_size);
1584                   nir_def_rewrite_uses(&intrin->def,
1585                                        u);
1586                }
1587                nir_instr_remove(&intrin->instr);
1588                nir_deref_instr_remove_if_unused(deref);
1589                continue;
1590             }
1591 
1592             /* If we're not dropping any components, there's no need to
1593              * compact vectors.
1594              */
1595             if (usage->comps_kept == usage->all_comps)
1596                continue;
1597 
1598             if (intrin->intrinsic == nir_intrinsic_load_deref) {
1599                b.cursor = nir_after_instr(&intrin->instr);
1600 
1601                nir_def *undef =
1602                   nir_undef(&b, 1, intrin->def.bit_size);
1603                nir_def *vec_srcs[NIR_MAX_VEC_COMPONENTS];
1604                unsigned c = 0;
1605                for (unsigned i = 0; i < intrin->num_components; i++) {
1606                   if (usage->comps_kept & (1u << i))
1607                      vec_srcs[i] = nir_channel(&b, &intrin->def, c++);
1608                   else
1609                      vec_srcs[i] = undef;
1610                }
1611                nir_def *vec = nir_vec(&b, vec_srcs, intrin->num_components);
1612 
1613                nir_def_rewrite_uses_after(&intrin->def,
1614                                           vec,
1615                                           vec->parent_instr);
1616 
1617                /* The SSA def is now only used by the swizzle.  It's safe to
1618                 * shrink the number of components.
1619                 */
1620                assert(list_length(&intrin->def.uses) == c);
1621                intrin->num_components = c;
1622                intrin->def.num_components = c;
1623             } else {
1624                nir_component_mask_t write_mask =
1625                   nir_intrinsic_write_mask(intrin);
1626 
1627                unsigned swizzle[NIR_MAX_VEC_COMPONENTS];
1628                nir_component_mask_t new_write_mask = 0;
1629                unsigned c = 0;
1630                for (unsigned i = 0; i < intrin->num_components; i++) {
1631                   if (usage->comps_kept & (1u << i)) {
1632                      swizzle[c] = i;
1633                      if (write_mask & (1u << i))
1634                         new_write_mask |= 1u << c;
1635                      c++;
1636                   }
1637                }
1638 
1639                b.cursor = nir_before_instr(&intrin->instr);
1640 
1641                nir_def *swizzled =
1642                   nir_swizzle(&b, intrin->src[1].ssa, swizzle, c);
1643 
1644                /* Rewrite to use the compacted source */
1645                nir_src_rewrite(&intrin->src[1], swizzled);
1646                nir_intrinsic_set_write_mask(intrin, new_write_mask);
1647                intrin->num_components = c;
1648             }
1649             break;
1650          }
1651 
1652          default:
1653             break;
1654          }
1655       }
1656    }
1657 }
1658 
1659 static bool
function_impl_has_vars_with_modes(nir_function_impl * impl,nir_variable_mode modes)1660 function_impl_has_vars_with_modes(nir_function_impl *impl,
1661                                   nir_variable_mode modes)
1662 {
1663    nir_shader *shader = impl->function->shader;
1664 
1665    if (modes & ~nir_var_function_temp) {
1666       nir_foreach_variable_with_modes(var, shader,
1667                                       modes & ~nir_var_function_temp)
1668          return true;
1669    }
1670 
1671    if ((modes & nir_var_function_temp) && !exec_list_is_empty(&impl->locals))
1672       return true;
1673 
1674    return false;
1675 }
1676 
1677 /** Attempt to shrink arrays of vectors
1678  *
1679  * This pass looks at variables which contain a vector or an array (possibly
1680  * multiple dimensions) of vectors and attempts to lower to a smaller vector
1681  * or array.  If the pass can prove that a component of a vector (or array of
1682  * vectors) is never really used, then that component will be removed.
1683  * Similarly, the pass attempts to shorten arrays based on what elements it
1684  * can prove are never read or never contain valid data.
1685  */
1686 bool
nir_shrink_vec_array_vars(nir_shader * shader,nir_variable_mode modes)1687 nir_shrink_vec_array_vars(nir_shader *shader, nir_variable_mode modes)
1688 {
1689    assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
1690 
1691    void *mem_ctx = ralloc_context(NULL);
1692 
1693    struct hash_table *var_usage_map =
1694       _mesa_pointer_hash_table_create(mem_ctx);
1695 
1696    bool has_vars_to_shrink = false;
1697    nir_foreach_function_impl(impl, shader) {
1698       /* Don't even bother crawling the IR if we don't have any variables.
1699        * Given that this pass deletes any unused variables, it's likely that
1700        * we will be in this scenario eventually.
1701        */
1702       if (function_impl_has_vars_with_modes(impl, modes)) {
1703          has_vars_to_shrink = true;
1704          find_used_components_impl(impl, var_usage_map,
1705                                    modes, mem_ctx);
1706       }
1707    }
1708    if (!has_vars_to_shrink) {
1709       ralloc_free(mem_ctx);
1710       nir_shader_preserve_all_metadata(shader);
1711       return false;
1712    }
1713 
1714    bool globals_shrunk = false;
1715    if (modes & nir_var_shader_temp) {
1716       globals_shrunk = shrink_vec_var_list(&shader->variables,
1717                                            nir_var_shader_temp,
1718                                            var_usage_map);
1719    }
1720 
1721    bool progress = false;
1722    nir_foreach_function_impl(impl, shader) {
1723       bool locals_shrunk = false;
1724       if (modes & nir_var_function_temp) {
1725          locals_shrunk = shrink_vec_var_list(&impl->locals,
1726                                              nir_var_function_temp,
1727                                              var_usage_map);
1728       }
1729 
1730       if (globals_shrunk || locals_shrunk) {
1731          shrink_vec_var_access_impl(impl, var_usage_map, modes);
1732 
1733          nir_metadata_preserve(impl, nir_metadata_control_flow);
1734          progress = true;
1735       } else {
1736          nir_metadata_preserve(impl, nir_metadata_all);
1737       }
1738    }
1739 
1740    ralloc_free(mem_ctx);
1741 
1742    return progress;
1743 }
1744