1 /*
2 * Copyright © 2022 Advanced Micro Devices, Inc.
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 /* Convert 8-bit and 16-bit loads to 32 bits. This is for drivers that don't
8 * support non-32-bit loads.
9 *
10 * This pass only transforms load intrinsics lowered by nir_lower_explicit_io,
11 * so this pass should run after it.
12 *
13 * nir_opt_load_store_vectorize should be run before this because it analyzes
14 * offset calculations and recomputes align_mul and align_offset.
15 *
16 * nir_opt_algebraic and (optionally) ALU scalarization are recommended to be
17 * run after this.
18 *
19 * Running nir_opt_load_store_vectorize after this pass may lead to further
20 * vectorization, e.g. adjacent 2x16-bit and 1x32-bit loads will become
21 * 2x32-bit loads.
22 */
23
24 #include "util/u_math.h"
25 #include "ac_nir.h"
26 #include "nir_builder.h"
27
28 static bool
lower_subdword_loads(nir_builder * b,nir_instr * instr,void * data)29 lower_subdword_loads(nir_builder *b, nir_instr *instr, void *data)
30 {
31 ac_nir_lower_subdword_options *options = data;
32
33 if (instr->type != nir_instr_type_intrinsic)
34 return false;
35
36 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
37 unsigned num_components = intr->num_components;
38 nir_variable_mode modes =
39 num_components == 1 ? options->modes_1_comp
40 : options->modes_N_comps;
41
42 switch (intr->intrinsic) {
43 case nir_intrinsic_load_ubo:
44 if (!(modes & nir_var_mem_ubo))
45 return false;
46 break;
47 case nir_intrinsic_load_ssbo:
48 if (!(modes & nir_var_mem_ssbo))
49 return false;
50 break;
51 case nir_intrinsic_load_global:
52 if (!(modes & nir_var_mem_global))
53 return false;
54 break;
55 case nir_intrinsic_load_push_constant:
56 if (!(modes & nir_var_mem_push_const))
57 return false;
58 break;
59 default:
60 return false;
61 }
62
63 unsigned bit_size = intr->def.bit_size;
64 if (bit_size >= 32)
65 return false;
66
67 assert(bit_size == 8 || bit_size == 16);
68
69 unsigned component_size = bit_size / 8;
70 unsigned comp_per_dword = 4 / component_size;
71
72 /* Get the offset alignment relative to the closest dword. */
73 unsigned align_mul = MIN2(nir_intrinsic_align_mul(intr), 4);
74 unsigned align_offset = nir_intrinsic_align_offset(intr) % align_mul;
75
76 nir_src *src_offset = nir_get_io_offset_src(intr);
77 nir_def *offset = src_offset->ssa;
78 nir_def *result = &intr->def;
79
80 /* Change the load to 32 bits per channel, update the channel count,
81 * and increase the declared load alignment.
82 */
83 intr->def.bit_size = 32;
84
85 if (align_mul == 4 && align_offset == 0) {
86 intr->num_components = intr->def.num_components =
87 DIV_ROUND_UP(num_components, comp_per_dword);
88
89 /* Aligned loads. Just bitcast the vector and trim it if there are
90 * trailing unused elements.
91 */
92 b->cursor = nir_after_instr(instr);
93 result = nir_extract_bits(b, &result, 1, 0, num_components, bit_size);
94
95 nir_def_rewrite_uses_after(&intr->def, result,
96 result->parent_instr);
97 return true;
98 }
99
100 b->cursor = nir_before_instr(instr);
101
102 if (nir_intrinsic_has_base(intr)) {
103 offset = nir_iadd_imm(b, offset, nir_intrinsic_base(intr));
104 nir_intrinsic_set_base(intr, 0);
105 }
106
107 /* Multi-component unaligned loads may straddle the dword boundary.
108 * E.g. for 2 components, we need to load an extra dword, and so on.
109 */
110 intr->num_components = intr->def.num_components =
111 DIV_ROUND_UP(4 - align_mul + align_offset + num_components * component_size, 4);
112
113 nir_intrinsic_set_align(intr,
114 MAX2(nir_intrinsic_align_mul(intr), 4),
115 nir_intrinsic_align_offset(intr) & ~0x3);
116
117 if (align_mul == 4) {
118 /* Unaligned loads with an aligned non-constant base offset (which is
119 * X * align_mul) and a constant added offset (align_offset).
120 */
121 assert(align_offset <= 3);
122 assert(align_offset % component_size == 0);
123 unsigned comp_offset = align_offset / component_size;
124
125 /* There is a good probability that the offset is "iadd" adding
126 * align_offset. Subtracting align_offset should eliminate it.
127 */
128 nir_src_rewrite(src_offset, nir_iadd_imm(b, offset, -align_offset));
129
130 b->cursor = nir_after_instr(instr);
131 result = nir_extract_bits(b, &result, 1, comp_offset * bit_size,
132 num_components, bit_size);
133
134 nir_def_rewrite_uses_after(&intr->def, result,
135 result->parent_instr);
136 return true;
137 }
138
139 /* Fully unaligned loads. We overfetch by up to 1 dword and then bitshift
140 * the whole vector.
141 */
142 assert(align_mul <= 2 && align_offset <= 3);
143
144 /* Round down by masking out the bits. */
145 nir_src_rewrite(src_offset, nir_iand_imm(b, offset, ~0x3));
146
147 /* We need to shift bits in the loaded vector by this number. */
148 b->cursor = nir_after_instr(instr);
149 nir_def *shift = nir_ishl_imm(b, nir_iand_imm(b, offset, 0x3), 3);
150 nir_def *rev_shift32 = nir_isub_imm(b, 32, shift);
151
152 nir_def *elems[NIR_MAX_VEC_COMPONENTS];
153
154 /* "shift" can be only be one of: 0, 8, 16, 24
155 *
156 * When we shift by (32 - shift) and shift is 0, resulting in a shift by 32,
157 * which is the same as a shift by 0, we need to convert the shifted number
158 * to u64 to get the shift by 32 that we want.
159 *
160 * The following algorithms are used to shift the vector.
161 *
162 * 64-bit variant (shr64 + shl64 + or32 per 2 elements):
163 * for (i = 0; i < num_components / 2 - 1; i++) {
164 * qword1 = pack(src[i * 2 + 0], src[i * 2 + 1]) >> shift;
165 * dword2 = u2u32(u2u64(src[i * 2 + 2]) << (32 - shift));
166 * dst[i * 2 + 0] = unpack_64_2x32_x(qword1);
167 * dst[i * 2 + 1] = unpack_64_2x32_y(qword1) | dword2;
168 * }
169 * i *= 2;
170 *
171 * 32-bit variant (shr32 + shl64 + or32 per element):
172 * for (; i < num_components - 1; i++)
173 * dst[i] = (src[i] >> shift) |
174 * u2u32(u2u64(src[i + 1]) << (32 - shift));
175 */
176 unsigned i = 0;
177
178 if (intr->num_components >= 2) {
179 /* Use the 64-bit algorithm as described above. */
180 for (i = 0; i < intr->num_components / 2 - 1; i++) {
181 nir_def *qword1, *dword2;
182
183 qword1 = nir_pack_64_2x32_split(b,
184 nir_channel(b, result, i * 2 + 0),
185 nir_channel(b, result, i * 2 + 1));
186 qword1 = nir_ushr(b, qword1, shift);
187 dword2 = nir_ishl(b, nir_u2u64(b, nir_channel(b, result, i * 2 + 2)),
188 rev_shift32);
189 dword2 = nir_u2u32(b, dword2);
190
191 elems[i * 2 + 0] = nir_unpack_64_2x32_split_x(b, qword1);
192 elems[i * 2 + 1] =
193 nir_ior(b, nir_unpack_64_2x32_split_y(b, qword1), dword2);
194 }
195 i *= 2;
196
197 /* Use the 32-bit algorithm for the remainder of the vector. */
198 for (; i < intr->num_components - 1; i++) {
199 elems[i] =
200 nir_ior(b,
201 nir_ushr(b, nir_channel(b, result, i), shift),
202 nir_u2u32(b,
203 nir_ishl(b, nir_u2u64(b, nir_channel(b, result, i + 1)),
204 rev_shift32)));
205 }
206 }
207
208 /* Shift the last element. */
209 elems[i] = nir_ushr(b, nir_channel(b, result, i), shift);
210
211 result = nir_vec(b, elems, intr->num_components);
212 result = nir_extract_bits(b, &result, 1, 0, num_components, bit_size);
213
214 nir_def_rewrite_uses_after(&intr->def, result,
215 result->parent_instr);
216 return true;
217 }
218
219 bool
ac_nir_lower_subdword_loads(nir_shader * nir,ac_nir_lower_subdword_options options)220 ac_nir_lower_subdword_loads(nir_shader *nir, ac_nir_lower_subdword_options options)
221 {
222 return nir_shader_instructions_pass(nir, lower_subdword_loads,
223 nir_metadata_control_flow, &options);
224 }
225