xref: /aosp_15_r20/external/mesa3d/src/compiler/spirv/vtn_subgroup.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "vtn_private.h"
25 
26 static struct vtn_ssa_value *
vtn_build_subgroup_instr(struct vtn_builder * b,nir_intrinsic_op nir_op,struct vtn_ssa_value * src0,nir_def * index,unsigned const_idx0,unsigned const_idx1)27 vtn_build_subgroup_instr(struct vtn_builder *b,
28                          nir_intrinsic_op nir_op,
29                          struct vtn_ssa_value *src0,
30                          nir_def *index,
31                          unsigned const_idx0,
32                          unsigned const_idx1)
33 {
34    /* Some of the subgroup operations take an index.  SPIR-V allows this to be
35     * any integer type.  To make things simpler for drivers, we only support
36     * 32-bit indices.
37     */
38    if (index && index->bit_size != 32)
39       index = nir_u2u32(&b->nb, index);
40 
41    struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
42 
43    vtn_assert(dst->type == src0->type);
44    if (!glsl_type_is_vector_or_scalar(dst->type)) {
45       for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
46          dst->elems[0] =
47             vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
48                                      const_idx0, const_idx1);
49       }
50       return dst;
51    }
52 
53    nir_intrinsic_instr *intrin =
54       nir_intrinsic_instr_create(b->nb.shader, nir_op);
55    nir_def_init_for_type(&intrin->instr, &intrin->def, dst->type);
56    intrin->num_components = intrin->def.num_components;
57 
58    intrin->src[0] = nir_src_for_ssa(src0->def);
59    if (index)
60       intrin->src[1] = nir_src_for_ssa(index);
61 
62    intrin->const_index[0] = const_idx0;
63    intrin->const_index[1] = const_idx1;
64 
65    nir_builder_instr_insert(&b->nb, &intrin->instr);
66 
67    dst->def = &intrin->def;
68 
69    return dst;
70 }
71 
72 void
vtn_handle_subgroup(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)73 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
74                     const uint32_t *w, unsigned count)
75 {
76    struct vtn_type *dest_type = vtn_get_type(b, w[1]);
77 
78    switch (opcode) {
79    case SpvOpGroupNonUniformElect: {
80       vtn_fail_if(dest_type->type != glsl_bool_type(),
81                   "OpGroupNonUniformElect must return a Bool");
82       nir_intrinsic_instr *elect =
83          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
84       nir_def_init_for_type(&elect->instr, &elect->def, dest_type->type);
85       nir_builder_instr_insert(&b->nb, &elect->instr);
86       vtn_push_nir_ssa(b, w[2], &elect->def);
87       break;
88    }
89 
90    case SpvOpGroupNonUniformBallot:
91    case SpvOpSubgroupBallotKHR: {
92       bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
93       vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
94                   "OpGroupNonUniformBallot must return a uvec4");
95       nir_intrinsic_instr *ballot =
96          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
97       ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope]));
98       nir_def_init(&ballot->instr, &ballot->def, 4, 32);
99       ballot->num_components = 4;
100       nir_builder_instr_insert(&b->nb, &ballot->instr);
101       vtn_push_nir_ssa(b, w[2], &ballot->def);
102       break;
103    }
104 
105    case SpvOpGroupNonUniformInverseBallot: {
106       nir_def *dest = nir_inverse_ballot(&b->nb, 1, vtn_get_nir_ssa(b, w[4]));
107       vtn_push_nir_ssa(b, w[2], dest);
108       break;
109    }
110 
111    case SpvOpGroupNonUniformBallotBitExtract:
112    case SpvOpGroupNonUniformBallotBitCount:
113    case SpvOpGroupNonUniformBallotFindLSB:
114    case SpvOpGroupNonUniformBallotFindMSB: {
115       nir_def *src0, *src1 = NULL;
116       nir_intrinsic_op op;
117       switch (opcode) {
118       case SpvOpGroupNonUniformBallotBitExtract:
119          op = nir_intrinsic_ballot_bitfield_extract;
120          src0 = vtn_get_nir_ssa(b, w[4]);
121          src1 = vtn_get_nir_ssa(b, w[5]);
122          break;
123       case SpvOpGroupNonUniformBallotBitCount:
124          switch ((SpvGroupOperation)w[4]) {
125          case SpvGroupOperationReduce:
126             op = nir_intrinsic_ballot_bit_count_reduce;
127             break;
128          case SpvGroupOperationInclusiveScan:
129             op = nir_intrinsic_ballot_bit_count_inclusive;
130             break;
131          case SpvGroupOperationExclusiveScan:
132             op = nir_intrinsic_ballot_bit_count_exclusive;
133             break;
134          default:
135             unreachable("Invalid group operation");
136          }
137          src0 = vtn_get_nir_ssa(b, w[5]);
138          break;
139       case SpvOpGroupNonUniformBallotFindLSB:
140          op = nir_intrinsic_ballot_find_lsb;
141          src0 = vtn_get_nir_ssa(b, w[4]);
142          break;
143       case SpvOpGroupNonUniformBallotFindMSB:
144          op = nir_intrinsic_ballot_find_msb;
145          src0 = vtn_get_nir_ssa(b, w[4]);
146          break;
147       default:
148          unreachable("Unhandled opcode");
149       }
150 
151       nir_intrinsic_instr *intrin =
152          nir_intrinsic_instr_create(b->nb.shader, op);
153 
154       intrin->src[0] = nir_src_for_ssa(src0);
155       if (src1)
156          intrin->src[1] = nir_src_for_ssa(src1);
157 
158       nir_def_init_for_type(&intrin->instr, &intrin->def,
159                             dest_type->type);
160       nir_builder_instr_insert(&b->nb, &intrin->instr);
161 
162       vtn_push_nir_ssa(b, w[2], &intrin->def);
163       break;
164    }
165 
166    case SpvOpGroupNonUniformBroadcastFirst:
167    case SpvOpSubgroupFirstInvocationKHR: {
168       bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
169       vtn_push_ssa_value(b, w[2],
170          vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
171                                   vtn_ssa_value(b, w[3 + has_scope]),
172                                   NULL, 0, 0));
173       break;
174    }
175 
176    case SpvOpGroupNonUniformBroadcast:
177    case SpvOpGroupBroadcast:
178    case SpvOpSubgroupReadInvocationKHR: {
179       bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
180       vtn_push_ssa_value(b, w[2],
181          vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
182                                   vtn_ssa_value(b, w[3 + has_scope]),
183                                   vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
184       break;
185    }
186 
187    case SpvOpGroupNonUniformAll:
188    case SpvOpGroupNonUniformAny:
189    case SpvOpGroupNonUniformAllEqual:
190    case SpvOpGroupAll:
191    case SpvOpGroupAny:
192    case SpvOpSubgroupAllKHR:
193    case SpvOpSubgroupAnyKHR:
194    case SpvOpSubgroupAllEqualKHR: {
195       vtn_fail_if(dest_type->type != glsl_bool_type(),
196                   "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
197       nir_intrinsic_op op;
198       switch (opcode) {
199       case SpvOpGroupNonUniformAll:
200       case SpvOpGroupAll:
201       case SpvOpSubgroupAllKHR:
202          op = nir_intrinsic_vote_all;
203          break;
204       case SpvOpGroupNonUniformAny:
205       case SpvOpGroupAny:
206       case SpvOpSubgroupAnyKHR:
207          op = nir_intrinsic_vote_any;
208          break;
209       case SpvOpSubgroupAllEqualKHR:
210          op = nir_intrinsic_vote_ieq;
211          break;
212       case SpvOpGroupNonUniformAllEqual:
213          switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
214          case GLSL_TYPE_FLOAT:
215          case GLSL_TYPE_FLOAT16:
216          case GLSL_TYPE_DOUBLE:
217             op = nir_intrinsic_vote_feq;
218             break;
219          case GLSL_TYPE_UINT:
220          case GLSL_TYPE_INT:
221          case GLSL_TYPE_UINT8:
222          case GLSL_TYPE_INT8:
223          case GLSL_TYPE_UINT16:
224          case GLSL_TYPE_INT16:
225          case GLSL_TYPE_UINT64:
226          case GLSL_TYPE_INT64:
227          case GLSL_TYPE_BOOL:
228             op = nir_intrinsic_vote_ieq;
229             break;
230          default:
231             unreachable("Unhandled type");
232          }
233          break;
234       default:
235          unreachable("Unhandled opcode");
236       }
237 
238       nir_def *src0;
239       if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
240           opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
241           opcode == SpvOpGroupNonUniformAllEqual) {
242          src0 = vtn_get_nir_ssa(b, w[4]);
243       } else {
244          src0 = vtn_get_nir_ssa(b, w[3]);
245       }
246       nir_intrinsic_instr *intrin =
247          nir_intrinsic_instr_create(b->nb.shader, op);
248       if (nir_intrinsic_infos[op].src_components[0] == 0)
249          intrin->num_components = src0->num_components;
250       intrin->src[0] = nir_src_for_ssa(src0);
251       nir_def_init_for_type(&intrin->instr, &intrin->def,
252                             dest_type->type);
253       nir_builder_instr_insert(&b->nb, &intrin->instr);
254 
255       vtn_push_nir_ssa(b, w[2], &intrin->def);
256       break;
257    }
258 
259    case SpvOpGroupNonUniformShuffle:
260    case SpvOpGroupNonUniformShuffleXor:
261    case SpvOpGroupNonUniformShuffleUp:
262    case SpvOpGroupNonUniformShuffleDown: {
263       nir_intrinsic_op op;
264       switch (opcode) {
265       case SpvOpGroupNonUniformShuffle:
266          op = nir_intrinsic_shuffle;
267          break;
268       case SpvOpGroupNonUniformShuffleXor:
269          op = nir_intrinsic_shuffle_xor;
270          break;
271       case SpvOpGroupNonUniformShuffleUp:
272          op = nir_intrinsic_shuffle_up;
273          break;
274       case SpvOpGroupNonUniformShuffleDown:
275          op = nir_intrinsic_shuffle_down;
276          break;
277       default:
278          unreachable("Invalid opcode");
279       }
280       vtn_push_ssa_value(b, w[2],
281          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
282                                   vtn_get_nir_ssa(b, w[5]), 0, 0));
283       break;
284    }
285 
286    case SpvOpSubgroupShuffleINTEL:
287    case SpvOpSubgroupShuffleXorINTEL: {
288       nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleINTEL ?
289          nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor;
290       vtn_push_ssa_value(b, w[2],
291          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]),
292                                   vtn_get_nir_ssa(b, w[4]), 0, 0));
293       break;
294    }
295 
296    case SpvOpSubgroupShuffleUpINTEL:
297    case SpvOpSubgroupShuffleDownINTEL: {
298       /* TODO: Move this lower on the compiler stack, where we can move the
299        * current/other data to adjacent registers to avoid doing a shuffle
300        * twice.
301        */
302 
303       nir_builder *nb = &b->nb;
304       nir_def *size = nir_load_subgroup_size(nb);
305       nir_def *delta = vtn_get_nir_ssa(b, w[5]);
306 
307       /* Rewrite UP in terms of DOWN.
308        *
309        *   UP(a, b, delta) == DOWN(a, b, size - delta)
310        */
311       if (opcode == SpvOpSubgroupShuffleUpINTEL)
312          delta = nir_isub(nb, size, delta);
313 
314       nir_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta);
315       struct vtn_ssa_value *current =
316          vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]),
317                                   index, 0, 0);
318 
319       struct vtn_ssa_value *next =
320          vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]),
321                                   nir_isub(nb, index, size), 0, 0);
322 
323       nir_def *cond = nir_ilt(nb, index, size);
324       vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def));
325 
326       break;
327    }
328 
329    case SpvOpGroupNonUniformRotateKHR: {
330       const uint32_t cluster_size = count > 6 ? vtn_constant_uint(b, w[6]) : 0;
331       vtn_fail_if(cluster_size && !IS_POT(cluster_size),
332                   "Behavior is undefined unless ClusterSize is at least 1 and a power of 2.");
333 
334       struct vtn_ssa_value *value = vtn_ssa_value(b, w[4]);
335       struct vtn_ssa_value *delta = vtn_ssa_value(b, w[5]);
336       vtn_push_nir_ssa(b, w[2],
337          vtn_build_subgroup_instr(b, nir_intrinsic_rotate,
338                                   value, delta->def, cluster_size, 0)->def);
339       break;
340    }
341 
342    case SpvOpGroupNonUniformQuadBroadcast:
343       /* From the Vulkan spec 1.3.269:
344        *
345        * 9.27. Quad Group Operations:
346        * "Fragment shaders that statically execute quad group operations
347        * must launch sufficient invocations to ensure their correct operation;"
348        */
349       if (b->shader->info.stage == MESA_SHADER_FRAGMENT)
350          b->shader->info.fs.require_full_quads = true;
351 
352       vtn_push_ssa_value(b, w[2],
353          vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
354                                   vtn_ssa_value(b, w[4]),
355                                   vtn_get_nir_ssa(b, w[5]), 0, 0));
356       break;
357 
358    case SpvOpGroupNonUniformQuadSwap: {
359       if (b->shader->info.stage == MESA_SHADER_FRAGMENT)
360          b->shader->info.fs.require_full_quads = true;
361 
362       unsigned direction = vtn_constant_uint(b, w[5]);
363       nir_intrinsic_op op;
364       switch (direction) {
365       case 0:
366          op = nir_intrinsic_quad_swap_horizontal;
367          break;
368       case 1:
369          op = nir_intrinsic_quad_swap_vertical;
370          break;
371       case 2:
372          op = nir_intrinsic_quad_swap_diagonal;
373          break;
374       default:
375          vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
376       }
377       vtn_push_ssa_value(b, w[2],
378          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
379       break;
380    }
381 
382    case SpvOpGroupNonUniformQuadAllKHR: {
383       nir_def *dest = nir_quad_vote_all(&b->nb, 1, vtn_get_nir_ssa(b, w[3]));
384       vtn_push_nir_ssa(b, w[2], dest);
385       break;
386    }
387    case SpvOpGroupNonUniformQuadAnyKHR: {
388       nir_def *dest = nir_quad_vote_any(&b->nb, 1, vtn_get_nir_ssa(b, w[3]));
389       vtn_push_nir_ssa(b, w[2], dest);
390       break;
391    }
392 
393    case SpvOpGroupNonUniformIAdd:
394    case SpvOpGroupNonUniformFAdd:
395    case SpvOpGroupNonUniformIMul:
396    case SpvOpGroupNonUniformFMul:
397    case SpvOpGroupNonUniformSMin:
398    case SpvOpGroupNonUniformUMin:
399    case SpvOpGroupNonUniformFMin:
400    case SpvOpGroupNonUniformSMax:
401    case SpvOpGroupNonUniformUMax:
402    case SpvOpGroupNonUniformFMax:
403    case SpvOpGroupNonUniformBitwiseAnd:
404    case SpvOpGroupNonUniformBitwiseOr:
405    case SpvOpGroupNonUniformBitwiseXor:
406    case SpvOpGroupNonUniformLogicalAnd:
407    case SpvOpGroupNonUniformLogicalOr:
408    case SpvOpGroupNonUniformLogicalXor:
409    case SpvOpGroupIAdd:
410    case SpvOpGroupFAdd:
411    case SpvOpGroupFMin:
412    case SpvOpGroupUMin:
413    case SpvOpGroupSMin:
414    case SpvOpGroupFMax:
415    case SpvOpGroupUMax:
416    case SpvOpGroupSMax:
417    case SpvOpGroupIAddNonUniformAMD:
418    case SpvOpGroupFAddNonUniformAMD:
419    case SpvOpGroupFMinNonUniformAMD:
420    case SpvOpGroupUMinNonUniformAMD:
421    case SpvOpGroupSMinNonUniformAMD:
422    case SpvOpGroupFMaxNonUniformAMD:
423    case SpvOpGroupUMaxNonUniformAMD:
424    case SpvOpGroupSMaxNonUniformAMD: {
425       nir_op reduction_op;
426       switch (opcode) {
427       case SpvOpGroupNonUniformIAdd:
428       case SpvOpGroupIAdd:
429       case SpvOpGroupIAddNonUniformAMD:
430          reduction_op = nir_op_iadd;
431          break;
432       case SpvOpGroupNonUniformFAdd:
433       case SpvOpGroupFAdd:
434       case SpvOpGroupFAddNonUniformAMD:
435          reduction_op = nir_op_fadd;
436          break;
437       case SpvOpGroupNonUniformIMul:
438          reduction_op = nir_op_imul;
439          break;
440       case SpvOpGroupNonUniformFMul:
441          reduction_op = nir_op_fmul;
442          break;
443       case SpvOpGroupNonUniformSMin:
444       case SpvOpGroupSMin:
445       case SpvOpGroupSMinNonUniformAMD:
446          reduction_op = nir_op_imin;
447          break;
448       case SpvOpGroupNonUniformUMin:
449       case SpvOpGroupUMin:
450       case SpvOpGroupUMinNonUniformAMD:
451          reduction_op = nir_op_umin;
452          break;
453       case SpvOpGroupNonUniformFMin:
454       case SpvOpGroupFMin:
455       case SpvOpGroupFMinNonUniformAMD:
456          reduction_op = nir_op_fmin;
457          break;
458       case SpvOpGroupNonUniformSMax:
459       case SpvOpGroupSMax:
460       case SpvOpGroupSMaxNonUniformAMD:
461          reduction_op = nir_op_imax;
462          break;
463       case SpvOpGroupNonUniformUMax:
464       case SpvOpGroupUMax:
465       case SpvOpGroupUMaxNonUniformAMD:
466          reduction_op = nir_op_umax;
467          break;
468       case SpvOpGroupNonUniformFMax:
469       case SpvOpGroupFMax:
470       case SpvOpGroupFMaxNonUniformAMD:
471          reduction_op = nir_op_fmax;
472          break;
473       case SpvOpGroupNonUniformBitwiseAnd:
474       case SpvOpGroupNonUniformLogicalAnd:
475          reduction_op = nir_op_iand;
476          break;
477       case SpvOpGroupNonUniformBitwiseOr:
478       case SpvOpGroupNonUniformLogicalOr:
479          reduction_op = nir_op_ior;
480          break;
481       case SpvOpGroupNonUniformBitwiseXor:
482       case SpvOpGroupNonUniformLogicalXor:
483          reduction_op = nir_op_ixor;
484          break;
485       default:
486          unreachable("Invalid reduction operation");
487       }
488 
489       nir_intrinsic_op op;
490       unsigned cluster_size = 0;
491       switch ((SpvGroupOperation)w[4]) {
492       case SpvGroupOperationReduce:
493          op = nir_intrinsic_reduce;
494          break;
495       case SpvGroupOperationInclusiveScan:
496          op = nir_intrinsic_inclusive_scan;
497          break;
498       case SpvGroupOperationExclusiveScan:
499          op = nir_intrinsic_exclusive_scan;
500          break;
501       case SpvGroupOperationClusteredReduce:
502          op = nir_intrinsic_reduce;
503          assert(count == 7);
504          cluster_size = vtn_constant_uint(b, w[6]);
505          break;
506       default:
507          unreachable("Invalid group operation");
508       }
509 
510       vtn_push_ssa_value(b, w[2],
511          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
512                                   reduction_op, cluster_size));
513       break;
514    }
515 
516    default:
517       unreachable("Invalid SPIR-V opcode");
518    }
519 }
520