xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_reg_intrinsics_to_ssa.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Valve Corporation
3  * Copyright 2014 Intel Corporation
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "nir.h"
8 #include "nir_builder.h"
9 #include "nir_intrinsics.h"
10 #include "nir_intrinsics_indices.h"
11 #include "nir_phi_builder.h"
12 #include "nir_vla.h"
13 
14 static bool
should_lower_reg(nir_intrinsic_instr * decl)15 should_lower_reg(nir_intrinsic_instr *decl)
16 {
17    /* This pass only really works on "plain" registers. In particular,
18     * base/indirects are not handled. If it's a packed or array register,
19     * just set the value to NULL so that the rewrite portion of the pass
20     * will know to ignore it.
21     */
22    return nir_intrinsic_num_array_elems(decl) == 0;
23 }
24 
25 struct regs_to_ssa_state {
26    nir_builder b;
27 
28    /* Scratch bitset for use in setup_reg */
29    unsigned defs_words;
30    BITSET_WORD *defs;
31 
32    struct nir_phi_builder *phi_builder;
33    struct nir_phi_builder_value **values;
34 };
35 
36 static void
setup_reg(nir_intrinsic_instr * decl,struct regs_to_ssa_state * state)37 setup_reg(nir_intrinsic_instr *decl, struct regs_to_ssa_state *state)
38 {
39    if (nir_def_is_unused(&decl->def)) {
40       nir_instr_remove(&decl->instr);
41       return;
42    }
43 
44    assert(state->values[decl->def.index] == NULL);
45    if (!should_lower_reg(decl))
46       return;
47 
48    const unsigned num_components = nir_intrinsic_num_components(decl);
49    const unsigned bit_size = nir_intrinsic_bit_size(decl);
50 
51    memset(state->defs, 0, state->defs_words * sizeof(*state->defs));
52 
53    nir_foreach_reg_store(store, decl)
54       BITSET_SET(state->defs, nir_src_parent_instr(store)->block->index);
55 
56    state->values[decl->def.index] =
57       nir_phi_builder_add_value(state->phi_builder, num_components,
58                                 bit_size, state->defs);
59 }
60 
61 static void
rewrite_load(nir_intrinsic_instr * load,struct regs_to_ssa_state * state)62 rewrite_load(nir_intrinsic_instr *load, struct regs_to_ssa_state *state)
63 {
64    nir_block *block = load->instr.block;
65    nir_def *reg = load->src[0].ssa;
66 
67    struct nir_phi_builder_value *value = state->values[reg->index];
68    if (!value)
69       return;
70 
71    nir_intrinsic_instr *decl = nir_instr_as_intrinsic(reg->parent_instr);
72    nir_def *def = nir_phi_builder_value_get_block_def(value, block);
73 
74    nir_def_replace(&load->def, def);
75 
76    if (nir_def_is_unused(&decl->def))
77       nir_instr_remove(&decl->instr);
78 }
79 
80 static void
rewrite_store(nir_intrinsic_instr * store,struct regs_to_ssa_state * state)81 rewrite_store(nir_intrinsic_instr *store, struct regs_to_ssa_state *state)
82 {
83    nir_block *block = store->instr.block;
84    nir_def *new_value = store->src[0].ssa;
85    nir_def *reg = store->src[1].ssa;
86 
87    struct nir_phi_builder_value *value = state->values[reg->index];
88    if (!value)
89       return;
90 
91    nir_intrinsic_instr *decl = nir_instr_as_intrinsic(reg->parent_instr);
92    unsigned num_components = nir_intrinsic_num_components(decl);
93    unsigned write_mask = nir_intrinsic_write_mask(store);
94 
95    /* Implement write masks by combining together the old/new values */
96    if (write_mask != BITFIELD_MASK(num_components)) {
97       nir_def *old_value =
98          nir_phi_builder_value_get_block_def(value, block);
99 
100       nir_def *channels[NIR_MAX_VEC_COMPONENTS] = { NULL };
101       state->b.cursor = nir_before_instr(&store->instr);
102 
103       for (unsigned i = 0; i < num_components; ++i) {
104          if (write_mask & BITFIELD_BIT(i))
105             channels[i] = nir_channel(&state->b, new_value, i);
106          else
107             channels[i] = nir_channel(&state->b, old_value, i);
108       }
109 
110       new_value = nir_vec(&state->b, channels, num_components);
111    }
112 
113    nir_phi_builder_value_set_block_def(value, block, new_value);
114    nir_instr_remove(&store->instr);
115 
116    if (nir_def_is_unused(&decl->def))
117       nir_instr_remove(&decl->instr);
118 }
119 
120 bool
nir_lower_reg_intrinsics_to_ssa_impl(nir_function_impl * impl)121 nir_lower_reg_intrinsics_to_ssa_impl(nir_function_impl *impl)
122 {
123    bool need_lower_reg = false;
124    nir_foreach_reg_decl(reg, impl) {
125       if (should_lower_reg(reg)) {
126          need_lower_reg = true;
127          break;
128       }
129    }
130    if (!need_lower_reg) {
131       nir_metadata_preserve(impl, nir_metadata_all);
132       return false;
133    }
134 
135    nir_metadata_require(impl, nir_metadata_control_flow);
136    nir_index_ssa_defs(impl);
137 
138    void *dead_ctx = ralloc_context(NULL);
139    struct regs_to_ssa_state state;
140    state.b = nir_builder_create(impl);
141    state.defs_words = BITSET_WORDS(impl->num_blocks);
142    state.defs = ralloc_array(dead_ctx, BITSET_WORD, state.defs_words);
143    state.phi_builder = nir_phi_builder_create(state.b.impl);
144    state.values = rzalloc_array(dead_ctx, struct nir_phi_builder_value *,
145                                 impl->ssa_alloc);
146 
147    nir_foreach_block_unstructured(block, impl) {
148       nir_foreach_instr_safe(instr, block) {
149          if (instr->type != nir_instr_type_intrinsic)
150             continue;
151 
152          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
153          switch (intr->intrinsic) {
154          case nir_intrinsic_decl_reg:
155             setup_reg(intr, &state);
156             break;
157          case nir_intrinsic_load_reg:
158             rewrite_load(intr, &state);
159             break;
160          case nir_intrinsic_store_reg:
161             rewrite_store(intr, &state);
162             break;
163          default:
164             break;
165          }
166       }
167    }
168 
169    nir_phi_builder_finish(state.phi_builder);
170 
171    ralloc_free(dead_ctx);
172 
173    nir_metadata_preserve(impl, nir_metadata_control_flow);
174    return true;
175 }
176 
177 bool
nir_lower_reg_intrinsics_to_ssa(nir_shader * shader)178 nir_lower_reg_intrinsics_to_ssa(nir_shader *shader)
179 {
180    bool progress = false;
181 
182    nir_foreach_function_impl(impl, shader) {
183       progress |= nir_lower_reg_intrinsics_to_ssa_impl(impl);
184    }
185 
186    return progress;
187 }
188