1 /*
2 * Copyright © 2021 Valve Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "ir3.h"
7 #include "ir3_nir.h"
8 #include "util/ralloc.h"
9
10 /* Lower several macro-instructions needed for shader subgroup support that
11 * must be turned into if statements. We do this after RA and post-RA
12 * scheduling to give the scheduler a chance to rearrange them, because RA
13 * may need to insert OPC_META_READ_FIRST to handle splitting live ranges, and
14 * also because some (e.g. BALLOT and READ_FIRST) must produce a shared
15 * register that cannot be spilled to a normal register until after the if,
16 * which makes implementing spilling more complicated if they are already
17 * lowered.
18 */
19
20 static void
replace_pred(struct ir3_block * block,struct ir3_block * old_pred,struct ir3_block * new_pred)21 replace_pred(struct ir3_block *block, struct ir3_block *old_pred,
22 struct ir3_block *new_pred)
23 {
24 for (unsigned i = 0; i < block->predecessors_count; i++) {
25 if (block->predecessors[i] == old_pred) {
26 block->predecessors[i] = new_pred;
27 return;
28 }
29 }
30 }
31
32 static void
replace_physical_pred(struct ir3_block * block,struct ir3_block * old_pred,struct ir3_block * new_pred)33 replace_physical_pred(struct ir3_block *block, struct ir3_block *old_pred,
34 struct ir3_block *new_pred)
35 {
36 for (unsigned i = 0; i < block->physical_predecessors_count; i++) {
37 if (block->physical_predecessors[i] == old_pred) {
38 block->physical_predecessors[i] = new_pred;
39 return;
40 }
41 }
42 }
43
44 static void
mov_immed(struct ir3_register * dst,struct ir3_block * block,unsigned immed)45 mov_immed(struct ir3_register *dst, struct ir3_block *block, unsigned immed)
46 {
47 struct ir3_instruction *mov = ir3_instr_create(block, OPC_MOV, 1, 1);
48 struct ir3_register *mov_dst = ir3_dst_create(mov, dst->num, dst->flags);
49 mov_dst->wrmask = dst->wrmask;
50 struct ir3_register *src = ir3_src_create(
51 mov, INVALID_REG, (dst->flags & IR3_REG_HALF) | IR3_REG_IMMED);
52 src->uim_val = immed;
53 mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
54 mov->cat1.src_type = mov->cat1.dst_type;
55 mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
56 }
57
58 static void
mov_reg(struct ir3_block * block,struct ir3_register * dst,struct ir3_register * src)59 mov_reg(struct ir3_block *block, struct ir3_register *dst,
60 struct ir3_register *src)
61 {
62 struct ir3_instruction *mov = ir3_instr_create(block, OPC_MOV, 1, 1);
63
64 struct ir3_register *mov_dst =
65 ir3_dst_create(mov, dst->num, dst->flags & (IR3_REG_HALF | IR3_REG_SHARED));
66 struct ir3_register *mov_src =
67 ir3_src_create(mov, src->num, src->flags & (IR3_REG_HALF | IR3_REG_SHARED));
68 mov_dst->wrmask = dst->wrmask;
69 mov_src->wrmask = src->wrmask;
70 mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
71
72 mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
73 mov->cat1.src_type = (src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
74 }
75
76 static void
binop(struct ir3_block * block,opc_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1)77 binop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
78 struct ir3_register *src0, struct ir3_register *src1)
79 {
80 struct ir3_instruction *instr = ir3_instr_create(block, opc, 1, 2);
81
82 unsigned flags = dst->flags & IR3_REG_HALF;
83 struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
84 struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
85 struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
86
87 instr_dst->wrmask = dst->wrmask;
88 instr_src0->wrmask = src0->wrmask;
89 instr_src1->wrmask = src1->wrmask;
90 instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
91 }
92
93 static void
triop(struct ir3_block * block,opc_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1,struct ir3_register * src2)94 triop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
95 struct ir3_register *src0, struct ir3_register *src1,
96 struct ir3_register *src2)
97 {
98 struct ir3_instruction *instr = ir3_instr_create(block, opc, 1, 3);
99
100 unsigned flags = dst->flags & IR3_REG_HALF;
101 struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
102 struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
103 struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
104 struct ir3_register *instr_src2 = ir3_src_create(instr, src2->num, flags);
105
106 instr_dst->wrmask = dst->wrmask;
107 instr_src0->wrmask = src0->wrmask;
108 instr_src1->wrmask = src1->wrmask;
109 instr_src2->wrmask = src2->wrmask;
110 instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
111 }
112
113 static void
do_reduce(struct ir3_block * block,reduce_op_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1)114 do_reduce(struct ir3_block *block, reduce_op_t opc,
115 struct ir3_register *dst, struct ir3_register *src0,
116 struct ir3_register *src1)
117 {
118 switch (opc) {
119 #define CASE(name) \
120 case REDUCE_OP_##name: \
121 binop(block, OPC_##name, dst, src0, src1); \
122 break;
123
124 CASE(ADD_U)
125 CASE(ADD_F)
126 CASE(MUL_F)
127 CASE(MIN_U)
128 CASE(MIN_S)
129 CASE(MIN_F)
130 CASE(MAX_U)
131 CASE(MAX_S)
132 CASE(MAX_F)
133 CASE(AND_B)
134 CASE(OR_B)
135 CASE(XOR_B)
136
137 #undef CASE
138
139 case REDUCE_OP_MUL_U:
140 if (dst->flags & IR3_REG_HALF) {
141 binop(block, OPC_MUL_S24, dst, src0, src1);
142 } else {
143 /* 32-bit multiplication macro - see ir3_nir_imul */
144 binop(block, OPC_MULL_U, dst, src0, src1);
145 triop(block, OPC_MADSH_M16, dst, src0, src1, dst);
146 triop(block, OPC_MADSH_M16, dst, src1, src0, dst);
147 }
148 break;
149 }
150 }
151
152 static struct ir3_block *
split_block(struct ir3 * ir,struct ir3_block * before_block,struct ir3_instruction * instr)153 split_block(struct ir3 *ir, struct ir3_block *before_block,
154 struct ir3_instruction *instr)
155 {
156 struct ir3_block *after_block = ir3_block_create(ir);
157 list_add(&after_block->node, &before_block->node);
158
159 for (unsigned i = 0; i < ARRAY_SIZE(before_block->successors); i++) {
160 after_block->successors[i] = before_block->successors[i];
161 if (after_block->successors[i])
162 replace_pred(after_block->successors[i], before_block, after_block);
163 }
164
165 for (unsigned i = 0; i < before_block->physical_successors_count; i++) {
166 replace_physical_pred(before_block->physical_successors[i],
167 before_block, after_block);
168 }
169
170 ralloc_steal(after_block, before_block->physical_successors);
171 after_block->physical_successors = before_block->physical_successors;
172 after_block->physical_successors_sz = before_block->physical_successors_sz;
173 after_block->physical_successors_count =
174 before_block->physical_successors_count;
175
176 before_block->successors[0] = before_block->successors[1] = NULL;
177 before_block->physical_successors = NULL;
178 before_block->physical_successors_count = 0;
179 before_block->physical_successors_sz = 0;
180
181 foreach_instr_from_safe (rem_instr, &instr->node,
182 &before_block->instr_list) {
183 list_del(&rem_instr->node);
184 list_addtail(&rem_instr->node, &after_block->instr_list);
185 rem_instr->block = after_block;
186 }
187
188 after_block->divergent_condition = before_block->divergent_condition;
189 before_block->divergent_condition = false;
190 return after_block;
191 }
192
193 static void
link_blocks(struct ir3_block * pred,struct ir3_block * succ,unsigned index)194 link_blocks(struct ir3_block *pred, struct ir3_block *succ, unsigned index)
195 {
196 pred->successors[index] = succ;
197 ir3_block_add_predecessor(succ, pred);
198 ir3_block_link_physical(pred, succ);
199 }
200
201 static void
link_blocks_jump(struct ir3_block * pred,struct ir3_block * succ)202 link_blocks_jump(struct ir3_block *pred, struct ir3_block *succ)
203 {
204 ir3_JUMP(pred);
205 link_blocks(pred, succ, 0);
206 }
207
208 static void
link_blocks_branch(struct ir3_block * pred,struct ir3_block * target,struct ir3_block * fallthrough,unsigned opc,unsigned flags,struct ir3_instruction * condition)209 link_blocks_branch(struct ir3_block *pred, struct ir3_block *target,
210 struct ir3_block *fallthrough, unsigned opc, unsigned flags,
211 struct ir3_instruction *condition)
212 {
213 unsigned nsrc = condition ? 1 : 0;
214 struct ir3_instruction *branch = ir3_instr_create(pred, opc, 0, nsrc);
215 branch->flags |= flags;
216
217 if (condition) {
218 struct ir3_register *cond_dst = condition->dsts[0];
219 struct ir3_register *src =
220 ir3_src_create(branch, cond_dst->num, cond_dst->flags);
221 src->def = cond_dst;
222 }
223
224 link_blocks(pred, target, 0);
225 link_blocks(pred, fallthrough, 1);
226
227 if (opc != OPC_BALL && opc != OPC_BANY) {
228 pred->divergent_condition = true;
229 }
230 }
231
232 static struct ir3_block *
create_if(struct ir3 * ir,struct ir3_block * before_block,struct ir3_block * after_block,unsigned opc,unsigned flags,struct ir3_instruction * condition)233 create_if(struct ir3 *ir, struct ir3_block *before_block,
234 struct ir3_block *after_block, unsigned opc, unsigned flags,
235 struct ir3_instruction *condition)
236 {
237 struct ir3_block *then_block = ir3_block_create(ir);
238 list_add(&then_block->node, &before_block->node);
239
240 link_blocks_branch(before_block, then_block, after_block, opc, flags,
241 condition);
242 link_blocks_jump(then_block, after_block);
243
244 return then_block;
245 }
246
247 static bool
lower_instr(struct ir3 * ir,struct ir3_block ** block,struct ir3_instruction * instr)248 lower_instr(struct ir3 *ir, struct ir3_block **block, struct ir3_instruction *instr)
249 {
250 switch (instr->opc) {
251 case OPC_BALLOT_MACRO:
252 case OPC_ANY_MACRO:
253 case OPC_ALL_MACRO:
254 case OPC_ELECT_MACRO:
255 case OPC_READ_COND_MACRO:
256 case OPC_SCAN_MACRO:
257 case OPC_SCAN_CLUSTERS_MACRO:
258 break;
259 case OPC_READ_FIRST_MACRO:
260 /* Moves to shared registers read the first active fiber, so we can just
261 * turn read_first.macro into a move. However we must still use the macro
262 * and lower it late because in ir3_cp we need to distinguish between
263 * moves where all source fibers contain the same value, which can be copy
264 * propagated, and moves generated from API-level ReadFirstInvocation
265 * which cannot.
266 */
267 assert(instr->dsts[0]->flags & IR3_REG_SHARED);
268 instr->opc = OPC_MOV;
269 instr->cat1.dst_type = TYPE_U32;
270 instr->cat1.src_type =
271 (instr->srcs[0]->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
272 return false;
273 default:
274 return false;
275 }
276
277 struct ir3_block *before_block = *block;
278 struct ir3_block *after_block = split_block(ir, before_block, instr);
279
280 if (instr->opc == OPC_SCAN_MACRO) {
281 /* The pseudo-code for the scan macro is:
282 *
283 * while (true) {
284 * header:
285 * if (elect()) {
286 * exit:
287 * exclusive = reduce;
288 * inclusive = src OP exclusive;
289 * reduce = inclusive;
290 * break;
291 * }
292 * footer:
293 * }
294 *
295 * This is based on the blob's sequence, and carefully crafted to avoid
296 * using the shared register "reduce" except in move instructions, since
297 * using it in the actual OP isn't possible for half-registers.
298 */
299 struct ir3_block *header = ir3_block_create(ir);
300 list_add(&header->node, &before_block->node);
301
302 struct ir3_block *exit = ir3_block_create(ir);
303 list_add(&exit->node, &header->node);
304
305 struct ir3_block *footer = ir3_block_create(ir);
306 list_add(&footer->node, &exit->node);
307 footer->reconvergence_point = true;
308
309 after_block->reconvergence_point = true;
310
311 link_blocks_jump(before_block, header);
312
313 link_blocks_branch(header, exit, footer, OPC_GETONE,
314 IR3_INSTR_NEEDS_HELPERS, NULL);
315
316 link_blocks_jump(exit, after_block);
317 ir3_block_link_physical(exit, footer);
318
319 link_blocks_jump(footer, header);
320
321 struct ir3_register *exclusive = instr->dsts[0];
322 struct ir3_register *inclusive = instr->dsts[1];
323 struct ir3_register *reduce = instr->dsts[2];
324 struct ir3_register *src = instr->srcs[0];
325
326 mov_reg(exit, exclusive, reduce);
327 do_reduce(exit, instr->cat1.reduce_op, inclusive, src, exclusive);
328 mov_reg(exit, reduce, inclusive);
329 } else if (instr->opc == OPC_SCAN_CLUSTERS_MACRO) {
330 /* The pseudo-code for the scan macro is:
331 *
332 * while (true) {
333 * body:
334 * scratch = reduce;
335 *
336 * inclusive = inclusive_src OP scratch;
337 *
338 * static if (is exclusive scan)
339 * exclusive = exclusive_src OP scratch
340 *
341 * if (getlast()) {
342 * store:
343 * reduce = inclusive;
344 * if (elect())
345 * break;
346 * } else {
347 * break;
348 * }
349 * }
350 * after_block:
351 */
352 struct ir3_block *body = ir3_block_create(ir);
353 list_add(&body->node, &before_block->node);
354
355 struct ir3_block *store = ir3_block_create(ir);
356 list_add(&store->node, &body->node);
357
358 after_block->reconvergence_point = true;
359
360 link_blocks_jump(before_block, body);
361
362 link_blocks_branch(body, store, after_block, OPC_GETLAST, 0, NULL);
363
364 link_blocks_branch(store, after_block, body, OPC_GETONE,
365 IR3_INSTR_NEEDS_HELPERS, NULL);
366
367 struct ir3_register *reduce = instr->dsts[0];
368 struct ir3_register *inclusive = instr->dsts[1];
369 struct ir3_register *inclusive_src = instr->srcs[1];
370
371 /* We need to perform the following operations:
372 * - inclusive = inclusive_src OP reduce
373 * - exclusive = exclusive_src OP reduce (iff exclusive scan)
374 * Since reduce is initially in a shared register, we need to copy it to a
375 * scratch register before performing the operations.
376 *
377 * The scratch register used is:
378 * - an explicitly allocated one if op is 32b mul_u.
379 * - necessary because we cannot do 'foo = foo mul_u bar' since mul_u
380 * clobbers its destination.
381 * - exclusive if this is an exclusive scan (and not 32b mul_u).
382 * - since we calculate inclusive first.
383 * - inclusive otherwise.
384 *
385 * In all cases, this is the last destination.
386 */
387 struct ir3_register *scratch = instr->dsts[instr->dsts_count - 1];
388
389 mov_reg(body, scratch, reduce);
390 do_reduce(body, instr->cat1.reduce_op, inclusive, inclusive_src, scratch);
391
392 /* exclusive scan */
393 if (instr->srcs_count == 3) {
394 struct ir3_register *exclusive_src = instr->srcs[2];
395 struct ir3_register *exclusive = instr->dsts[2];
396 do_reduce(body, instr->cat1.reduce_op, exclusive, exclusive_src,
397 scratch);
398 }
399
400 mov_reg(store, reduce, inclusive);
401 } else {
402 /* For ballot, the destination must be initialized to 0 before we do
403 * the movmsk because the condition may be 0 and then the movmsk will
404 * be skipped.
405 */
406 if (instr->opc == OPC_BALLOT_MACRO) {
407 mov_immed(instr->dsts[0], before_block, 0);
408 }
409
410 struct ir3_instruction *condition = NULL;
411 unsigned branch_opc = 0;
412 unsigned branch_flags = 0;
413
414 switch (instr->opc) {
415 case OPC_BALLOT_MACRO:
416 case OPC_READ_COND_MACRO:
417 case OPC_ANY_MACRO:
418 case OPC_ALL_MACRO:
419 condition = instr->srcs[0]->def->instr;
420 break;
421 default:
422 break;
423 }
424
425 switch (instr->opc) {
426 case OPC_BALLOT_MACRO:
427 case OPC_READ_COND_MACRO:
428 after_block->reconvergence_point = true;
429 branch_opc = OPC_BR;
430 break;
431 case OPC_ANY_MACRO:
432 branch_opc = OPC_BANY;
433 break;
434 case OPC_ALL_MACRO:
435 branch_opc = OPC_BALL;
436 break;
437 case OPC_ELECT_MACRO:
438 after_block->reconvergence_point = true;
439 branch_opc = OPC_GETONE;
440 branch_flags = instr->flags & IR3_INSTR_NEEDS_HELPERS;
441 break;
442 default:
443 unreachable("bad opcode");
444 }
445
446 struct ir3_block *then_block =
447 create_if(ir, before_block, after_block, branch_opc, branch_flags,
448 condition);
449
450 switch (instr->opc) {
451 case OPC_ALL_MACRO:
452 case OPC_ANY_MACRO:
453 case OPC_ELECT_MACRO:
454 mov_immed(instr->dsts[0], then_block, 1);
455 mov_immed(instr->dsts[0], before_block, 0);
456 break;
457
458 case OPC_BALLOT_MACRO: {
459 unsigned comp_count = util_last_bit(instr->dsts[0]->wrmask);
460 struct ir3_instruction *movmsk =
461 ir3_instr_create(then_block, OPC_MOVMSK, 1, 0);
462 ir3_dst_create(movmsk, instr->dsts[0]->num, instr->dsts[0]->flags);
463 movmsk->repeat = comp_count - 1;
464 break;
465 }
466
467 case OPC_READ_COND_MACRO: {
468 struct ir3_instruction *mov =
469 ir3_instr_create(then_block, OPC_MOV, 1, 1);
470 ir3_dst_create(mov, instr->dsts[0]->num, instr->dsts[0]->flags);
471 struct ir3_register *new_src = ir3_src_create(mov, 0, 0);
472 *new_src = *instr->srcs[1];
473 mov->cat1.dst_type = TYPE_U32;
474 mov->cat1.src_type =
475 (new_src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
476 mov->flags |= IR3_INSTR_NEEDS_HELPERS;
477 break;
478 }
479
480 default:
481 unreachable("bad opcode");
482 }
483 }
484
485 *block = after_block;
486 list_delinit(&instr->node);
487 return true;
488 }
489
490 static bool
lower_block(struct ir3 * ir,struct ir3_block ** block)491 lower_block(struct ir3 *ir, struct ir3_block **block)
492 {
493 bool progress = true;
494
495 bool inner_progress;
496 do {
497 inner_progress = false;
498 foreach_instr (instr, &(*block)->instr_list) {
499 if (lower_instr(ir, block, instr)) {
500 /* restart the loop with the new block we created because the
501 * iterator has been invalidated.
502 */
503 progress = inner_progress = true;
504 break;
505 }
506 }
507 } while (inner_progress);
508
509 return progress;
510 }
511
512 bool
ir3_lower_subgroups(struct ir3 * ir)513 ir3_lower_subgroups(struct ir3 *ir)
514 {
515 bool progress = false;
516
517 foreach_block (block, &ir->block_list)
518 progress |= lower_block(ir, &block);
519
520 return progress;
521 }
522
523 static bool
filter_scan_reduce(const nir_instr * instr,const void * data)524 filter_scan_reduce(const nir_instr *instr, const void *data)
525 {
526 if (instr->type != nir_instr_type_intrinsic)
527 return false;
528
529 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
530
531 switch (intrin->intrinsic) {
532 case nir_intrinsic_reduce:
533 case nir_intrinsic_inclusive_scan:
534 case nir_intrinsic_exclusive_scan:
535 return true;
536 default:
537 return false;
538 }
539 }
540
541 static nir_def *
lower_scan_reduce(struct nir_builder * b,nir_instr * instr,void * data)542 lower_scan_reduce(struct nir_builder *b, nir_instr *instr, void *data)
543 {
544 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
545 unsigned bit_size = intrin->def.bit_size;
546
547 nir_op op = nir_intrinsic_reduction_op(intrin);
548 nir_const_value ident_val = nir_alu_binop_identity(op, bit_size);
549 nir_def *ident = nir_build_imm(b, 1, bit_size, &ident_val);
550 nir_def *inclusive = intrin->src[0].ssa;
551 nir_def *exclusive = ident;
552
553 for (unsigned cluster_size = 2; cluster_size <= 8; cluster_size *= 2) {
554 nir_def *brcst = nir_brcst_active_ir3(b, ident, inclusive,
555 .cluster_size = cluster_size);
556 inclusive = nir_build_alu2(b, op, inclusive, brcst);
557
558 if (intrin->intrinsic == nir_intrinsic_exclusive_scan)
559 exclusive = nir_build_alu2(b, op, exclusive, brcst);
560 }
561
562 switch (intrin->intrinsic) {
563 case nir_intrinsic_reduce:
564 return nir_reduce_clusters_ir3(b, inclusive, .reduction_op = op);
565 case nir_intrinsic_inclusive_scan:
566 return nir_inclusive_scan_clusters_ir3(b, inclusive, .reduction_op = op);
567 case nir_intrinsic_exclusive_scan:
568 return nir_exclusive_scan_clusters_ir3(b, inclusive, exclusive,
569 .reduction_op = op);
570 default:
571 unreachable("filtered intrinsic");
572 }
573 }
574
575 bool
ir3_nir_opt_subgroups(nir_shader * nir,struct ir3_shader_variant * v)576 ir3_nir_opt_subgroups(nir_shader *nir, struct ir3_shader_variant *v)
577 {
578 if (!v->compiler->has_getfiberid)
579 return false;
580
581 return nir_shader_lower_instructions(nir, filter_scan_reduce,
582 lower_scan_reduce, NULL);
583 }
584