xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/r600/sfn/sfn_nir_lower_fs_out_to_vector.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /* -*- mesa-c++  -*-
2  * Copyright 2019 Collabora LTD
3  * Author: Gert Wollny <[email protected]>
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "sfn_nir_lower_fs_out_to_vector.h"
8 
9 #include "nir_builder.h"
10 #include "nir_deref.h"
11 #include "util/u_math.h"
12 
13 #include <algorithm>
14 #include <array>
15 #include <set>
16 #include <vector>
17 
18 namespace r600 {
19 
20 using std::array;
21 using std::multiset;
22 using std::vector;
23 
24 struct nir_intrinsic_instr_less {
operator ()r600::nir_intrinsic_instr_less25    bool operator()(const nir_intrinsic_instr *lhs, const nir_intrinsic_instr *rhs) const
26    {
27       nir_variable *vlhs = nir_intrinsic_get_var(lhs, 0);
28       nir_variable *vrhs = nir_intrinsic_get_var(rhs, 0);
29 
30       auto ltype = glsl_get_base_type(vlhs->type);
31       auto rtype = glsl_get_base_type(vrhs->type);
32 
33       if (ltype != rtype)
34          return ltype < rtype;
35       return vlhs->data.location < vrhs->data.location;
36    }
37 };
38 
39 class NirLowerIOToVector {
40 public:
41    NirLowerIOToVector(int base_slot);
42    bool run(nir_function_impl *shader);
43 
44 protected:
45    bool var_can_merge(const nir_variable *lhs, const nir_variable *rhs);
46    bool var_can_rewrite(nir_variable *var) const;
47    void create_new_io_vars(nir_shader *shader);
48    void create_new_io_var(nir_shader *shader, unsigned location, unsigned comps);
49 
50    nir_deref_instr *clone_deref_array(nir_builder *b,
51                                       nir_deref_instr *dst_tail,
52                                       const nir_deref_instr *src_head);
53 
54    bool vectorize_block(nir_builder *b, nir_block *block);
55    bool instr_can_rewrite(nir_instr *instr);
56    bool vec_instr_set_remove(nir_builder *b, nir_instr *instr);
57 
58    using InstrSet = multiset<nir_intrinsic_instr *, nir_intrinsic_instr_less>;
59    using InstrSubSet = std::pair<InstrSet::iterator, InstrSet::iterator>;
60 
61    bool
62    vec_instr_stack_pop(nir_builder *b, InstrSubSet& ir_set, nir_intrinsic_instr *instr);
63 
64    array<array<nir_variable *, 4>, 16> m_vars;
65    InstrSet m_block_io;
66    int m_next_index;
67 
68 private:
69    virtual nir_variable_mode get_io_mode(nir_shader *shader) const = 0;
70    virtual bool instr_can_rewrite_type(nir_intrinsic_instr *intr) const = 0;
71    virtual bool var_can_rewrite_slot(nir_variable *var) const = 0;
72    virtual void create_new_io(nir_builder *b,
73                               nir_intrinsic_instr *intr,
74                               nir_variable *var,
75                               nir_def **srcs,
76                               unsigned first_comp,
77                               unsigned num_comps) = 0;
78 
79    int m_base_slot;
80 };
81 
82 class NirLowerFSOutToVector : public NirLowerIOToVector {
83 public:
84    NirLowerFSOutToVector();
85 
86 private:
87    nir_variable_mode get_io_mode(nir_shader *shader) const override;
88    bool var_can_rewrite_slot(nir_variable *var) const override;
89    void create_new_io(nir_builder *b,
90                       nir_intrinsic_instr *intr,
91                       nir_variable *var,
92                       nir_def **srcs,
93                       unsigned first_comp,
94                       unsigned num_comps) override;
95    bool instr_can_rewrite_type(nir_intrinsic_instr *intr) const override;
96 
97    nir_def *create_combined_vector(nir_builder *b,
98                                        nir_def **srcs,
99                                        int first_comp,
100                                        int num_comp);
101 };
102 
103 bool
r600_lower_fs_out_to_vector(nir_shader * shader)104 r600_lower_fs_out_to_vector(nir_shader *shader)
105 {
106    NirLowerFSOutToVector processor;
107 
108    assert(shader->info.stage == MESA_SHADER_FRAGMENT);
109    bool progress = false;
110 
111    nir_foreach_function_impl(impl, shader) {
112       progress |= processor.run(impl);
113    }
114    return progress;
115 }
116 
NirLowerIOToVector(int base_slot)117 NirLowerIOToVector::NirLowerIOToVector(int base_slot):
118     m_next_index(0),
119     m_base_slot(base_slot)
120 {
121    for (auto& a : m_vars)
122       for (auto& aa : a)
123          aa = nullptr;
124 }
125 
126 bool
run(nir_function_impl * impl)127 NirLowerIOToVector::run(nir_function_impl *impl)
128 {
129    nir_builder b = nir_builder_create(impl);
130 
131    nir_metadata_require(impl, nir_metadata_dominance);
132    create_new_io_vars(impl->function->shader);
133 
134    bool progress = vectorize_block(&b, nir_start_block(impl));
135    if (progress) {
136       nir_metadata_preserve(impl, nir_metadata_control_flow);
137    } else {
138       nir_metadata_preserve(impl, nir_metadata_all);
139    }
140    return progress;
141 }
142 
143 void
create_new_io_vars(nir_shader * shader)144 NirLowerIOToVector::create_new_io_vars(nir_shader *shader)
145 {
146    nir_variable_mode mode = get_io_mode(shader);
147 
148    bool can_rewrite_vars = false;
149    nir_foreach_variable_with_modes(var, shader, mode)
150    {
151       if (var_can_rewrite(var)) {
152          can_rewrite_vars = true;
153          unsigned loc = var->data.location - m_base_slot;
154          m_vars[loc][var->data.location_frac] = var;
155       }
156    }
157 
158    if (!can_rewrite_vars)
159       return;
160 
161    /* We don't handle combining vars of different type e.g. different array
162     * lengths.
163     */
164    for (unsigned i = 0; i < 16; i++) {
165       unsigned comps = 0;
166 
167       for (unsigned j = 0; j < 3; j++) {
168          if (!m_vars[i][j])
169             continue;
170 
171          for (unsigned k = j + 1; k < 4; k++) {
172             if (!m_vars[i][k])
173                continue;
174 
175             if (!var_can_merge(m_vars[i][j], m_vars[i][k]))
176                continue;
177 
178             /* Set comps */
179             for (unsigned n = 0; n < glsl_get_components(m_vars[i][j]->type); ++n)
180                comps |= 1 << (m_vars[i][j]->data.location_frac + n);
181 
182             for (unsigned n = 0; n < glsl_get_components(m_vars[i][k]->type); ++n)
183                comps |= 1 << (m_vars[i][k]->data.location_frac + n);
184          }
185       }
186       if (comps)
187          create_new_io_var(shader, i, comps);
188    }
189 }
190 
191 bool
var_can_merge(const nir_variable * lhs,const nir_variable * rhs)192 NirLowerIOToVector::var_can_merge(const nir_variable *lhs, const nir_variable *rhs)
193 {
194    return (glsl_get_base_type(lhs->type) == glsl_get_base_type(rhs->type));
195 }
196 
197 void
create_new_io_var(nir_shader * shader,unsigned location,unsigned comps)198 NirLowerIOToVector::create_new_io_var(nir_shader *shader,
199                                       unsigned location,
200                                       unsigned comps)
201 {
202    unsigned num_comps = util_bitcount(comps);
203    assert(num_comps > 1);
204 
205    /* Note: u_bit_scan() strips a component of the comps bitfield here */
206    unsigned first_comp = u_bit_scan(&comps);
207 
208    nir_variable *var = nir_variable_clone(m_vars[location][first_comp], shader);
209    var->data.location_frac = first_comp;
210    var->type = glsl_replace_vector_type(var->type, num_comps);
211 
212    nir_shader_add_variable(shader, var);
213 
214    m_vars[location][first_comp] = var;
215 
216    while (comps) {
217       const int comp = u_bit_scan(&comps);
218       if (m_vars[location][comp]) {
219          m_vars[location][comp] = var;
220       }
221    }
222 }
223 
224 bool
var_can_rewrite(nir_variable * var) const225 NirLowerIOToVector::var_can_rewrite(nir_variable *var) const
226 {
227    /* Skip complex types we don't split in the first place */
228    if (!glsl_type_is_vector_or_scalar(glsl_without_array(var->type)))
229       return false;
230 
231    if (glsl_get_bit_size(glsl_without_array(var->type)) != 32)
232       return false;
233 
234    return var_can_rewrite_slot(var);
235 }
236 
237 bool
vectorize_block(nir_builder * b,nir_block * block)238 NirLowerIOToVector::vectorize_block(nir_builder *b, nir_block *block)
239 {
240    bool progress = false;
241 
242    nir_foreach_instr_safe(instr, block)
243    {
244       if (instr_can_rewrite(instr)) {
245          instr->index = m_next_index++;
246          nir_intrinsic_instr *ir = nir_instr_as_intrinsic(instr);
247          m_block_io.insert(ir);
248       }
249    }
250 
251    for (unsigned i = 0; i < block->num_dom_children; i++) {
252       nir_block *child = block->dom_children[i];
253       progress |= vectorize_block(b, child);
254    }
255 
256    nir_foreach_instr_reverse_safe(instr, block)
257    {
258       progress |= vec_instr_set_remove(b, instr);
259    }
260    m_block_io.clear();
261 
262    return progress;
263 }
264 
265 bool
instr_can_rewrite(nir_instr * instr)266 NirLowerIOToVector::instr_can_rewrite(nir_instr *instr)
267 {
268    if (instr->type != nir_instr_type_intrinsic)
269       return false;
270 
271    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
272 
273    if (intr->num_components > 3)
274       return false;
275 
276    return instr_can_rewrite_type(intr);
277 }
278 
279 bool
vec_instr_set_remove(nir_builder * b,nir_instr * instr)280 NirLowerIOToVector::vec_instr_set_remove(nir_builder *b, nir_instr *instr)
281 {
282    if (!instr_can_rewrite(instr))
283       return false;
284 
285    nir_intrinsic_instr *ir = nir_instr_as_intrinsic(instr);
286    auto entry = m_block_io.equal_range(ir);
287    if (entry.first != m_block_io.end()) {
288       vec_instr_stack_pop(b, entry, ir);
289    }
290    return true;
291 }
292 
293 nir_deref_instr *
clone_deref_array(nir_builder * b,nir_deref_instr * dst_tail,const nir_deref_instr * src_head)294 NirLowerIOToVector::clone_deref_array(nir_builder *b,
295                                       nir_deref_instr *dst_tail,
296                                       const nir_deref_instr *src_head)
297 {
298    const nir_deref_instr *parent = nir_deref_instr_parent(src_head);
299 
300    if (!parent)
301       return dst_tail;
302 
303    assert(src_head->deref_type == nir_deref_type_array);
304 
305    dst_tail = clone_deref_array(b, dst_tail, parent);
306 
307    return nir_build_deref_array(b, dst_tail, src_head->arr.index.ssa);
308 }
309 
NirLowerFSOutToVector()310 NirLowerFSOutToVector::NirLowerFSOutToVector():
311     NirLowerIOToVector(FRAG_RESULT_COLOR)
312 {
313 }
314 
315 bool
var_can_rewrite_slot(nir_variable * var) const316 NirLowerFSOutToVector::var_can_rewrite_slot(nir_variable *var) const
317 {
318    return ((var->data.mode == nir_var_shader_out) &&
319            ((var->data.location == FRAG_RESULT_COLOR) ||
320             ((var->data.location >= FRAG_RESULT_DATA0) &&
321              (var->data.location <= FRAG_RESULT_DATA7))));
322 }
323 
324 bool
vec_instr_stack_pop(nir_builder * b,InstrSubSet & ir_set,nir_intrinsic_instr * instr)325 NirLowerIOToVector::vec_instr_stack_pop(nir_builder *b,
326                                         InstrSubSet& ir_set,
327                                         nir_intrinsic_instr *instr)
328 {
329    vector<nir_intrinsic_instr *> ir_sorted_set(ir_set.first, ir_set.second);
330    std::sort(ir_sorted_set.begin(),
331              ir_sorted_set.end(),
332              [](const nir_intrinsic_instr *lhs, const nir_intrinsic_instr *rhs) {
333                 return lhs->instr.index > rhs->instr.index;
334              });
335 
336    nir_intrinsic_instr *intr = *ir_sorted_set.begin();
337    nir_variable *var = nir_intrinsic_get_var(intr, 0);
338 
339    unsigned loc = var->data.location - m_base_slot;
340 
341    nir_variable *new_var = m_vars[loc][var->data.location_frac];
342    unsigned num_comps = glsl_get_vector_elements(glsl_without_array(new_var->type));
343    unsigned old_num_comps = glsl_get_vector_elements(glsl_without_array(var->type));
344 
345    /* Don't bother walking the stack if this component can't be vectorised. */
346    if (old_num_comps > 3) {
347       return false;
348    }
349 
350    if (new_var == var) {
351       return false;
352    }
353 
354    b->cursor = nir_after_instr(&intr->instr);
355    nir_undef_instr *instr_undef = nir_undef_instr_create(b->shader, 1, 32);
356    nir_builder_instr_insert(b, &instr_undef->instr);
357 
358    nir_def *srcs[4];
359    for (int i = 0; i < 4; i++) {
360       srcs[i] = &instr_undef->def;
361    }
362    srcs[var->data.location_frac] = intr->src[1].ssa;
363 
364    for (auto k = ir_sorted_set.begin() + 1; k != ir_sorted_set.end(); ++k) {
365       nir_intrinsic_instr *intr2 = *k;
366       nir_variable *var2 = nir_intrinsic_get_var(intr2, 0);
367       unsigned loc2 = var->data.location - m_base_slot;
368 
369       if (m_vars[loc][var->data.location_frac] !=
370           m_vars[loc2][var2->data.location_frac]) {
371          continue;
372       }
373 
374       assert(glsl_get_vector_elements(glsl_without_array(var2->type)) < 4);
375 
376       if (srcs[var2->data.location_frac] == &instr_undef->def) {
377          assert(intr2->src[1].ssa);
378          srcs[var2->data.location_frac] = intr2->src[1].ssa;
379       }
380       nir_instr_remove(&intr2->instr);
381    }
382 
383    create_new_io(b, intr, new_var, srcs, new_var->data.location_frac, num_comps);
384    return true;
385 }
386 
387 nir_variable_mode
get_io_mode(nir_shader * shader) const388 NirLowerFSOutToVector::get_io_mode(nir_shader *shader) const
389 {
390    return nir_var_shader_out;
391 }
392 
393 void
create_new_io(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var,nir_def ** srcs,unsigned first_comp,unsigned num_comps)394 NirLowerFSOutToVector::create_new_io(nir_builder *b,
395                                      nir_intrinsic_instr *intr,
396                                      nir_variable *var,
397                                      nir_def **srcs,
398                                      unsigned first_comp,
399                                      unsigned num_comps)
400 {
401    b->cursor = nir_before_instr(&intr->instr);
402 
403    nir_intrinsic_instr *new_intr = nir_intrinsic_instr_create(b->shader, intr->intrinsic);
404    new_intr->num_components = num_comps;
405 
406    nir_intrinsic_set_write_mask(new_intr, (1 << num_comps) - 1);
407 
408    nir_deref_instr *deref = nir_build_deref_var(b, var);
409    deref = clone_deref_array(b, deref, nir_src_as_deref(intr->src[0]));
410 
411    new_intr->src[0] = nir_src_for_ssa(&deref->def);
412    new_intr->src[1] =
413       nir_src_for_ssa(create_combined_vector(b, srcs, first_comp, num_comps));
414 
415    nir_builder_instr_insert(b, &new_intr->instr);
416 
417    /* Remove the old store intrinsic */
418    nir_instr_remove(&intr->instr);
419 }
420 
421 bool
instr_can_rewrite_type(nir_intrinsic_instr * intr) const422 NirLowerFSOutToVector::instr_can_rewrite_type(nir_intrinsic_instr *intr) const
423 {
424    if (intr->intrinsic != nir_intrinsic_store_deref)
425       return false;
426 
427    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
428    if (!nir_deref_mode_is(deref, nir_var_shader_out))
429       return false;
430 
431    return var_can_rewrite(nir_deref_instr_get_variable(deref));
432 }
433 
434 nir_def *
create_combined_vector(nir_builder * b,nir_def ** srcs,int first_comp,int num_comp)435 NirLowerFSOutToVector::create_combined_vector(nir_builder *b,
436                                               nir_def **srcs,
437                                               int first_comp,
438                                               int num_comp)
439 {
440    nir_op op;
441    switch (num_comp) {
442    case 2:
443       op = nir_op_vec2;
444       break;
445    case 3:
446       op = nir_op_vec3;
447       break;
448    case 4:
449       op = nir_op_vec4;
450       break;
451    default:
452       unreachable("combined vector must have 2 to 4 components");
453    }
454    nir_alu_instr *instr = nir_alu_instr_create(b->shader, op);
455    instr->exact = b->exact;
456 
457    int i = 0;
458    unsigned k = 0;
459    while (i < num_comp) {
460       nir_def *s = srcs[first_comp + k];
461       for (uint8_t kk = 0; kk < s->num_components && i < num_comp; ++kk) {
462          instr->src[i].src = nir_src_for_ssa(s);
463          instr->src[i].swizzle[0] = kk;
464          ++i;
465       }
466       k += s->num_components;
467    }
468 
469    nir_def_init(&instr->instr, &instr->def, num_comp, 32);
470    nir_builder_instr_insert(b, &instr->instr);
471    return &instr->def;
472 }
473 
474 } // namespace r600
475