xref: /aosp_15_r20/external/mesa3d/src/asahi/compiler/agx_nir_lower_subgroups.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Valve Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "compiler/nir/nir.h"
7 #include "compiler/nir/nir_builder.h"
8 #include "util/list.h"
9 #include "agx_nir.h"
10 #include "nir_builder_opcodes.h"
11 #include "nir_intrinsics.h"
12 #include "nir_intrinsics_indices.h"
13 #include "nir_opcodes.h"
14 
15 /* XXX: cribbed from nak, move to common */
16 static nir_def *
nir_udiv_round_up(nir_builder * b,nir_def * n,nir_def * d)17 nir_udiv_round_up(nir_builder *b, nir_def *n, nir_def *d)
18 {
19    return nir_udiv(b, nir_iadd(b, n, nir_iadd_imm(b, d, -1)), d);
20 }
21 
22 static bool
lower(nir_builder * b,nir_intrinsic_instr * intr,void * data)23 lower(nir_builder *b, nir_intrinsic_instr *intr, void *data)
24 {
25    b->cursor = nir_before_instr(&intr->instr);
26 
27    switch (intr->intrinsic) {
28    case nir_intrinsic_vote_any: {
29       /* We don't have vote instructions, but we have efficient ballots */
30       nir_def *ballot = nir_ballot(b, 1, 32, intr->src[0].ssa);
31       nir_def_rewrite_uses(&intr->def, nir_ine_imm(b, ballot, 0));
32       return true;
33    }
34 
35    case nir_intrinsic_vote_all: {
36       nir_def *ballot = nir_ballot(b, 1, 32, nir_inot(b, intr->src[0].ssa));
37       nir_def_rewrite_uses(&intr->def, nir_ieq_imm(b, ballot, 0));
38       return true;
39    }
40 
41    case nir_intrinsic_quad_vote_any: {
42       nir_def *ballot = nir_quad_ballot_agx(b, 16, intr->src[0].ssa);
43       nir_def_rewrite_uses(&intr->def, nir_ine_imm(b, ballot, 0));
44       return true;
45    }
46 
47    case nir_intrinsic_quad_vote_all: {
48       nir_def *ballot =
49          nir_quad_ballot_agx(b, 16, nir_inot(b, intr->src[0].ssa));
50       nir_def_rewrite_uses(&intr->def, nir_ieq_imm(b, ballot, 0));
51       return true;
52    }
53 
54    case nir_intrinsic_elect: {
55       nir_def *active_id = nir_load_active_subgroup_invocation_agx(b, 16);
56       nir_def_rewrite_uses(&intr->def, nir_ieq_imm(b, active_id, 0));
57       return true;
58    }
59 
60    case nir_intrinsic_first_invocation: {
61       nir_def *active_id = nir_load_active_subgroup_invocation_agx(b, 16);
62       nir_def *is_first = nir_ieq_imm(b, active_id, 0);
63       nir_def *first_bit = nir_ballot(b, 1, 32, is_first);
64       nir_def_rewrite_uses(&intr->def, nir_ufind_msb(b, first_bit));
65       return true;
66    }
67 
68    case nir_intrinsic_last_invocation: {
69       nir_def *active_mask = nir_ballot(b, 1, 32, nir_imm_true(b));
70       nir_def_rewrite_uses(&intr->def, nir_ufind_msb(b, active_mask));
71       return true;
72    }
73 
74    case nir_intrinsic_vote_ieq:
75    case nir_intrinsic_vote_feq: {
76       /* The common lowering does:
77        *
78        *    vote_all(x == read_first(x))
79        *
80        * This is not optimal for AGX, since we have ufind_msb but not ctz, so
81        * it's cheaper to read the last invocation than the first. So we do:
82        *
83        *    vote_all(x == read_last(x))
84        *
85        * implemented with lowered instructions as
86        *
87        *    ballot(x != broadcast(x, ffs(ballot(true)))) == 0
88        */
89       nir_def *active_mask = nir_ballot(b, 1, 32, nir_imm_true(b));
90       nir_def *active_bit = nir_ufind_msb(b, active_mask);
91       nir_def *other = nir_read_invocation(b, intr->src[0].ssa, active_bit);
92       nir_def *is_ne;
93 
94       if (intr->intrinsic == nir_intrinsic_vote_feq) {
95          is_ne = nir_fneu(b, other, intr->src[0].ssa);
96       } else {
97          is_ne = nir_ine(b, other, intr->src[0].ssa);
98       }
99 
100       nir_def *ballot = nir_ballot(b, 1, 32, is_ne);
101       nir_def_rewrite_uses(&intr->def, nir_ieq_imm(b, ballot, 0));
102       return true;
103    }
104 
105    case nir_intrinsic_load_num_subgroups: {
106       nir_def *workgroup_size = nir_load_workgroup_size(b);
107       workgroup_size = nir_imul(b,
108                                 nir_imul(b, nir_channel(b, workgroup_size, 0),
109                                          nir_channel(b, workgroup_size, 1)),
110                                 nir_channel(b, workgroup_size, 2));
111       nir_def *subgroup_size = nir_imm_int(b, 32);
112       nir_def *num_subgroups =
113          nir_udiv_round_up(b, workgroup_size, subgroup_size);
114       nir_def_rewrite_uses(&intr->def, num_subgroups);
115       return true;
116    }
117 
118    case nir_intrinsic_shuffle: {
119       nir_def *data = intr->src[0].ssa;
120       nir_def *target = intr->src[1].ssa;
121 
122       /* The hardware shuffle instruction chooses a single index within the
123        * target quad to shuffle each source quad with. Consequently, the low
124        * 2-bits of shuffle indices should not be quad divergent.  To implement
125        * arbitrary shuffle, pull each low 2-bits index in the quad separately.
126        */
127       nir_def *quad_start = nir_iand_imm(b, target, 0x1c);
128       nir_def *result = NULL;
129 
130       for (unsigned i = 0; i < 4; ++i) {
131          nir_def *target_i = nir_iadd_imm(b, quad_start, i);
132          nir_def *shuf = nir_read_invocation(b, data, target_i);
133 
134          if (result)
135             result = nir_bcsel(b, nir_ieq(b, target, target_i), shuf, result);
136          else
137             result = shuf;
138       }
139 
140       nir_def_rewrite_uses(&intr->def, result);
141       return true;
142    }
143 
144    case nir_intrinsic_inclusive_scan: {
145       /* If we got here, we support the corresponding exclusive scan in
146        * hardware, so just handle the last element.
147        */
148       nir_op red_op = nir_intrinsic_reduction_op(intr);
149       nir_def *data = intr->src[0].ssa;
150 
151       b->cursor = nir_after_instr(&intr->instr);
152       intr->intrinsic = nir_intrinsic_exclusive_scan;
153       nir_def *accum = nir_build_alu2(b, red_op, data, &intr->def);
154       nir_def_rewrite_uses_after(&intr->def, accum, accum->parent_instr);
155       return true;
156    }
157 
158    case nir_intrinsic_ballot: {
159       /* Optimize popcount(ballot(true)) to load_active_subgroup_count_agx() */
160       if (!nir_src_is_const(intr->src[0]) || !nir_src_as_bool(intr->src[0]) ||
161           !list_is_singular(&intr->def.uses))
162          return false;
163 
164       nir_src *use = list_first_entry(&intr->def.uses, nir_src, use_link);
165       nir_instr *parent = nir_src_parent_instr(use);
166       if (parent->type != nir_instr_type_alu)
167          return false;
168 
169       nir_alu_instr *alu = nir_instr_as_alu(parent);
170       if (alu->op != nir_op_bit_count)
171          return false;
172 
173       nir_def_rewrite_uses(&alu->def,
174                            nir_load_active_subgroup_count_agx(b, 32));
175       return true;
176    }
177 
178    default:
179       return false;
180    }
181 }
182 
183 static bool
lower_subgroup_filter(const nir_instr * instr,UNUSED const void * data)184 lower_subgroup_filter(const nir_instr *instr, UNUSED const void *data)
185 {
186    if (instr->type != nir_instr_type_intrinsic)
187       return false;
188 
189    /* Use default behaviour for everything but scans */
190    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
191    if (intr->intrinsic != nir_intrinsic_exclusive_scan &&
192        intr->intrinsic != nir_intrinsic_inclusive_scan &&
193        intr->intrinsic != nir_intrinsic_reduce)
194       return true;
195 
196    if (intr->def.num_components > 1 || intr->def.bit_size == 1)
197       return true;
198 
199    /* Hardware supports quad ops but no other support clustered reductions. */
200    if (nir_intrinsic_has_cluster_size(intr)) {
201       unsigned cluster = nir_intrinsic_cluster_size(intr);
202       if (cluster && cluster != 4 && cluster < 32)
203          return true;
204    }
205 
206    switch (nir_intrinsic_reduction_op(intr)) {
207    case nir_op_imul:
208       /* no imul hardware scan, always lower it */
209       return true;
210 
211    case nir_op_iadd:
212    case nir_op_iand:
213    case nir_op_ixor:
214    case nir_op_ior:
215       /* these have dedicated 64-bit lowering paths that use the 32-bit hardware
216        * instructions so are likely better than the full lowering.
217        */
218       return false;
219 
220    default:
221       /* otherwise, lower 64-bit, since the hw ops are at most 32-bit. */
222       return intr->def.bit_size == 64;
223    }
224 }
225 
226 bool
agx_nir_lower_subgroups(nir_shader * s)227 agx_nir_lower_subgroups(nir_shader *s)
228 {
229    /* First, do as much common lowering as we can */
230    nir_lower_subgroups_options opts = {
231       .filter = lower_subgroup_filter,
232       .lower_read_first_invocation = true,
233       .lower_inverse_ballot = true,
234       .lower_to_scalar = true,
235       .lower_relative_shuffle = true,
236       .lower_rotate_to_shuffle = true,
237       .lower_subgroup_masks = true,
238       .lower_reduce = true,
239       .ballot_components = 1,
240       .ballot_bit_size = 32,
241       .subgroup_size = 32,
242    };
243 
244    bool progress = nir_lower_subgroups(s, &opts);
245 
246    /* Then do AGX-only lowerings on top */
247    progress |=
248       nir_shader_intrinsics_pass(s, lower, nir_metadata_control_flow, NULL);
249 
250    return progress;
251 }
252