xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/radeonsi/si_nir_optim.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2021 Advanced Micro Devices, Inc.
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "si_pipe.h"
8 #include "nir.h"
9 #include "nir_builder.h"
10 #include "nir_worklist.h"
11 
12 
13 static bool
add_src_instr_to_worklist(nir_src * src,void * wl)14 add_src_instr_to_worklist(nir_src *src, void *wl)
15 {
16    nir_instr_worklist_push_tail(wl, src->ssa->parent_instr);
17    return true;
18 }
19 
20 static int
get_tex_unit(nir_tex_instr * tex)21 get_tex_unit(nir_tex_instr *tex)
22 {
23    int tex_index = nir_tex_instr_src_index(tex, nir_tex_src_texture_deref);
24    if (tex_index >= 0) {
25       nir_deref_instr *deref = nir_src_as_deref(tex->src[tex_index].src);
26       nir_variable *var = nir_deref_instr_get_variable(deref);
27       return var ? var->data.binding : 0;
28    }
29    return -1;
30 }
31 
32 static int
check_instr_depends_on_tex(nir_intrinsic_instr * store)33 check_instr_depends_on_tex(nir_intrinsic_instr *store)
34 {
35    int texunit = -1;
36    struct set *instrs = _mesa_set_create(NULL, _mesa_hash_pointer,
37                                          _mesa_key_pointer_equal);
38    nir_instr_worklist *work = nir_instr_worklist_create();
39 
40    _mesa_set_add(instrs, &store->instr);
41    add_src_instr_to_worklist(&store->src[0], work);
42 
43    nir_foreach_instr_in_worklist(instr, work) {
44       /* Don't process an instruction twice */
45       if (_mesa_set_search(instrs, instr))
46          continue;
47 
48       _mesa_set_add(instrs, instr);
49 
50       if (instr->type == nir_instr_type_alu ||
51           instr->type == nir_instr_type_load_const) {
52          /* TODO: ubo, etc */
53          if (!nir_foreach_src(instr, add_src_instr_to_worklist, work))
54             break;
55          continue;
56       } else if (instr->type == nir_instr_type_tex) {
57          if (texunit != -1) {
58             /* We can only depend on a single tex */
59             texunit = -1;
60             break;
61          } else {
62             texunit = get_tex_unit(nir_instr_as_tex(instr));
63             continue;
64          }
65       } else {
66          break;
67       }
68    }
69 
70    nir_instr_worklist_destroy(work);
71    _mesa_set_destroy(instrs, NULL);
72    return texunit;
73 }
74 
75 static bool
get_output_as_const_value(nir_shader * shader,float values[4])76 get_output_as_const_value(nir_shader *shader, float values[4])
77 {
78    nir_foreach_function_impl(impl, shader) {
79       nir_foreach_block_reverse(block, impl) {
80          nir_foreach_instr_reverse_safe(instr, block) {
81             switch (instr->type) {
82                case nir_instr_type_intrinsic: {
83                   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
84                   if (intrin->intrinsic == nir_intrinsic_store_output) {
85                      nir_const_value *c = nir_src_as_const_value(intrin->src[0]);
86                      if (c) {
87                         nir_const_value_to_array(values, c, 4, f32);
88                         return true;
89                      }
90                      return false;
91                   }
92                   FALLTHROUGH;
93                }
94                default:
95                   continue;
96             }
97          }
98       }
99    }
100    return false;
101 }
102 
103 struct replace_param {
104    float value[4];
105    int *texunit;
106 };
107 
108 static bool
store_instr_depends_on_tex(nir_builder * b,nir_intrinsic_instr * intrin,void * state)109 store_instr_depends_on_tex(nir_builder *b, nir_intrinsic_instr *intrin,
110                            void *state)
111 {
112    if (intrin->intrinsic != nir_intrinsic_store_output)
113       return false;
114 
115    struct replace_param *p = (struct replace_param*) state;
116    *(p->texunit) = check_instr_depends_on_tex(intrin);
117 
118    return *(p->texunit) != -1;
119 }
120 
121 
122 static bool
replace_tex_by_imm(nir_builder * b,nir_instr * instr,void * state)123 replace_tex_by_imm(nir_builder *b, nir_instr *instr, void *state)
124 {
125    if (instr->type != nir_instr_type_tex)
126       return false;
127 
128    nir_tex_instr *tex = nir_instr_as_tex(instr);
129    struct replace_param *p = (struct replace_param*) state;
130 
131    if (get_tex_unit(tex) != *(p->texunit))
132       return false;
133 
134    b->cursor = nir_instr_remove(&tex->instr);
135    nir_def *imm = nir_imm_vec4(b, p->value[0], p->value[1], p->value[2], p->value[3]);
136    nir_def_rewrite_uses(&tex->def, imm);
137    return true;
138 }
139 
140 
141 /* This function returns true if a shader' sole output becomes constant when
142  * a given texunit is replaced by a constant value.
143  * The input constant value is passed as 'in' and the determined constant
144  * value is stored in 'out'. The texunit is also remembered.
145  */
146 bool
si_nir_is_output_const_if_tex_is_const(nir_shader * shader,float * in,float * out,int * texunit)147 si_nir_is_output_const_if_tex_is_const(nir_shader *shader, float *in, float *out, int *texunit)
148 {
149    assert(shader->info.stage == MESA_SHADER_FRAGMENT);
150 
151    if (BITSET_COUNT(shader->info.textures_used) == 0 ||
152        util_bitcount64(shader->info.outputs_written) != 1)
153       return false;
154 
155    struct replace_param p;
156    memcpy(p.value, in, 4 * sizeof(float));
157    p.texunit = texunit;
158 
159    /* Test if the single store_output only depends on constants and a single texture op */
160    if (nir_shader_intrinsics_pass(shader, store_instr_depends_on_tex, nir_metadata_all, &p)) {
161       assert(*p.texunit != -1);
162 
163       /* Replace nir_tex_instr using texunit by vec4(v) */
164       nir_shader_instructions_pass(shader, replace_tex_by_imm,
165                                    nir_metadata_control_flow, &p);
166 
167       /* Optimize the cloned shader */
168       bool progress;
169       do {
170          progress = false;
171          NIR_PASS(progress, shader, nir_copy_prop);
172          NIR_PASS(progress, shader, nir_opt_remove_phis);
173          NIR_PASS(progress, shader, nir_opt_dce);
174          NIR_PASS(progress, shader, nir_opt_dead_cf);
175          NIR_PASS(progress, shader, nir_opt_algebraic);
176          NIR_PASS(progress, shader, nir_opt_constant_folding);
177       } while (progress);
178 
179       /* Is the output a constant value? */
180       if (get_output_as_const_value(shader, out))
181          return true;
182    }
183 
184    return false;
185 }
186