xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_opt_shrink_vectors.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2020 Google LLC
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 /**
25  * @file
26  *
27  * Removes unused components of SSA defs.
28  *
29  * Due to various optimization passes (or frontend implementations,
30  * particularly prog_to_nir), we may have instructions generating vectors
31  * whose components don't get read by any instruction.
32  *
33  * For memory loads, while it can be tricky to eliminate unused low components
34  * or channels in the middle of a writemask (you might need to increment some
35  * offset from a load_uniform, for example), it is trivial to just drop the
36  * trailing components. This pass shrinks low components on select intrinsics.
37  * For vector ALU and load_const, only used by other ALU instructions,
38  * this pass eliminates arbitrary channels as well as duplicate channels,
39  * and reswizzles the uses.
40  *
41  * This pass is probably only of use to vector backends -- scalar backends
42  * typically get unused def channel trimming by scalarizing and dead code
43  * elimination.
44  */
45 
46 #include "util/u_math.h"
47 #include "nir.h"
48 #include "nir_builder.h"
49 
50 /*
51  * Round up a vector size to a vector size that's valid in NIR. At present, NIR
52  * supports only vec2-5, vec8, and vec16. Attempting to generate other sizes
53  * will fail validation.
54  */
55 static unsigned
round_up_components(unsigned n)56 round_up_components(unsigned n)
57 {
58    return (n > 5) ? util_next_power_of_two(n) : n;
59 }
60 
61 static void
reswizzle_alu_uses(nir_def * def,uint8_t * reswizzle)62 reswizzle_alu_uses(nir_def *def, uint8_t *reswizzle)
63 {
64    nir_foreach_use(use_src, def) {
65       /* all uses must be ALU instructions */
66       assert(nir_src_parent_instr(use_src)->type == nir_instr_type_alu);
67       nir_alu_src *alu_src = (nir_alu_src *)use_src;
68 
69       /* reswizzle ALU sources */
70       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
71          alu_src->swizzle[i] = reswizzle[alu_src->swizzle[i]];
72    }
73 }
74 
75 static bool
is_only_used_by_alu(nir_def * def)76 is_only_used_by_alu(nir_def *def)
77 {
78    nir_foreach_use(use_src, def) {
79       if (nir_src_parent_instr(use_src)->type != nir_instr_type_alu)
80          return false;
81    }
82 
83    return true;
84 }
85 
86 static bool
shrink_dest_to_read_mask(nir_def * def,bool shrink_start)87 shrink_dest_to_read_mask(nir_def *def, bool shrink_start)
88 {
89    /* early out if there's nothing to do. */
90    if (def->num_components == 1)
91       return false;
92 
93    /* don't remove any channels if used by an intrinsic */
94    nir_foreach_use(use_src, def) {
95       if (nir_src_parent_instr(use_src)->type == nir_instr_type_intrinsic)
96          return false;
97    }
98 
99    unsigned mask = nir_def_components_read(def);
100 
101    /* If nothing was read, leave it up to DCE. */
102    if (!mask)
103       return false;
104 
105    nir_intrinsic_instr *intr = NULL;
106    if (def->parent_instr->type == nir_instr_type_intrinsic)
107       intr = nir_instr_as_intrinsic(def->parent_instr);
108 
109    shrink_start &= (intr != NULL) && nir_intrinsic_has_component(intr) &&
110                    is_only_used_by_alu(def);
111 
112    int last_bit = util_last_bit(mask);
113    int first_bit = shrink_start ? (ffs(mask) - 1) : 0;
114 
115    const unsigned comps = last_bit - first_bit;
116    const unsigned rounded = round_up_components(comps);
117    assert(rounded <= def->num_components);
118 
119    if ((def->num_components > rounded) || first_bit > 0) {
120       def->num_components = rounded;
121 
122       if (first_bit) {
123          assert(shrink_start);
124 
125          nir_intrinsic_set_component(intr, nir_intrinsic_component(intr) + first_bit);
126 
127          /* Reswizzle sources, which must be ALU since they have swizzle */
128          uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
129          for (unsigned i = 0; i < comps; ++i) {
130             swizzle[first_bit + i] = i;
131          }
132 
133          reswizzle_alu_uses(def, swizzle);
134       }
135 
136       return true;
137    }
138 
139    return false;
140 }
141 
142 static bool
shrink_intrinsic_to_non_sparse(nir_intrinsic_instr * instr)143 shrink_intrinsic_to_non_sparse(nir_intrinsic_instr *instr)
144 {
145    unsigned mask = nir_def_components_read(&instr->def);
146    int last_bit = util_last_bit(mask);
147 
148    /* If the sparse component is used, do nothing. */
149    if (last_bit == instr->def.num_components)
150       return false;
151 
152    instr->def.num_components -= 1;
153    instr->num_components = instr->def.num_components;
154 
155    /* Switch to the non-sparse intrinsic. */
156    switch (instr->intrinsic) {
157    case nir_intrinsic_image_sparse_load:
158       instr->intrinsic = nir_intrinsic_image_load;
159       break;
160    case nir_intrinsic_bindless_image_sparse_load:
161       instr->intrinsic = nir_intrinsic_bindless_image_load;
162       break;
163    case nir_intrinsic_image_deref_sparse_load:
164       instr->intrinsic = nir_intrinsic_image_deref_load;
165       break;
166    default:
167       break;
168    }
169 
170    return true;
171 }
172 
173 static bool
opt_shrink_vector(nir_builder * b,nir_alu_instr * instr)174 opt_shrink_vector(nir_builder *b, nir_alu_instr *instr)
175 {
176    nir_def *def = &instr->def;
177    unsigned mask = nir_def_components_read(def);
178 
179    /* If nothing was read, leave it up to DCE. */
180    if (mask == 0)
181       return false;
182 
183    /* don't remove any channels if used by non-ALU */
184    if (!is_only_used_by_alu(def))
185       return false;
186 
187    uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
188    nir_scalar srcs[NIR_MAX_VEC_COMPONENTS] = { 0 };
189    unsigned num_components = 0;
190    for (unsigned i = 0; i < def->num_components; i++) {
191       if (!((mask >> i) & 0x1))
192          continue;
193 
194       nir_scalar scalar = nir_get_scalar(instr->src[i].src.ssa, instr->src[i].swizzle[0]);
195 
196       /* Try reuse a component with the same value */
197       unsigned j;
198       for (j = 0; j < num_components; j++) {
199          if (nir_scalar_equal(scalar, srcs[j])) {
200             reswizzle[i] = j;
201             break;
202          }
203       }
204 
205       /* Otherwise, just append the value */
206       if (j == num_components) {
207          srcs[num_components] = scalar;
208          reswizzle[i] = num_components++;
209       }
210    }
211 
212    /* return if no component was removed */
213    if (num_components == def->num_components)
214       return false;
215 
216    /* create new vecN and replace uses */
217    nir_def *new_vec = nir_vec_scalars(b, srcs, num_components);
218    nir_def_rewrite_uses(def, new_vec);
219    reswizzle_alu_uses(new_vec, reswizzle);
220 
221    return true;
222 }
223 
224 static bool
opt_shrink_vectors_alu(nir_builder * b,nir_alu_instr * instr)225 opt_shrink_vectors_alu(nir_builder *b, nir_alu_instr *instr)
226 {
227    nir_def *def = &instr->def;
228 
229    /* Nothing to shrink */
230    if (def->num_components == 1)
231       return false;
232 
233    switch (instr->op) {
234    /* don't use nir_op_is_vec() as not all vector sizes are supported. */
235    case nir_op_vec4:
236    case nir_op_vec3:
237    case nir_op_vec2:
238       return opt_shrink_vector(b, instr);
239    default:
240       if (nir_op_infos[instr->op].output_size != 0)
241          return false;
242       break;
243    }
244 
245    /* don't remove any channels if used by non-ALU */
246    if (!is_only_used_by_alu(def))
247       return false;
248 
249    unsigned mask = nir_def_components_read(def);
250    /* return, if there is nothing to do */
251    if (mask == 0)
252       return false;
253 
254    uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
255    unsigned num_components = 0;
256    bool progress = false;
257    for (unsigned i = 0; i < def->num_components; i++) {
258       /* skip unused components */
259       if (!((mask >> i) & 0x1))
260          continue;
261 
262       /* Try reuse a component with the same swizzles */
263       unsigned j;
264       for (j = 0; j < num_components; j++) {
265          bool duplicate_channel = true;
266          for (unsigned k = 0; k < nir_op_infos[instr->op].num_inputs; k++) {
267             if (nir_op_infos[instr->op].input_sizes[k] != 0 ||
268                 instr->src[k].swizzle[i] != instr->src[k].swizzle[j]) {
269                duplicate_channel = false;
270                break;
271             }
272          }
273 
274          if (duplicate_channel) {
275             reswizzle[i] = j;
276             progress = true;
277             break;
278          }
279       }
280 
281       /* Otherwise, just append the value */
282       if (j == num_components) {
283          for (int k = 0; k < nir_op_infos[instr->op].num_inputs; k++) {
284             instr->src[k].swizzle[num_components] = instr->src[k].swizzle[i];
285          }
286          if (i != num_components)
287             progress = true;
288          reswizzle[i] = num_components++;
289       }
290    }
291 
292    /* update uses */
293    if (progress)
294       reswizzle_alu_uses(def, reswizzle);
295 
296    unsigned rounded = round_up_components(num_components);
297    assert(rounded <= def->num_components);
298    if (rounded < def->num_components)
299       progress = true;
300 
301    /* update dest */
302    def->num_components = rounded;
303 
304    return progress;
305 }
306 
307 static bool
opt_shrink_vectors_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,bool shrink_start)308 opt_shrink_vectors_intrinsic(nir_builder *b, nir_intrinsic_instr *instr,
309                              bool shrink_start)
310 {
311    switch (instr->intrinsic) {
312    case nir_intrinsic_load_uniform:
313    case nir_intrinsic_load_ubo:
314    case nir_intrinsic_load_input:
315    case nir_intrinsic_load_per_primitive_input:
316    case nir_intrinsic_load_input_vertex:
317    case nir_intrinsic_load_per_vertex_input:
318    case nir_intrinsic_load_interpolated_input:
319    case nir_intrinsic_load_ssbo:
320    case nir_intrinsic_load_push_constant:
321    case nir_intrinsic_load_constant:
322    case nir_intrinsic_load_shared:
323    case nir_intrinsic_load_global:
324    case nir_intrinsic_load_global_constant:
325    case nir_intrinsic_load_kernel_input:
326    case nir_intrinsic_load_scratch: {
327       /* Must be a vectorized intrinsic that we can resize. */
328       assert(instr->num_components != 0);
329 
330       /* Trim the dest to the used channels */
331       if (!shrink_dest_to_read_mask(&instr->def, shrink_start))
332          return false;
333 
334       instr->num_components = instr->def.num_components;
335       return true;
336    }
337    case nir_intrinsic_image_sparse_load:
338    case nir_intrinsic_bindless_image_sparse_load:
339    case nir_intrinsic_image_deref_sparse_load:
340       return shrink_intrinsic_to_non_sparse(instr);
341    default:
342       return false;
343    }
344 }
345 
346 static bool
opt_shrink_vectors_tex(nir_builder * b,nir_tex_instr * tex)347 opt_shrink_vectors_tex(nir_builder *b, nir_tex_instr *tex)
348 {
349    if (!tex->is_sparse)
350       return false;
351 
352    unsigned mask = nir_def_components_read(&tex->def);
353    int last_bit = util_last_bit(mask);
354 
355    /* If the sparse component is used, do nothing. */
356    if (last_bit == tex->def.num_components)
357       return false;
358 
359    tex->def.num_components -= 1;
360    tex->is_sparse = false;
361 
362    return true;
363 }
364 
365 static bool
opt_shrink_vectors_load_const(nir_load_const_instr * instr)366 opt_shrink_vectors_load_const(nir_load_const_instr *instr)
367 {
368    nir_def *def = &instr->def;
369 
370    /* early out if there's nothing to do. */
371    if (def->num_components == 1)
372       return false;
373 
374    /* don't remove any channels if used by non-ALU */
375    if (!is_only_used_by_alu(def))
376       return false;
377 
378    unsigned mask = nir_def_components_read(def);
379 
380    /* If nothing was read, leave it up to DCE. */
381    if (!mask)
382       return false;
383 
384    uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
385    unsigned num_components = 0;
386    bool progress = false;
387    for (unsigned i = 0; i < def->num_components; i++) {
388       if (!((mask >> i) & 0x1))
389          continue;
390 
391       /* Try reuse a component with the same constant */
392       unsigned j;
393       for (j = 0; j < num_components; j++) {
394          if (instr->value[i].u64 == instr->value[j].u64) {
395             reswizzle[i] = j;
396             progress = true;
397             break;
398          }
399       }
400 
401       /* Otherwise, just append the value */
402       if (j == num_components) {
403          instr->value[num_components] = instr->value[i];
404          if (i != num_components)
405             progress = true;
406          reswizzle[i] = num_components++;
407       }
408    }
409 
410    if (progress)
411       reswizzle_alu_uses(def, reswizzle);
412 
413    unsigned rounded = round_up_components(num_components);
414    assert(rounded <= def->num_components);
415    if (rounded < def->num_components)
416       progress = true;
417 
418    def->num_components = rounded;
419 
420    return progress;
421 }
422 
423 static bool
opt_shrink_vectors_ssa_undef(nir_undef_instr * instr)424 opt_shrink_vectors_ssa_undef(nir_undef_instr *instr)
425 {
426    return shrink_dest_to_read_mask(&instr->def, false);
427 }
428 
429 static bool
opt_shrink_vectors_phi(nir_builder * b,nir_phi_instr * instr)430 opt_shrink_vectors_phi(nir_builder *b, nir_phi_instr *instr)
431 {
432    nir_def *def = &instr->def;
433 
434    /* early out if there's nothing to do. */
435    if (def->num_components == 1)
436       return false;
437 
438    /* Ignore large vectors for now. */
439    if (def->num_components > 4)
440       return false;
441 
442    /* Check the uses. */
443    nir_component_mask_t mask = 0;
444    nir_foreach_use(src, def) {
445       if (nir_src_parent_instr(src)->type != nir_instr_type_alu)
446          return false;
447 
448       nir_alu_instr *alu = nir_instr_as_alu(nir_src_parent_instr(src));
449 
450       nir_alu_src *alu_src = exec_node_data(nir_alu_src, src, src);
451       int src_idx = alu_src - &alu->src[0];
452       nir_component_mask_t src_read_mask = nir_alu_instr_src_read_mask(alu, src_idx);
453 
454       nir_def *alu_def = &alu->def;
455 
456       /* We don't mark the channels used if the only reader is the original phi.
457        * This can happen in the case of loops.
458        */
459       nir_foreach_use(alu_use_src, alu_def) {
460          if (nir_src_parent_instr(alu_use_src) != &instr->instr) {
461             mask |= src_read_mask;
462          }
463       }
464 
465       /* However, even if the instruction only points back at the phi, we still
466        * need to check that the swizzles are trivial.
467        */
468       if (nir_op_is_vec(alu->op)) {
469          if (src_idx != alu->src[src_idx].swizzle[0]) {
470             mask |= src_read_mask;
471          }
472       } else if (!nir_alu_src_is_trivial_ssa(alu, src_idx)) {
473          mask |= src_read_mask;
474       }
475    }
476 
477    /* DCE will handle this. */
478    if (mask == 0)
479       return false;
480 
481    /* Nothing to shrink? */
482    if (BITFIELD_MASK(def->num_components) == mask)
483       return false;
484 
485    /* Set up the reswizzles. */
486    unsigned num_components = 0;
487    uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
488    uint8_t src_reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
489    for (unsigned i = 0; i < def->num_components; i++) {
490       if (!((mask >> i) & 0x1))
491          continue;
492       src_reswizzle[num_components] = i;
493       reswizzle[i] = num_components++;
494    }
495 
496    /* Shrink the phi, this part is simple. */
497    def->num_components = num_components;
498 
499    /* We can't swizzle phi sources directly so just insert extra mov
500     * with the correct swizzle and let the other parts of nir_shrink_vectors
501     * do its job on the original source instruction. If the original source was
502     * used only in the phi, the movs will disappear later after copy propagate.
503     */
504    nir_foreach_phi_src(phi_src, instr) {
505       b->cursor = nir_after_instr_and_phis(phi_src->src.ssa->parent_instr);
506 
507       nir_alu_src alu_src = {
508          .src = nir_src_for_ssa(phi_src->src.ssa)
509       };
510 
511       for (unsigned i = 0; i < num_components; i++)
512          alu_src.swizzle[i] = src_reswizzle[i];
513       nir_def *mov = nir_mov_alu(b, alu_src, num_components);
514 
515       nir_src_rewrite(&phi_src->src, mov);
516    }
517    b->cursor = nir_before_instr(&instr->instr);
518 
519    /* Reswizzle readers. */
520    reswizzle_alu_uses(def, reswizzle);
521 
522    return true;
523 }
524 
525 static bool
opt_shrink_vectors_instr(nir_builder * b,nir_instr * instr,bool shrink_start)526 opt_shrink_vectors_instr(nir_builder *b, nir_instr *instr, bool shrink_start)
527 {
528    b->cursor = nir_before_instr(instr);
529 
530    switch (instr->type) {
531    case nir_instr_type_alu:
532       return opt_shrink_vectors_alu(b, nir_instr_as_alu(instr));
533 
534    case nir_instr_type_tex:
535       return opt_shrink_vectors_tex(b, nir_instr_as_tex(instr));
536 
537    case nir_instr_type_intrinsic:
538       return opt_shrink_vectors_intrinsic(b, nir_instr_as_intrinsic(instr),
539                                           shrink_start);
540 
541    case nir_instr_type_load_const:
542       return opt_shrink_vectors_load_const(nir_instr_as_load_const(instr));
543 
544    case nir_instr_type_undef:
545       return opt_shrink_vectors_ssa_undef(nir_instr_as_undef(instr));
546 
547    case nir_instr_type_phi:
548       return opt_shrink_vectors_phi(b, nir_instr_as_phi(instr));
549 
550    default:
551       return false;
552    }
553 
554    return true;
555 }
556 
557 bool
nir_opt_shrink_vectors(nir_shader * shader,bool shrink_start)558 nir_opt_shrink_vectors(nir_shader *shader, bool shrink_start)
559 {
560    bool progress = false;
561 
562    nir_foreach_function_impl(impl, shader) {
563       nir_builder b = nir_builder_create(impl);
564 
565       nir_foreach_block_reverse(block, impl) {
566          nir_foreach_instr_reverse(instr, block) {
567             progress |= opt_shrink_vectors_instr(&b, instr, shrink_start);
568          }
569       }
570 
571       if (progress) {
572          nir_metadata_preserve(impl,
573                                nir_metadata_control_flow);
574       } else {
575          nir_metadata_preserve(impl, nir_metadata_all);
576       }
577    }
578 
579    return progress;
580 }
581