xref: /aosp_15_r20/external/mesa3d/src/amd/common/ac_nir_lower_subdword_loads.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2022 Advanced Micro Devices, Inc.
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 /* Convert 8-bit and 16-bit loads to 32 bits. This is for drivers that don't
8  * support non-32-bit loads.
9  *
10  * This pass only transforms load intrinsics lowered by nir_lower_explicit_io,
11  * so this pass should run after it.
12  *
13  * nir_opt_load_store_vectorize should be run before this because it analyzes
14  * offset calculations and recomputes align_mul and align_offset.
15  *
16  * nir_opt_algebraic and (optionally) ALU scalarization are recommended to be
17  * run after this.
18  *
19  * Running nir_opt_load_store_vectorize after this pass may lead to further
20  * vectorization, e.g. adjacent 2x16-bit and 1x32-bit loads will become
21  * 2x32-bit loads.
22  */
23 
24 #include "util/u_math.h"
25 #include "ac_nir.h"
26 #include "nir_builder.h"
27 
28 static bool
lower_subdword_loads(nir_builder * b,nir_instr * instr,void * data)29 lower_subdword_loads(nir_builder *b, nir_instr *instr, void *data)
30 {
31    ac_nir_lower_subdword_options *options = data;
32 
33    if (instr->type != nir_instr_type_intrinsic)
34       return false;
35 
36    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
37    unsigned num_components = intr->num_components;
38    nir_variable_mode modes =
39       num_components == 1 ? options->modes_1_comp
40                           : options->modes_N_comps;
41 
42    switch (intr->intrinsic) {
43    case nir_intrinsic_load_ubo:
44       if (!(modes & nir_var_mem_ubo))
45          return false;
46       break;
47    case nir_intrinsic_load_ssbo:
48       if (!(modes & nir_var_mem_ssbo))
49          return false;
50       break;
51    case nir_intrinsic_load_global:
52       if (!(modes & nir_var_mem_global))
53          return false;
54       break;
55    case nir_intrinsic_load_push_constant:
56       if (!(modes & nir_var_mem_push_const))
57          return false;
58       break;
59    default:
60       return false;
61    }
62 
63    unsigned bit_size = intr->def.bit_size;
64    if (bit_size >= 32)
65       return false;
66 
67    assert(bit_size == 8 || bit_size == 16);
68 
69    unsigned component_size = bit_size / 8;
70    unsigned comp_per_dword = 4 / component_size;
71 
72    /* Get the offset alignment relative to the closest dword. */
73    unsigned align_mul = MIN2(nir_intrinsic_align_mul(intr), 4);
74    unsigned align_offset = nir_intrinsic_align_offset(intr) % align_mul;
75 
76    nir_src *src_offset = nir_get_io_offset_src(intr);
77    nir_def *offset = src_offset->ssa;
78    nir_def *result = &intr->def;
79 
80    /* Change the load to 32 bits per channel, update the channel count,
81     * and increase the declared load alignment.
82     */
83    intr->def.bit_size = 32;
84 
85    if (align_mul == 4 && align_offset == 0) {
86       intr->num_components = intr->def.num_components =
87          DIV_ROUND_UP(num_components, comp_per_dword);
88 
89       /* Aligned loads. Just bitcast the vector and trim it if there are
90        * trailing unused elements.
91        */
92       b->cursor = nir_after_instr(instr);
93       result = nir_extract_bits(b, &result, 1, 0, num_components, bit_size);
94 
95       nir_def_rewrite_uses_after(&intr->def, result,
96                                      result->parent_instr);
97       return true;
98    }
99 
100    b->cursor = nir_before_instr(instr);
101 
102    if (nir_intrinsic_has_base(intr)) {
103       offset = nir_iadd_imm(b, offset, nir_intrinsic_base(intr));
104       nir_intrinsic_set_base(intr, 0);
105    }
106 
107    /* Multi-component unaligned loads may straddle the dword boundary.
108     * E.g. for 2 components, we need to load an extra dword, and so on.
109     */
110    intr->num_components = intr->def.num_components =
111       DIV_ROUND_UP(4 - align_mul + align_offset + num_components * component_size, 4);
112 
113    nir_intrinsic_set_align(intr,
114                            MAX2(nir_intrinsic_align_mul(intr), 4),
115                            nir_intrinsic_align_offset(intr) & ~0x3);
116 
117    if (align_mul == 4) {
118       /* Unaligned loads with an aligned non-constant base offset (which is
119        * X * align_mul) and a constant added offset (align_offset).
120        */
121       assert(align_offset <= 3);
122       assert(align_offset % component_size == 0);
123       unsigned comp_offset = align_offset / component_size;
124 
125       /* There is a good probability that the offset is "iadd" adding
126        * align_offset. Subtracting align_offset should eliminate it.
127        */
128       nir_src_rewrite(src_offset, nir_iadd_imm(b, offset, -align_offset));
129 
130       b->cursor = nir_after_instr(instr);
131       result = nir_extract_bits(b, &result, 1, comp_offset * bit_size,
132                                 num_components, bit_size);
133 
134       nir_def_rewrite_uses_after(&intr->def, result,
135                                      result->parent_instr);
136       return true;
137    }
138 
139    /* Fully unaligned loads. We overfetch by up to 1 dword and then bitshift
140     * the whole vector.
141     */
142    assert(align_mul <= 2 && align_offset <= 3);
143 
144    /* Round down by masking out the bits. */
145    nir_src_rewrite(src_offset, nir_iand_imm(b, offset, ~0x3));
146 
147    /* We need to shift bits in the loaded vector by this number. */
148    b->cursor = nir_after_instr(instr);
149    nir_def *shift = nir_ishl_imm(b, nir_iand_imm(b, offset, 0x3), 3);
150    nir_def *rev_shift32 = nir_isub_imm(b, 32, shift);
151 
152    nir_def *elems[NIR_MAX_VEC_COMPONENTS];
153 
154    /* "shift" can be only be one of: 0, 8, 16, 24
155     *
156     * When we shift by (32 - shift) and shift is 0, resulting in a shift by 32,
157     * which is the same as a shift by 0, we need to convert the shifted number
158     * to u64 to get the shift by 32 that we want.
159     *
160     * The following algorithms are used to shift the vector.
161     *
162     * 64-bit variant (shr64 + shl64 + or32 per 2 elements):
163     *    for (i = 0; i < num_components / 2 - 1; i++) {
164     *       qword1 = pack(src[i * 2 + 0], src[i * 2 + 1]) >> shift;
165     *       dword2 = u2u32(u2u64(src[i * 2 + 2]) << (32 - shift));
166     *       dst[i * 2 + 0] = unpack_64_2x32_x(qword1);
167     *       dst[i * 2 + 1] = unpack_64_2x32_y(qword1) | dword2;
168     *    }
169     *    i *= 2;
170     *
171     * 32-bit variant (shr32 + shl64 + or32 per element):
172     *    for (; i < num_components - 1; i++)
173     *       dst[i] = (src[i] >> shift) |
174     *                u2u32(u2u64(src[i + 1]) << (32 - shift));
175     */
176    unsigned i = 0;
177 
178    if (intr->num_components >= 2) {
179       /* Use the 64-bit algorithm as described above. */
180       for (i = 0; i < intr->num_components / 2 - 1; i++) {
181          nir_def *qword1, *dword2;
182 
183          qword1 = nir_pack_64_2x32_split(b,
184                                          nir_channel(b, result, i * 2 + 0),
185                                          nir_channel(b, result, i * 2 + 1));
186          qword1 = nir_ushr(b, qword1, shift);
187          dword2 = nir_ishl(b, nir_u2u64(b, nir_channel(b, result, i * 2 + 2)),
188                            rev_shift32);
189          dword2 = nir_u2u32(b, dword2);
190 
191          elems[i * 2 + 0] = nir_unpack_64_2x32_split_x(b, qword1);
192          elems[i * 2 + 1] =
193             nir_ior(b, nir_unpack_64_2x32_split_y(b, qword1), dword2);
194       }
195       i *= 2;
196 
197       /* Use the 32-bit algorithm for the remainder of the vector. */
198       for (; i < intr->num_components - 1; i++) {
199          elems[i] =
200             nir_ior(b,
201                     nir_ushr(b, nir_channel(b, result, i), shift),
202                     nir_u2u32(b,
203                         nir_ishl(b, nir_u2u64(b, nir_channel(b, result, i + 1)),
204                                  rev_shift32)));
205       }
206    }
207 
208    /* Shift the last element. */
209    elems[i] = nir_ushr(b, nir_channel(b, result, i), shift);
210 
211    result = nir_vec(b, elems, intr->num_components);
212    result = nir_extract_bits(b, &result, 1, 0, num_components, bit_size);
213 
214    nir_def_rewrite_uses_after(&intr->def, result,
215                                   result->parent_instr);
216    return true;
217 }
218 
219 bool
ac_nir_lower_subdword_loads(nir_shader * nir,ac_nir_lower_subdword_options options)220 ac_nir_lower_subdword_loads(nir_shader *nir, ac_nir_lower_subdword_options options)
221 {
222    return nir_shader_instructions_pass(nir, lower_subdword_loads,
223                                        nir_metadata_control_flow, &options);
224 }
225