xref: /aosp_15_r20/external/mesa3d/src/freedreno/ir3/ir3_lower_subgroups.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2021 Valve Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "ir3.h"
7 #include "ir3_nir.h"
8 #include "util/ralloc.h"
9 
10 /* Lower several macro-instructions needed for shader subgroup support that
11  * must be turned into if statements. We do this after RA and post-RA
12  * scheduling to give the scheduler a chance to rearrange them, because RA
13  * may need to insert OPC_META_READ_FIRST to handle splitting live ranges, and
14  * also because some (e.g. BALLOT and READ_FIRST) must produce a shared
15  * register that cannot be spilled to a normal register until after the if,
16  * which makes implementing spilling more complicated if they are already
17  * lowered.
18  */
19 
20 static void
replace_pred(struct ir3_block * block,struct ir3_block * old_pred,struct ir3_block * new_pred)21 replace_pred(struct ir3_block *block, struct ir3_block *old_pred,
22              struct ir3_block *new_pred)
23 {
24    for (unsigned i = 0; i < block->predecessors_count; i++) {
25       if (block->predecessors[i] == old_pred) {
26          block->predecessors[i] = new_pred;
27          return;
28       }
29    }
30 }
31 
32 static void
replace_physical_pred(struct ir3_block * block,struct ir3_block * old_pred,struct ir3_block * new_pred)33 replace_physical_pred(struct ir3_block *block, struct ir3_block *old_pred,
34                       struct ir3_block *new_pred)
35 {
36    for (unsigned i = 0; i < block->physical_predecessors_count; i++) {
37       if (block->physical_predecessors[i] == old_pred) {
38          block->physical_predecessors[i] = new_pred;
39          return;
40       }
41    }
42 }
43 
44 static void
mov_immed(struct ir3_register * dst,struct ir3_block * block,unsigned immed)45 mov_immed(struct ir3_register *dst, struct ir3_block *block, unsigned immed)
46 {
47    struct ir3_instruction *mov = ir3_instr_create(block, OPC_MOV, 1, 1);
48    struct ir3_register *mov_dst = ir3_dst_create(mov, dst->num, dst->flags);
49    mov_dst->wrmask = dst->wrmask;
50    struct ir3_register *src = ir3_src_create(
51       mov, INVALID_REG, (dst->flags & IR3_REG_HALF) | IR3_REG_IMMED);
52    src->uim_val = immed;
53    mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
54    mov->cat1.src_type = mov->cat1.dst_type;
55    mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
56 }
57 
58 static void
mov_reg(struct ir3_block * block,struct ir3_register * dst,struct ir3_register * src)59 mov_reg(struct ir3_block *block, struct ir3_register *dst,
60         struct ir3_register *src)
61 {
62    struct ir3_instruction *mov = ir3_instr_create(block, OPC_MOV, 1, 1);
63 
64    struct ir3_register *mov_dst =
65       ir3_dst_create(mov, dst->num, dst->flags & (IR3_REG_HALF | IR3_REG_SHARED));
66    struct ir3_register *mov_src =
67       ir3_src_create(mov, src->num, src->flags & (IR3_REG_HALF | IR3_REG_SHARED));
68    mov_dst->wrmask = dst->wrmask;
69    mov_src->wrmask = src->wrmask;
70    mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
71 
72    mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
73    mov->cat1.src_type = (src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
74 }
75 
76 static void
binop(struct ir3_block * block,opc_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1)77 binop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
78       struct ir3_register *src0, struct ir3_register *src1)
79 {
80    struct ir3_instruction *instr = ir3_instr_create(block, opc, 1, 2);
81 
82    unsigned flags = dst->flags & IR3_REG_HALF;
83    struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
84    struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
85    struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
86 
87    instr_dst->wrmask = dst->wrmask;
88    instr_src0->wrmask = src0->wrmask;
89    instr_src1->wrmask = src1->wrmask;
90    instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
91 }
92 
93 static void
triop(struct ir3_block * block,opc_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1,struct ir3_register * src2)94 triop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
95       struct ir3_register *src0, struct ir3_register *src1,
96       struct ir3_register *src2)
97 {
98    struct ir3_instruction *instr = ir3_instr_create(block, opc, 1, 3);
99 
100    unsigned flags = dst->flags & IR3_REG_HALF;
101    struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
102    struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
103    struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
104    struct ir3_register *instr_src2 = ir3_src_create(instr, src2->num, flags);
105 
106    instr_dst->wrmask = dst->wrmask;
107    instr_src0->wrmask = src0->wrmask;
108    instr_src1->wrmask = src1->wrmask;
109    instr_src2->wrmask = src2->wrmask;
110    instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
111 }
112 
113 static void
do_reduce(struct ir3_block * block,reduce_op_t opc,struct ir3_register * dst,struct ir3_register * src0,struct ir3_register * src1)114 do_reduce(struct ir3_block *block, reduce_op_t opc,
115           struct ir3_register *dst, struct ir3_register *src0,
116           struct ir3_register *src1)
117 {
118    switch (opc) {
119 #define CASE(name)                                                             \
120    case REDUCE_OP_##name:                                                      \
121       binop(block, OPC_##name, dst, src0, src1);                               \
122       break;
123 
124    CASE(ADD_U)
125    CASE(ADD_F)
126    CASE(MUL_F)
127    CASE(MIN_U)
128    CASE(MIN_S)
129    CASE(MIN_F)
130    CASE(MAX_U)
131    CASE(MAX_S)
132    CASE(MAX_F)
133    CASE(AND_B)
134    CASE(OR_B)
135    CASE(XOR_B)
136 
137 #undef CASE
138 
139    case REDUCE_OP_MUL_U:
140       if (dst->flags & IR3_REG_HALF) {
141          binop(block, OPC_MUL_S24, dst, src0, src1);
142       } else {
143          /* 32-bit multiplication macro - see ir3_nir_imul */
144          binop(block, OPC_MULL_U, dst, src0, src1);
145          triop(block, OPC_MADSH_M16, dst, src0, src1, dst);
146          triop(block, OPC_MADSH_M16, dst, src1, src0, dst);
147       }
148       break;
149    }
150 }
151 
152 static struct ir3_block *
split_block(struct ir3 * ir,struct ir3_block * before_block,struct ir3_instruction * instr)153 split_block(struct ir3 *ir, struct ir3_block *before_block,
154             struct ir3_instruction *instr)
155 {
156    struct ir3_block *after_block = ir3_block_create(ir);
157    list_add(&after_block->node, &before_block->node);
158 
159    for (unsigned i = 0; i < ARRAY_SIZE(before_block->successors); i++) {
160       after_block->successors[i] = before_block->successors[i];
161       if (after_block->successors[i])
162          replace_pred(after_block->successors[i], before_block, after_block);
163    }
164 
165    for (unsigned i = 0; i < before_block->physical_successors_count; i++) {
166       replace_physical_pred(before_block->physical_successors[i],
167                             before_block, after_block);
168    }
169 
170    ralloc_steal(after_block, before_block->physical_successors);
171    after_block->physical_successors = before_block->physical_successors;
172    after_block->physical_successors_sz = before_block->physical_successors_sz;
173    after_block->physical_successors_count =
174       before_block->physical_successors_count;
175 
176    before_block->successors[0] = before_block->successors[1] = NULL;
177    before_block->physical_successors = NULL;
178    before_block->physical_successors_count = 0;
179    before_block->physical_successors_sz = 0;
180 
181    foreach_instr_from_safe (rem_instr, &instr->node,
182                             &before_block->instr_list) {
183       list_del(&rem_instr->node);
184       list_addtail(&rem_instr->node, &after_block->instr_list);
185       rem_instr->block = after_block;
186    }
187 
188    after_block->divergent_condition = before_block->divergent_condition;
189    before_block->divergent_condition = false;
190    return after_block;
191 }
192 
193 static void
link_blocks(struct ir3_block * pred,struct ir3_block * succ,unsigned index)194 link_blocks(struct ir3_block *pred, struct ir3_block *succ, unsigned index)
195 {
196    pred->successors[index] = succ;
197    ir3_block_add_predecessor(succ, pred);
198    ir3_block_link_physical(pred, succ);
199 }
200 
201 static void
link_blocks_jump(struct ir3_block * pred,struct ir3_block * succ)202 link_blocks_jump(struct ir3_block *pred, struct ir3_block *succ)
203 {
204    ir3_JUMP(pred);
205    link_blocks(pred, succ, 0);
206 }
207 
208 static void
link_blocks_branch(struct ir3_block * pred,struct ir3_block * target,struct ir3_block * fallthrough,unsigned opc,unsigned flags,struct ir3_instruction * condition)209 link_blocks_branch(struct ir3_block *pred, struct ir3_block *target,
210                    struct ir3_block *fallthrough, unsigned opc, unsigned flags,
211                    struct ir3_instruction *condition)
212 {
213    unsigned nsrc = condition ? 1 : 0;
214    struct ir3_instruction *branch = ir3_instr_create(pred, opc, 0, nsrc);
215    branch->flags |= flags;
216 
217    if (condition) {
218       struct ir3_register *cond_dst = condition->dsts[0];
219       struct ir3_register *src =
220          ir3_src_create(branch, cond_dst->num, cond_dst->flags);
221       src->def = cond_dst;
222    }
223 
224    link_blocks(pred, target, 0);
225    link_blocks(pred, fallthrough, 1);
226 
227    if (opc != OPC_BALL && opc != OPC_BANY) {
228       pred->divergent_condition = true;
229    }
230 }
231 
232 static struct ir3_block *
create_if(struct ir3 * ir,struct ir3_block * before_block,struct ir3_block * after_block,unsigned opc,unsigned flags,struct ir3_instruction * condition)233 create_if(struct ir3 *ir, struct ir3_block *before_block,
234           struct ir3_block *after_block, unsigned opc, unsigned flags,
235           struct ir3_instruction *condition)
236 {
237    struct ir3_block *then_block = ir3_block_create(ir);
238    list_add(&then_block->node, &before_block->node);
239 
240    link_blocks_branch(before_block, then_block, after_block, opc, flags,
241                       condition);
242    link_blocks_jump(then_block, after_block);
243 
244    return then_block;
245 }
246 
247 static bool
lower_instr(struct ir3 * ir,struct ir3_block ** block,struct ir3_instruction * instr)248 lower_instr(struct ir3 *ir, struct ir3_block **block, struct ir3_instruction *instr)
249 {
250    switch (instr->opc) {
251    case OPC_BALLOT_MACRO:
252    case OPC_ANY_MACRO:
253    case OPC_ALL_MACRO:
254    case OPC_ELECT_MACRO:
255    case OPC_READ_COND_MACRO:
256    case OPC_SCAN_MACRO:
257    case OPC_SCAN_CLUSTERS_MACRO:
258       break;
259    case OPC_READ_FIRST_MACRO:
260       /* Moves to shared registers read the first active fiber, so we can just
261        * turn read_first.macro into a move. However we must still use the macro
262        * and lower it late because in ir3_cp we need to distinguish between
263        * moves where all source fibers contain the same value, which can be copy
264        * propagated, and moves generated from API-level ReadFirstInvocation
265        * which cannot.
266        */
267       assert(instr->dsts[0]->flags & IR3_REG_SHARED);
268       instr->opc = OPC_MOV;
269       instr->cat1.dst_type = TYPE_U32;
270       instr->cat1.src_type =
271          (instr->srcs[0]->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
272       return false;
273    default:
274       return false;
275    }
276 
277    struct ir3_block *before_block = *block;
278    struct ir3_block *after_block = split_block(ir, before_block, instr);
279 
280    if (instr->opc == OPC_SCAN_MACRO) {
281       /* The pseudo-code for the scan macro is:
282        *
283        * while (true) {
284        *    header:
285        *    if (elect()) {
286        *       exit:
287        *       exclusive = reduce;
288        *       inclusive = src OP exclusive;
289        *       reduce = inclusive;
290        *       break;
291        *    }
292        *    footer:
293        * }
294        *
295        * This is based on the blob's sequence, and carefully crafted to avoid
296        * using the shared register "reduce" except in move instructions, since
297        * using it in the actual OP isn't possible for half-registers.
298        */
299       struct ir3_block *header = ir3_block_create(ir);
300       list_add(&header->node, &before_block->node);
301 
302       struct ir3_block *exit = ir3_block_create(ir);
303       list_add(&exit->node, &header->node);
304 
305       struct ir3_block *footer = ir3_block_create(ir);
306       list_add(&footer->node, &exit->node);
307       footer->reconvergence_point = true;
308 
309       after_block->reconvergence_point = true;
310 
311       link_blocks_jump(before_block, header);
312 
313       link_blocks_branch(header, exit, footer, OPC_GETONE,
314                          IR3_INSTR_NEEDS_HELPERS, NULL);
315 
316       link_blocks_jump(exit, after_block);
317       ir3_block_link_physical(exit, footer);
318 
319       link_blocks_jump(footer, header);
320 
321       struct ir3_register *exclusive = instr->dsts[0];
322       struct ir3_register *inclusive = instr->dsts[1];
323       struct ir3_register *reduce = instr->dsts[2];
324       struct ir3_register *src = instr->srcs[0];
325 
326       mov_reg(exit, exclusive, reduce);
327       do_reduce(exit, instr->cat1.reduce_op, inclusive, src, exclusive);
328       mov_reg(exit, reduce, inclusive);
329    } else if (instr->opc == OPC_SCAN_CLUSTERS_MACRO) {
330       /* The pseudo-code for the scan macro is:
331        *
332        * while (true) {
333        *    body:
334        *    scratch = reduce;
335        *
336        *    inclusive = inclusive_src OP scratch;
337        *
338        *    static if (is exclusive scan)
339        *       exclusive = exclusive_src OP scratch
340        *
341        *    if (getlast()) {
342        *       store:
343        *       reduce = inclusive;
344        *       if (elect())
345        *           break;
346        *    } else {
347        *       break;
348        *    }
349        * }
350        * after_block:
351        */
352       struct ir3_block *body = ir3_block_create(ir);
353       list_add(&body->node, &before_block->node);
354 
355       struct ir3_block *store = ir3_block_create(ir);
356       list_add(&store->node, &body->node);
357 
358       after_block->reconvergence_point = true;
359 
360       link_blocks_jump(before_block, body);
361 
362       link_blocks_branch(body, store, after_block, OPC_GETLAST, 0, NULL);
363 
364       link_blocks_branch(store, after_block, body, OPC_GETONE,
365                          IR3_INSTR_NEEDS_HELPERS, NULL);
366 
367       struct ir3_register *reduce = instr->dsts[0];
368       struct ir3_register *inclusive = instr->dsts[1];
369       struct ir3_register *inclusive_src = instr->srcs[1];
370 
371       /* We need to perform the following operations:
372        *  - inclusive = inclusive_src OP reduce
373        *  - exclusive = exclusive_src OP reduce (iff exclusive scan)
374        * Since reduce is initially in a shared register, we need to copy it to a
375        * scratch register before performing the operations.
376        *
377        * The scratch register used is:
378        *  - an explicitly allocated one if op is 32b mul_u.
379        *    - necessary because we cannot do 'foo = foo mul_u bar' since mul_u
380        *      clobbers its destination.
381        *  - exclusive if this is an exclusive scan (and not 32b mul_u).
382        *    - since we calculate inclusive first.
383        *  - inclusive otherwise.
384        *
385        * In all cases, this is the last destination.
386        */
387       struct ir3_register *scratch = instr->dsts[instr->dsts_count - 1];
388 
389       mov_reg(body, scratch, reduce);
390       do_reduce(body, instr->cat1.reduce_op, inclusive, inclusive_src, scratch);
391 
392       /* exclusive scan */
393       if (instr->srcs_count == 3) {
394          struct ir3_register *exclusive_src = instr->srcs[2];
395          struct ir3_register *exclusive = instr->dsts[2];
396          do_reduce(body, instr->cat1.reduce_op, exclusive, exclusive_src,
397                    scratch);
398       }
399 
400       mov_reg(store, reduce, inclusive);
401    } else {
402       /* For ballot, the destination must be initialized to 0 before we do
403        * the movmsk because the condition may be 0 and then the movmsk will
404        * be skipped.
405        */
406       if (instr->opc == OPC_BALLOT_MACRO) {
407          mov_immed(instr->dsts[0], before_block, 0);
408       }
409 
410       struct ir3_instruction *condition = NULL;
411       unsigned branch_opc = 0;
412       unsigned branch_flags = 0;
413 
414       switch (instr->opc) {
415       case OPC_BALLOT_MACRO:
416       case OPC_READ_COND_MACRO:
417       case OPC_ANY_MACRO:
418       case OPC_ALL_MACRO:
419          condition = instr->srcs[0]->def->instr;
420          break;
421       default:
422          break;
423       }
424 
425       switch (instr->opc) {
426       case OPC_BALLOT_MACRO:
427       case OPC_READ_COND_MACRO:
428          after_block->reconvergence_point = true;
429          branch_opc = OPC_BR;
430          break;
431       case OPC_ANY_MACRO:
432          branch_opc = OPC_BANY;
433          break;
434       case OPC_ALL_MACRO:
435          branch_opc = OPC_BALL;
436          break;
437       case OPC_ELECT_MACRO:
438          after_block->reconvergence_point = true;
439          branch_opc = OPC_GETONE;
440          branch_flags = instr->flags & IR3_INSTR_NEEDS_HELPERS;
441          break;
442       default:
443          unreachable("bad opcode");
444       }
445 
446       struct ir3_block *then_block =
447          create_if(ir, before_block, after_block, branch_opc, branch_flags,
448                    condition);
449 
450       switch (instr->opc) {
451       case OPC_ALL_MACRO:
452       case OPC_ANY_MACRO:
453       case OPC_ELECT_MACRO:
454          mov_immed(instr->dsts[0], then_block, 1);
455          mov_immed(instr->dsts[0], before_block, 0);
456          break;
457 
458       case OPC_BALLOT_MACRO: {
459          unsigned comp_count = util_last_bit(instr->dsts[0]->wrmask);
460          struct ir3_instruction *movmsk =
461             ir3_instr_create(then_block, OPC_MOVMSK, 1, 0);
462          ir3_dst_create(movmsk, instr->dsts[0]->num, instr->dsts[0]->flags);
463          movmsk->repeat = comp_count - 1;
464          break;
465       }
466 
467       case OPC_READ_COND_MACRO: {
468          struct ir3_instruction *mov =
469             ir3_instr_create(then_block, OPC_MOV, 1, 1);
470          ir3_dst_create(mov, instr->dsts[0]->num, instr->dsts[0]->flags);
471          struct ir3_register *new_src = ir3_src_create(mov, 0, 0);
472          *new_src = *instr->srcs[1];
473          mov->cat1.dst_type = TYPE_U32;
474          mov->cat1.src_type =
475             (new_src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
476          mov->flags |= IR3_INSTR_NEEDS_HELPERS;
477          break;
478       }
479 
480       default:
481          unreachable("bad opcode");
482       }
483    }
484 
485    *block = after_block;
486    list_delinit(&instr->node);
487    return true;
488 }
489 
490 static bool
lower_block(struct ir3 * ir,struct ir3_block ** block)491 lower_block(struct ir3 *ir, struct ir3_block **block)
492 {
493    bool progress = true;
494 
495    bool inner_progress;
496    do {
497       inner_progress = false;
498       foreach_instr (instr, &(*block)->instr_list) {
499          if (lower_instr(ir, block, instr)) {
500             /* restart the loop with the new block we created because the
501              * iterator has been invalidated.
502              */
503             progress = inner_progress = true;
504             break;
505          }
506       }
507    } while (inner_progress);
508 
509    return progress;
510 }
511 
512 bool
ir3_lower_subgroups(struct ir3 * ir)513 ir3_lower_subgroups(struct ir3 *ir)
514 {
515    bool progress = false;
516 
517    foreach_block (block, &ir->block_list)
518       progress |= lower_block(ir, &block);
519 
520    return progress;
521 }
522 
523 static bool
filter_scan_reduce(const nir_instr * instr,const void * data)524 filter_scan_reduce(const nir_instr *instr, const void *data)
525 {
526    if (instr->type != nir_instr_type_intrinsic)
527       return false;
528 
529    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
530 
531    switch (intrin->intrinsic) {
532    case nir_intrinsic_reduce:
533    case nir_intrinsic_inclusive_scan:
534    case nir_intrinsic_exclusive_scan:
535       return true;
536    default:
537       return false;
538    }
539 }
540 
541 static nir_def *
lower_scan_reduce(struct nir_builder * b,nir_instr * instr,void * data)542 lower_scan_reduce(struct nir_builder *b, nir_instr *instr, void *data)
543 {
544    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
545    unsigned bit_size = intrin->def.bit_size;
546 
547    nir_op op = nir_intrinsic_reduction_op(intrin);
548    nir_const_value ident_val = nir_alu_binop_identity(op, bit_size);
549    nir_def *ident = nir_build_imm(b, 1, bit_size, &ident_val);
550    nir_def *inclusive = intrin->src[0].ssa;
551    nir_def *exclusive = ident;
552 
553    for (unsigned cluster_size = 2; cluster_size <= 8; cluster_size *= 2) {
554       nir_def *brcst = nir_brcst_active_ir3(b, ident, inclusive,
555                                             .cluster_size = cluster_size);
556       inclusive = nir_build_alu2(b, op, inclusive, brcst);
557 
558       if (intrin->intrinsic == nir_intrinsic_exclusive_scan)
559          exclusive = nir_build_alu2(b, op, exclusive, brcst);
560    }
561 
562    switch (intrin->intrinsic) {
563    case nir_intrinsic_reduce:
564       return nir_reduce_clusters_ir3(b, inclusive, .reduction_op = op);
565    case nir_intrinsic_inclusive_scan:
566       return nir_inclusive_scan_clusters_ir3(b, inclusive, .reduction_op = op);
567    case nir_intrinsic_exclusive_scan:
568       return nir_exclusive_scan_clusters_ir3(b, inclusive, exclusive,
569                                              .reduction_op = op);
570    default:
571       unreachable("filtered intrinsic");
572    }
573 }
574 
575 bool
ir3_nir_opt_subgroups(nir_shader * nir,struct ir3_shader_variant * v)576 ir3_nir_opt_subgroups(nir_shader *nir, struct ir3_shader_variant *v)
577 {
578    if (!v->compiler->has_getfiberid)
579       return false;
580 
581    return nir_shader_lower_instructions(nir, filter_scan_reduce,
582                                         lower_scan_reduce, NULL);
583 }
584