xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_vec_to_regs.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Valve Corporation
3  * Copyright 2014 Intel Corporation
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "nir.h"
8 #include "nir_builder.h"
9 
10 /*
11  * This file implements a simple pass that lowers vecN instructions to a series
12  * of partial register stores with partial writes.
13  */
14 
15 struct data {
16    nir_instr_writemask_filter_cb cb;
17    const void *data;
18 };
19 
20 /**
21  * For a given starting writemask channel and corresponding source index in
22  * the vec instruction, insert a store_reg to the vector register with a
23  * writemask with all the channels that get read from the same src reg.
24  *
25  * Returns the writemask of the store, so the parent loop calling this knows
26  * which ones have been processed.
27  */
28 static unsigned
insert_store(nir_builder * b,nir_def * reg,nir_alu_instr * vec,unsigned start_idx)29 insert_store(nir_builder *b, nir_def *reg, nir_alu_instr *vec,
30              unsigned start_idx)
31 {
32    assert(start_idx < nir_op_infos[vec->op].num_inputs);
33    nir_def *src = vec->src[start_idx].src.ssa;
34 
35    unsigned num_components = vec->def.num_components;
36    assert(num_components == nir_op_infos[vec->op].num_inputs);
37    unsigned write_mask = 0;
38    unsigned swiz[NIR_MAX_VEC_COMPONENTS] = { 0 };
39 
40    for (unsigned i = start_idx; i < num_components; i++) {
41       if (vec->src[i].src.ssa == src) {
42          write_mask |= BITFIELD_BIT(i);
43          swiz[i] = vec->src[i].swizzle[0];
44       }
45    }
46 
47    /* No sense storing from undef, just return the write mask */
48    if (src->parent_instr->type == nir_instr_type_undef)
49       return write_mask;
50 
51    b->cursor = nir_before_instr(&vec->instr);
52    nir_build_store_reg(b, nir_swizzle(b, src, swiz, num_components), reg,
53                        .write_mask = write_mask);
54    return write_mask;
55 }
56 
57 static bool
has_replicated_dest(nir_alu_instr * alu)58 has_replicated_dest(nir_alu_instr *alu)
59 {
60    return alu->op == nir_op_fdot2_replicated ||
61           alu->op == nir_op_fdot3_replicated ||
62           alu->op == nir_op_fdot4_replicated ||
63           alu->op == nir_op_fdph_replicated;
64 }
65 
66 /* Attempts to coalesce the "move" from the given source of the vec to the
67  * destination of the instruction generating the value. If, for whatever
68  * reason, we cannot coalesce the move, it does nothing and returns 0.  We
69  * can then call insert_mov as normal.
70  */
71 static unsigned
try_coalesce(nir_builder * b,nir_def * reg,nir_alu_instr * vec,unsigned start_idx,struct data * data)72 try_coalesce(nir_builder *b, nir_def *reg, nir_alu_instr *vec,
73              unsigned start_idx, struct data *data)
74 {
75    assert(start_idx < nir_op_infos[vec->op].num_inputs);
76 
77    /* If we are going to do a reswizzle, then the vecN operation must be the
78     * only use of the source value.
79     */
80    nir_foreach_use_including_if(src, vec->src[start_idx].src.ssa) {
81       if (nir_src_is_if(src))
82          return 0;
83 
84       if (nir_src_parent_instr(src) != &vec->instr)
85          return 0;
86    }
87 
88    if (vec->src[start_idx].src.ssa->parent_instr->type != nir_instr_type_alu)
89       return 0;
90 
91    nir_alu_instr *src_alu =
92       nir_instr_as_alu(vec->src[start_idx].src.ssa->parent_instr);
93 
94    if (has_replicated_dest(src_alu)) {
95       /* The fdot instruction is special: It replicates its result to all
96        * components.  This means that we can always rewrite its destination
97        * and we don't need to swizzle anything.
98        */
99    } else {
100       /* We only care about being able to re-swizzle the instruction if it is
101        * something that we can reswizzle.  It must be per-component.  The one
102        * exception to this is the fdotN instructions which implicitly splat
103        * their result out to all channels.
104        */
105       if (nir_op_infos[src_alu->op].output_size != 0)
106          return 0;
107 
108       /* If we are going to reswizzle the instruction, we can't have any
109        * non-per-component sources either.
110        */
111       for (unsigned j = 0; j < nir_op_infos[src_alu->op].num_inputs; j++)
112          if (nir_op_infos[src_alu->op].input_sizes[j] != 0)
113             return 0;
114    }
115 
116    /* Only vecN instructions have more than 4 sources and those are disallowed
117     * by the above check for non-per-component sources.  This assumption saves
118     * us a bit of stack memory.
119     */
120    assert(nir_op_infos[src_alu->op].num_inputs <= 4);
121 
122    /* Stash off all of the ALU instruction's swizzles. */
123    uint8_t swizzles[4][NIR_MAX_VEC_COMPONENTS];
124    for (unsigned j = 0; j < nir_op_infos[src_alu->op].num_inputs; j++)
125       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
126          swizzles[j][i] = src_alu->src[j].swizzle[i];
127 
128    unsigned dest_components = vec->def.num_components;
129    assert(dest_components == nir_op_infos[vec->op].num_inputs);
130 
131    /* Generate the final write mask */
132    nir_component_mask_t write_mask = 0;
133    for (unsigned i = start_idx; i < dest_components; i++) {
134       if (vec->src[i].src.ssa != &src_alu->def)
135          continue;
136 
137       write_mask |= BITFIELD_BIT(i);
138    }
139 
140    /* If the instruction would be vectorized but the backend
141     * doesn't support vectorizing this op, abort. */
142    if (data->cb && !data->cb(&src_alu->instr, write_mask, data->data))
143       return 0;
144 
145    for (unsigned i = 0; i < dest_components; i++) {
146       bool valid = write_mask & BITFIELD_BIT(i);
147 
148       /* At this point, the given vec source matches up with the ALU
149        * instruction so we can re-swizzle that component to match.
150        */
151       if (has_replicated_dest(src_alu)) {
152          /* Since the destination is a single replicated value, we don't need
153           * to do any reswizzling
154           */
155       } else {
156          for (unsigned j = 0; j < nir_op_infos[src_alu->op].num_inputs; j++) {
157             /* For channels we're extending out of nowhere, use a benign swizzle
158              * so we don't read invalid components and trip nir_validate.
159              */
160             unsigned c = valid ? vec->src[i].swizzle[0] : 0;
161 
162             src_alu->src[j].swizzle[i] = swizzles[j][c];
163          }
164       }
165 
166       /* Clear the no longer needed vec source */
167       if (valid)
168          nir_instr_clear_src(&vec->instr, &vec->src[i].src);
169    }
170 
171    /* We've cleared the only use of the destination */
172    assert(list_is_empty(&src_alu->def.uses));
173 
174    /* ... so we can replace it with the bigger destination accommodating the
175     * whole vector that will be masked for the store.
176     */
177    unsigned bit_size = vec->def.bit_size;
178    assert(bit_size == src_alu->def.bit_size);
179    nir_def_init(&src_alu->instr, &src_alu->def, dest_components,
180                 bit_size);
181 
182    /* Then we can store that ALU result directly into the register */
183    b->cursor = nir_after_instr(&src_alu->instr);
184    nir_build_store_reg(b, &src_alu->def,
185                        reg, .write_mask = write_mask);
186 
187    return write_mask;
188 }
189 
190 static bool
lower(nir_builder * b,nir_instr * instr,void * data_)191 lower(nir_builder *b, nir_instr *instr, void *data_)
192 {
193    struct data *data = data_;
194    if (instr->type != nir_instr_type_alu)
195       return false;
196 
197    nir_alu_instr *vec = nir_instr_as_alu(instr);
198    if (!nir_op_is_vec(vec->op))
199       return false;
200 
201    unsigned num_components = vec->def.num_components;
202 
203    /* Special case: if all sources are the same, just swizzle instead to avoid
204     * the extra copies from a register.
205     */
206    bool need_reg = false;
207    for (unsigned i = 1; i < num_components; ++i) {
208       if (!nir_srcs_equal(vec->src[0].src, vec->src[i].src)) {
209          need_reg = true;
210          break;
211       }
212    }
213 
214    b->cursor = nir_before_instr(instr);
215 
216    if (need_reg) {
217       /* We'll replace with a register. Declare one for the purpose. */
218       nir_def *reg = nir_decl_reg(b, num_components,
219                                   vec->def.bit_size, 0);
220 
221       unsigned finished_write_mask = 0;
222       for (unsigned i = 0; i < num_components; i++) {
223          /* Try to coalesce the move */
224          if (!(finished_write_mask & BITFIELD_BIT(i)))
225             finished_write_mask |= try_coalesce(b, reg, vec, i, data);
226 
227          /* Otherwise fall back on the simple path */
228          if (!(finished_write_mask & BITFIELD_BIT(i)))
229             finished_write_mask |= insert_store(b, reg, vec, i);
230       }
231 
232       nir_rewrite_uses_to_load_reg(b, &vec->def, reg);
233    } else {
234       /* Otherwise, we replace with a swizzle */
235       unsigned swiz[NIR_MAX_VEC_COMPONENTS] = { 0 };
236 
237       for (unsigned i = 0; i < num_components; ++i) {
238          swiz[i] = vec->src[i].swizzle[0];
239       }
240 
241       nir_def *swizzled = nir_swizzle(b, vec->src[0].src.ssa, swiz,
242                                       num_components);
243       nir_def_rewrite_uses(&vec->def, swizzled);
244    }
245 
246    nir_instr_remove(&vec->instr);
247    nir_instr_free(&vec->instr);
248    return true;
249 }
250 
251 bool
nir_lower_vec_to_regs(nir_shader * shader,nir_instr_writemask_filter_cb cb,const void * _data)252 nir_lower_vec_to_regs(nir_shader *shader, nir_instr_writemask_filter_cb cb,
253                       const void *_data)
254 {
255    struct data data = {
256       .cb = cb,
257       .data = _data
258    };
259 
260    return nir_shader_instructions_pass(shader, lower,
261                                        nir_metadata_control_flow,
262                                        &data);
263 }
264