1 /*
2 * Copyright 2024 Advanced Micro Devices, Inc.
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 /**
8 * This pass:
9 * - vectorizes lowered input/output loads and stores
10 * - vectorizes low and high 16-bit loads and stores by merging them into
11 * a single 32-bit load or store (except load_interpolated_input, which has
12 * to keep bit_size=16)
13 * - performs DCE of output stores that overwrite the previous value by writing
14 * into the same slot and component.
15 *
16 * Vectorization is only local within basic blocks. No vectorization occurs
17 * across basic block boundaries, barriers (only TCS outputs), emits (only
18 * GS outputs), and output load <-> output store dependencies.
19 *
20 * All loads and stores must be scalar. 64-bit loads and stores are forbidden.
21 *
22 * For each basic block, the time complexity is O(n*log(n)) where n is
23 * the number of IO instructions within that block.
24 */
25
26 #include "nir.h"
27 #include "nir_builder.h"
28 #include "util/u_dynarray.h"
29
30 /* Return 0 if loads/stores are vectorizable. Return 1 or -1 to define
31 * an ordering between non-vectorizable instructions. This is used by qsort,
32 * to sort all gathered instructions into groups of vectorizable instructions.
33 */
34 static int
compare_is_not_vectorizable(nir_intrinsic_instr * a,nir_intrinsic_instr * b)35 compare_is_not_vectorizable(nir_intrinsic_instr *a, nir_intrinsic_instr *b)
36 {
37 if (a->intrinsic != b->intrinsic)
38 return a->intrinsic > b->intrinsic ? 1 : -1;
39
40 nir_src *offset0 = nir_get_io_offset_src(a);
41 nir_src *offset1 = nir_get_io_offset_src(b);
42 if (offset0 && offset0->ssa != offset1->ssa)
43 return offset0->ssa->index > offset1->ssa->index ? 1 : -1;
44
45 nir_src *array_idx0 = nir_get_io_arrayed_index_src(a);
46 nir_src *array_idx1 = nir_get_io_arrayed_index_src(b);
47 if (array_idx0 && array_idx0->ssa != array_idx1->ssa)
48 return array_idx0->ssa->index > array_idx1->ssa->index ? 1 : -1;
49
50 /* Compare barycentrics or vertex index. */
51 if ((a->intrinsic == nir_intrinsic_load_interpolated_input ||
52 a->intrinsic == nir_intrinsic_load_input_vertex) &&
53 a->src[0].ssa != b->src[0].ssa)
54 return a->src[0].ssa->index > b->src[0].ssa->index ? 1 : -1;
55
56 nir_io_semantics sem0 = nir_intrinsic_io_semantics(a);
57 nir_io_semantics sem1 = nir_intrinsic_io_semantics(b);
58 if (sem0.location != sem1.location)
59 return sem0.location > sem1.location ? 1 : -1;
60
61 /* The mediump flag isn't mergable. */
62 if (sem0.medium_precision != sem1.medium_precision)
63 return sem0.medium_precision > sem1.medium_precision ? 1 : -1;
64
65 /* Don't merge per-view attributes with non-per-view attributes. */
66 if (sem0.per_view != sem1.per_view)
67 return sem0.per_view > sem1.per_view ? 1 : -1;
68
69 if (sem0.interp_explicit_strict != sem1.interp_explicit_strict)
70 return sem0.interp_explicit_strict > sem1.interp_explicit_strict ? 1 : -1;
71
72 /* Only load_interpolated_input can't merge low and high halves of 16-bit
73 * loads/stores.
74 */
75 if (a->intrinsic == nir_intrinsic_load_interpolated_input &&
76 sem0.high_16bits != sem1.high_16bits)
77 return sem0.high_16bits > sem1.high_16bits ? 1 : -1;
78
79 nir_shader *shader =
80 nir_cf_node_get_function(&a->instr.block->cf_node)->function->shader;
81
82 /* Compare the types. */
83 if (!(shader->options->io_options & nir_io_vectorizer_ignores_types)) {
84 unsigned type_a, type_b;
85
86 if (nir_intrinsic_has_src_type(a)) {
87 type_a = nir_intrinsic_src_type(a);
88 type_b = nir_intrinsic_src_type(b);
89 } else {
90 type_a = nir_intrinsic_dest_type(a);
91 type_b = nir_intrinsic_dest_type(b);
92 }
93
94 if (type_a != type_b)
95 return type_a > type_b ? 1 : -1;
96 }
97
98 return 0;
99 }
100
101 static int
compare_intr(const void * xa,const void * xb)102 compare_intr(const void *xa, const void *xb)
103 {
104 nir_intrinsic_instr *a = *(nir_intrinsic_instr **)xa;
105 nir_intrinsic_instr *b = *(nir_intrinsic_instr **)xb;
106
107 int comp = compare_is_not_vectorizable(a, b);
108 if (comp)
109 return comp;
110
111 /* qsort isn't stable. This ensures that later stores aren't moved before earlier stores. */
112 return a->instr.index > b->instr.index ? 1 : -1;
113 }
114
115 static void
vectorize_load(nir_intrinsic_instr * chan[8],unsigned start,unsigned count,bool merge_low_high_16_to_32)116 vectorize_load(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
117 bool merge_low_high_16_to_32)
118 {
119 nir_intrinsic_instr *first = NULL;
120
121 /* Find the first instruction where the vectorized load will be
122 * inserted.
123 */
124 for (unsigned i = start; i < start + count; i++) {
125 first = !first || chan[i]->instr.index < first->instr.index ?
126 chan[i] : first;
127 if (merge_low_high_16_to_32) {
128 first = !first || chan[4 + i]->instr.index < first->instr.index ?
129 chan[4 + i] : first;
130 }
131 }
132
133 /* Insert the vectorized load. */
134 nir_builder b = nir_builder_at(nir_before_instr(&first->instr));
135 nir_intrinsic_instr *new_intr =
136 nir_intrinsic_instr_create(b.shader, first->intrinsic);
137
138 new_intr->num_components = count;
139 nir_def_init(&new_intr->instr, &new_intr->def, count,
140 merge_low_high_16_to_32 ? 32 : first->def.bit_size);
141 memcpy(new_intr->src, first->src,
142 nir_intrinsic_infos[first->intrinsic].num_srcs * sizeof(nir_src));
143 nir_intrinsic_copy_const_indices(new_intr, first);
144 nir_intrinsic_set_component(new_intr, start);
145
146 if (merge_low_high_16_to_32) {
147 nir_io_semantics sem = nir_intrinsic_io_semantics(new_intr);
148 sem.high_16bits = 0;
149 nir_intrinsic_set_io_semantics(new_intr, sem);
150 nir_intrinsic_set_dest_type(new_intr,
151 (nir_intrinsic_dest_type(new_intr) & ~16) | 32);
152 }
153
154 nir_builder_instr_insert(&b, &new_intr->instr);
155 nir_def *def = &new_intr->def;
156
157 /* Replace the scalar loads. */
158 if (merge_low_high_16_to_32) {
159 for (unsigned i = start; i < start + count; i++) {
160 nir_def *comp = nir_channel(&b, def, i - start);
161
162 nir_def_rewrite_uses(&chan[i]->def,
163 nir_unpack_32_2x16_split_x(&b, comp));
164 nir_def_rewrite_uses(&chan[4 + i]->def,
165 nir_unpack_32_2x16_split_y(&b, comp));
166 nir_instr_remove(&chan[i]->instr);
167 nir_instr_remove(&chan[4 + i]->instr);
168 }
169 } else {
170 for (unsigned i = start; i < start + count; i++) {
171 nir_def_replace(&chan[i]->def, nir_channel(&b, def, i - start));
172 }
173 }
174 }
175
176 static void
vectorize_store(nir_intrinsic_instr * chan[8],unsigned start,unsigned count,bool merge_low_high_16_to_32)177 vectorize_store(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
178 bool merge_low_high_16_to_32)
179 {
180 nir_intrinsic_instr *last = NULL;
181
182 /* Find the last instruction where the vectorized store will be
183 * inserted.
184 */
185 for (unsigned i = start; i < start + count; i++) {
186 last = !last || chan[i]->instr.index > last->instr.index ?
187 chan[i] : last;
188 if (merge_low_high_16_to_32) {
189 last = !last || chan[4 + i]->instr.index > last->instr.index ?
190 chan[4 + i] : last;
191 }
192 }
193
194 /* Change the last instruction to a vectorized store. Update xfb first
195 * because we need to read some info from "last" before overwriting it.
196 */
197 if (nir_intrinsic_has_io_xfb(last)) {
198 nir_io_xfb xfb[2] = {{{{0}}}};
199
200 for (unsigned i = start; i < start + count; i++) {
201 xfb[i / 2].out[i % 2] =
202 (i < 2 ? nir_intrinsic_io_xfb(chan[i]) :
203 nir_intrinsic_io_xfb2(chan[i])).out[i % 2];
204
205 /* Merging low and high 16 bits to 32 bits is not possible
206 * with xfb in some cases. (and it's not implemented for
207 * cases where it's possible)
208 */
209 assert(!xfb[i / 2].out[i % 2].num_components ||
210 !merge_low_high_16_to_32);
211 }
212
213 /* Now vectorize xfb info by merging the individual elements. */
214 for (unsigned i = start; i < start + count; i++) {
215 /* mediump means that xfb upconverts to 32 bits when writing to
216 * memory.
217 */
218 unsigned xfb_comp_size =
219 nir_intrinsic_io_semantics(chan[i]).medium_precision ?
220 32 : chan[i]->src[0].ssa->bit_size;
221
222 for (unsigned j = i + 1; j < start + count; j++) {
223 if (xfb[i / 2].out[i % 2].buffer != xfb[j / 2].out[j % 2].buffer ||
224 xfb[i / 2].out[i % 2].offset != xfb[j / 2].out[j % 2].offset +
225 xfb_comp_size * (j - i))
226 break;
227
228 xfb[i / 2].out[i % 2].num_components++;
229 memset(&xfb[j / 2].out[j % 2], 0, sizeof(xfb[j / 2].out[j % 2]));
230 }
231 }
232
233 nir_intrinsic_set_io_xfb(last, xfb[0]);
234 nir_intrinsic_set_io_xfb2(last, xfb[1]);
235 }
236
237 /* Update gs_streams. */
238 unsigned gs_streams = 0;
239 for (unsigned i = start; i < start + count; i++) {
240 gs_streams |= (nir_intrinsic_io_semantics(chan[i]).gs_streams & 0x3) <<
241 ((i - start) * 2);
242 }
243
244 nir_io_semantics sem = nir_intrinsic_io_semantics(last);
245 sem.gs_streams = gs_streams;
246
247 /* Update other flags. */
248 for (unsigned i = start; i < start + count; i++) {
249 if (!nir_intrinsic_io_semantics(chan[i]).no_sysval_output)
250 sem.no_sysval_output = 0;
251 if (!nir_intrinsic_io_semantics(chan[i]).no_varying)
252 sem.no_varying = 0;
253 if (nir_intrinsic_io_semantics(chan[i]).invariant)
254 sem.invariant = 1;
255 }
256
257 if (merge_low_high_16_to_32) {
258 /* Update "no" flags for high bits. */
259 for (unsigned i = start; i < start + count; i++) {
260 if (!nir_intrinsic_io_semantics(chan[4 + i]).no_sysval_output)
261 sem.no_sysval_output = 0;
262 if (!nir_intrinsic_io_semantics(chan[4 + i]).no_varying)
263 sem.no_varying = 0;
264 if (nir_intrinsic_io_semantics(chan[4 + i]).invariant)
265 sem.invariant = 1;
266 }
267
268 /* Update the type. */
269 sem.high_16bits = 0;
270 nir_intrinsic_set_src_type(last,
271 (nir_intrinsic_src_type(last) & ~16) | 32);
272 }
273
274 /* TODO: Merge names? */
275
276 /* Update the rest. */
277 nir_intrinsic_set_io_semantics(last, sem);
278 nir_intrinsic_set_component(last, start);
279 nir_intrinsic_set_write_mask(last, BITFIELD_MASK(count));
280 last->num_components = count;
281
282 nir_builder b = nir_builder_at(nir_before_instr(&last->instr));
283
284 /* Replace the stored scalar with the vector. */
285 if (merge_low_high_16_to_32) {
286 nir_def *value[4];
287 for (unsigned i = start; i < start + count; i++) {
288 value[i] = nir_pack_32_2x16_split(&b, chan[i]->src[0].ssa,
289 chan[4 + i]->src[0].ssa);
290 }
291
292 nir_src_rewrite(&last->src[0], nir_vec(&b, &value[start], count));
293 } else {
294 nir_def *value[4];
295 for (unsigned i = start; i < start + count; i++)
296 value[i] = chan[i]->src[0].ssa;
297
298 nir_src_rewrite(&last->src[0], nir_vec(&b, &value[start], count));
299 }
300
301 /* Remove the scalar stores. */
302 for (unsigned i = start; i < start + count; i++) {
303 if (chan[i] != last)
304 nir_instr_remove(&chan[i]->instr);
305 if (merge_low_high_16_to_32 && chan[4 + i] != last)
306 nir_instr_remove(&chan[4 + i]->instr);
307 }
308 }
309
310 /* Vectorize a vector of scalar instructions. chan[8] are the channels.
311 * (the last 4 are the high 16-bit channels)
312 */
313 static bool
vectorize_slot(nir_intrinsic_instr * chan[8],unsigned mask)314 vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask)
315 {
316 bool progress = false;
317
318 /* First, merge low and high 16-bit halves into 32 bits separately when
319 * possible. Then vectorize what's left.
320 */
321 for (int merge_low_high_16_to_32 = 1; merge_low_high_16_to_32 >= 0;
322 merge_low_high_16_to_32--) {
323 unsigned scan_mask;
324
325 if (merge_low_high_16_to_32) {
326 /* Get the subset of the mask where both low and high bits are set. */
327 scan_mask = 0;
328 for (unsigned i = 0; i < 4; i++) {
329 unsigned low_high_bits = BITFIELD_BIT(i) | BITFIELD_BIT(i + 4);
330
331 if ((mask & low_high_bits) == low_high_bits) {
332 /* Merging low and high 16 bits to 32 bits is not possible
333 * with xfb in some cases. (and it's not implemented for
334 * cases where it's possible)
335 */
336 if (nir_intrinsic_has_io_xfb(chan[i])) {
337 unsigned hi = i + 4;
338
339 if ((i < 2 ? nir_intrinsic_io_xfb(chan[i])
340 : nir_intrinsic_io_xfb2(chan[i])).out[i % 2].num_components ||
341 (i < 2 ? nir_intrinsic_io_xfb(chan[hi])
342 : nir_intrinsic_io_xfb2(chan[hi])).out[i % 2].num_components)
343 continue;
344 }
345
346 /* The GS stream must be the same for both halves. */
347 if ((nir_intrinsic_io_semantics(chan[i]).gs_streams & 0x3) !=
348 (nir_intrinsic_io_semantics(chan[4 + i]).gs_streams & 0x3))
349 continue;
350
351 scan_mask |= BITFIELD_BIT(i);
352 mask &= ~low_high_bits;
353 }
354 }
355 } else {
356 scan_mask = mask;
357 }
358
359 while (scan_mask) {
360 int start, count;
361
362 u_bit_scan_consecutive_range(&scan_mask, &start, &count);
363
364 if (count == 1 && !merge_low_high_16_to_32)
365 continue; /* There is nothing to vectorize. */
366
367 bool is_load = nir_intrinsic_infos[chan[start]->intrinsic].has_dest;
368
369 if (is_load)
370 vectorize_load(chan, start, count, merge_low_high_16_to_32);
371 else
372 vectorize_store(chan, start, count, merge_low_high_16_to_32);
373
374 progress = true;
375 }
376 }
377
378 return progress;
379 }
380
381 static bool
vectorize_batch(struct util_dynarray * io_instructions)382 vectorize_batch(struct util_dynarray *io_instructions)
383 {
384 unsigned num_instr = util_dynarray_num_elements(io_instructions, void *);
385
386 /* We need to at least 2 instructions to have something to do. */
387 if (num_instr <= 1) {
388 /* Clear the array. The next block will reuse it. */
389 util_dynarray_clear(io_instructions);
390 return false;
391 }
392
393 /* The instructions are sorted such that groups of vectorizable
394 * instructions are next to each other. Multiple incompatible
395 * groups of vectorizable instructions can occur in this array.
396 * The reason why 2 groups would be incompatible is that they
397 * could have a different intrinsic, indirect index, array index,
398 * vertex index, barycentrics, or location. Each group is vectorized
399 * separately.
400 *
401 * This reorders instructions in the array, but not in the shader.
402 */
403 qsort(io_instructions->data, num_instr, sizeof(void*), compare_intr);
404
405 nir_intrinsic_instr *chan[8] = {0}, *prev = NULL;
406 unsigned chan_mask = 0;
407 bool progress = false;
408
409 /* Vectorize all groups.
410 *
411 * The channels for each group are gathered. If 2 stores overwrite
412 * the same channel, the earlier store is DCE'd here.
413 */
414 util_dynarray_foreach(io_instructions, nir_intrinsic_instr *, intr) {
415 /* If the next instruction is not vectorizable, vectorize what
416 * we have gathered so far.
417 */
418 if (prev && compare_is_not_vectorizable(prev, *intr)) {
419 /* We need at least 2 instructions to have something to do. */
420 if (util_bitcount(chan_mask) > 1)
421 progress |= vectorize_slot(chan, chan_mask);
422
423 prev = NULL;
424 memset(chan, 0, sizeof(chan));
425 chan_mask = 0;
426 }
427
428 /* This performs DCE of output stores because the previous value
429 * is being overwritten.
430 */
431 unsigned index = nir_intrinsic_io_semantics(*intr).high_16bits * 4 +
432 nir_intrinsic_component(*intr);
433 bool is_store = !nir_intrinsic_infos[(*intr)->intrinsic].has_dest;
434 if (is_store && chan[index])
435 nir_instr_remove(&chan[index]->instr);
436
437 /* Gather the channel. */
438 chan[index] = *intr;
439 prev = *intr;
440 chan_mask |= BITFIELD_BIT(index);
441 }
442
443 /* Vectorize the last group. */
444 if (prev && util_bitcount(chan_mask) > 1)
445 progress |= vectorize_slot(chan, chan_mask);
446
447 /* Clear the array. The next block will reuse it. */
448 util_dynarray_clear(io_instructions);
449 return progress;
450 }
451
452 bool
nir_opt_vectorize_io(nir_shader * shader,nir_variable_mode modes)453 nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
454 {
455 assert(!(modes & ~(nir_var_shader_in | nir_var_shader_out)));
456
457 if (shader->info.stage == MESA_SHADER_FRAGMENT &&
458 shader->options->io_options & nir_io_prefer_scalar_fs_inputs)
459 modes &= ~nir_var_shader_in;
460
461 if ((shader->info.stage == MESA_SHADER_TESS_CTRL ||
462 shader->info.stage == MESA_SHADER_GEOMETRY) &&
463 util_bitcount(modes) == 2) {
464 /* When vectorizing TCS and GS IO, inputs can ignore barriers and emits,
465 * but that is only done when outputs are ignored, so vectorize them
466 * separately.
467 */
468 return nir_opt_vectorize_io(shader, nir_var_shader_in) ||
469 nir_opt_vectorize_io(shader, nir_var_shader_out);
470 }
471
472 /* Initialize dynamic arrays. */
473 struct util_dynarray io_instructions;
474 util_dynarray_init(&io_instructions, NULL);
475 bool global_progress = false;
476
477 nir_foreach_function_impl(impl, shader) {
478 bool progress = false;
479 nir_metadata_require(impl, nir_metadata_instr_index);
480
481 nir_foreach_block(block, impl) {
482 BITSET_DECLARE(has_output_loads, NUM_TOTAL_VARYING_SLOTS * 8);
483 BITSET_DECLARE(has_output_stores, NUM_TOTAL_VARYING_SLOTS * 8);
484 BITSET_ZERO(has_output_loads);
485 BITSET_ZERO(has_output_stores);
486
487 /* Gather load/store intrinsics within the block. */
488 nir_foreach_instr(instr, block) {
489 if (instr->type != nir_instr_type_intrinsic)
490 continue;
491
492 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
493 bool is_load = nir_intrinsic_infos[intr->intrinsic].has_dest;
494 bool is_output = false;
495 nir_io_semantics sem = {0};
496 unsigned index = 0;
497
498 if (nir_intrinsic_has_io_semantics(intr)) {
499 sem = nir_intrinsic_io_semantics(intr);
500 assert(sem.location < NUM_TOTAL_VARYING_SLOTS);
501 index = sem.location * 8 + sem.high_16bits * 4 +
502 nir_intrinsic_component(intr);
503 }
504
505 switch (intr->intrinsic) {
506 case nir_intrinsic_load_input:
507 case nir_intrinsic_load_per_primitive_input:
508 case nir_intrinsic_load_input_vertex:
509 case nir_intrinsic_load_interpolated_input:
510 case nir_intrinsic_load_per_vertex_input:
511 if (!(modes & nir_var_shader_in))
512 continue;
513 break;
514
515 case nir_intrinsic_load_output:
516 case nir_intrinsic_load_per_vertex_output:
517 case nir_intrinsic_load_per_primitive_output:
518 case nir_intrinsic_store_output:
519 case nir_intrinsic_store_per_vertex_output:
520 case nir_intrinsic_store_per_primitive_output:
521 if (!(modes & nir_var_shader_out))
522 continue;
523
524 /* Break the batch if an output load is followed by an output
525 * store to the same channel and vice versa.
526 */
527 if (BITSET_TEST(is_load ? has_output_stores : has_output_loads,
528 index)) {
529 progress |= vectorize_batch(&io_instructions);
530 BITSET_ZERO(has_output_loads);
531 BITSET_ZERO(has_output_stores);
532 }
533 is_output = true;
534 break;
535
536 case nir_intrinsic_barrier:
537 /* Don't vectorize across TCS barriers. */
538 if (modes & nir_var_shader_out &&
539 nir_intrinsic_memory_modes(intr) & nir_var_shader_out) {
540 progress |= vectorize_batch(&io_instructions);
541 BITSET_ZERO(has_output_loads);
542 BITSET_ZERO(has_output_stores);
543 }
544 continue;
545
546 case nir_intrinsic_emit_vertex:
547 /* Don't vectorize across GS emits. */
548 progress |= vectorize_batch(&io_instructions);
549 BITSET_ZERO(has_output_loads);
550 BITSET_ZERO(has_output_stores);
551 continue;
552
553 default:
554 continue;
555 }
556
557 /* Only scalar 16 and 32-bit instructions are allowed. */
558 ASSERTED nir_def *value = is_load ? &intr->def : intr->src[0].ssa;
559 assert(value->num_components == 1);
560 assert(value->bit_size == 16 || value->bit_size == 32);
561
562 util_dynarray_append(&io_instructions, void *, intr);
563 if (is_output)
564 BITSET_SET(is_load ? has_output_loads : has_output_stores, index);
565 }
566
567 progress |= vectorize_batch(&io_instructions);
568 }
569
570 nir_metadata_preserve(impl, progress ? (nir_metadata_block_index |
571 nir_metadata_dominance) :
572 nir_metadata_all);
573 global_progress |= progress;
574 }
575 util_dynarray_fini(&io_instructions);
576
577 return global_progress;
578 }
579