1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25
26 static struct vtn_ssa_value *
vtn_build_subgroup_instr(struct vtn_builder * b,nir_intrinsic_op nir_op,struct vtn_ssa_value * src0,nir_def * index,unsigned const_idx0,unsigned const_idx1)27 vtn_build_subgroup_instr(struct vtn_builder *b,
28 nir_intrinsic_op nir_op,
29 struct vtn_ssa_value *src0,
30 nir_def *index,
31 unsigned const_idx0,
32 unsigned const_idx1)
33 {
34 /* Some of the subgroup operations take an index. SPIR-V allows this to be
35 * any integer type. To make things simpler for drivers, we only support
36 * 32-bit indices.
37 */
38 if (index && index->bit_size != 32)
39 index = nir_u2u32(&b->nb, index);
40
41 struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
42
43 vtn_assert(dst->type == src0->type);
44 if (!glsl_type_is_vector_or_scalar(dst->type)) {
45 for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
46 dst->elems[0] =
47 vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
48 const_idx0, const_idx1);
49 }
50 return dst;
51 }
52
53 nir_intrinsic_instr *intrin =
54 nir_intrinsic_instr_create(b->nb.shader, nir_op);
55 nir_def_init_for_type(&intrin->instr, &intrin->def, dst->type);
56 intrin->num_components = intrin->def.num_components;
57
58 intrin->src[0] = nir_src_for_ssa(src0->def);
59 if (index)
60 intrin->src[1] = nir_src_for_ssa(index);
61
62 intrin->const_index[0] = const_idx0;
63 intrin->const_index[1] = const_idx1;
64
65 nir_builder_instr_insert(&b->nb, &intrin->instr);
66
67 dst->def = &intrin->def;
68
69 return dst;
70 }
71
72 void
vtn_handle_subgroup(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)73 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
74 const uint32_t *w, unsigned count)
75 {
76 struct vtn_type *dest_type = vtn_get_type(b, w[1]);
77
78 switch (opcode) {
79 case SpvOpGroupNonUniformElect: {
80 vtn_fail_if(dest_type->type != glsl_bool_type(),
81 "OpGroupNonUniformElect must return a Bool");
82 nir_intrinsic_instr *elect =
83 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
84 nir_def_init_for_type(&elect->instr, &elect->def, dest_type->type);
85 nir_builder_instr_insert(&b->nb, &elect->instr);
86 vtn_push_nir_ssa(b, w[2], &elect->def);
87 break;
88 }
89
90 case SpvOpGroupNonUniformBallot:
91 case SpvOpSubgroupBallotKHR: {
92 bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
93 vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
94 "OpGroupNonUniformBallot must return a uvec4");
95 nir_intrinsic_instr *ballot =
96 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
97 ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope]));
98 nir_def_init(&ballot->instr, &ballot->def, 4, 32);
99 ballot->num_components = 4;
100 nir_builder_instr_insert(&b->nb, &ballot->instr);
101 vtn_push_nir_ssa(b, w[2], &ballot->def);
102 break;
103 }
104
105 case SpvOpGroupNonUniformInverseBallot: {
106 nir_def *dest = nir_inverse_ballot(&b->nb, 1, vtn_get_nir_ssa(b, w[4]));
107 vtn_push_nir_ssa(b, w[2], dest);
108 break;
109 }
110
111 case SpvOpGroupNonUniformBallotBitExtract:
112 case SpvOpGroupNonUniformBallotBitCount:
113 case SpvOpGroupNonUniformBallotFindLSB:
114 case SpvOpGroupNonUniformBallotFindMSB: {
115 nir_def *src0, *src1 = NULL;
116 nir_intrinsic_op op;
117 switch (opcode) {
118 case SpvOpGroupNonUniformBallotBitExtract:
119 op = nir_intrinsic_ballot_bitfield_extract;
120 src0 = vtn_get_nir_ssa(b, w[4]);
121 src1 = vtn_get_nir_ssa(b, w[5]);
122 break;
123 case SpvOpGroupNonUniformBallotBitCount:
124 switch ((SpvGroupOperation)w[4]) {
125 case SpvGroupOperationReduce:
126 op = nir_intrinsic_ballot_bit_count_reduce;
127 break;
128 case SpvGroupOperationInclusiveScan:
129 op = nir_intrinsic_ballot_bit_count_inclusive;
130 break;
131 case SpvGroupOperationExclusiveScan:
132 op = nir_intrinsic_ballot_bit_count_exclusive;
133 break;
134 default:
135 unreachable("Invalid group operation");
136 }
137 src0 = vtn_get_nir_ssa(b, w[5]);
138 break;
139 case SpvOpGroupNonUniformBallotFindLSB:
140 op = nir_intrinsic_ballot_find_lsb;
141 src0 = vtn_get_nir_ssa(b, w[4]);
142 break;
143 case SpvOpGroupNonUniformBallotFindMSB:
144 op = nir_intrinsic_ballot_find_msb;
145 src0 = vtn_get_nir_ssa(b, w[4]);
146 break;
147 default:
148 unreachable("Unhandled opcode");
149 }
150
151 nir_intrinsic_instr *intrin =
152 nir_intrinsic_instr_create(b->nb.shader, op);
153
154 intrin->src[0] = nir_src_for_ssa(src0);
155 if (src1)
156 intrin->src[1] = nir_src_for_ssa(src1);
157
158 nir_def_init_for_type(&intrin->instr, &intrin->def,
159 dest_type->type);
160 nir_builder_instr_insert(&b->nb, &intrin->instr);
161
162 vtn_push_nir_ssa(b, w[2], &intrin->def);
163 break;
164 }
165
166 case SpvOpGroupNonUniformBroadcastFirst:
167 case SpvOpSubgroupFirstInvocationKHR: {
168 bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
169 vtn_push_ssa_value(b, w[2],
170 vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
171 vtn_ssa_value(b, w[3 + has_scope]),
172 NULL, 0, 0));
173 break;
174 }
175
176 case SpvOpGroupNonUniformBroadcast:
177 case SpvOpGroupBroadcast:
178 case SpvOpSubgroupReadInvocationKHR: {
179 bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
180 vtn_push_ssa_value(b, w[2],
181 vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
182 vtn_ssa_value(b, w[3 + has_scope]),
183 vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
184 break;
185 }
186
187 case SpvOpGroupNonUniformAll:
188 case SpvOpGroupNonUniformAny:
189 case SpvOpGroupNonUniformAllEqual:
190 case SpvOpGroupAll:
191 case SpvOpGroupAny:
192 case SpvOpSubgroupAllKHR:
193 case SpvOpSubgroupAnyKHR:
194 case SpvOpSubgroupAllEqualKHR: {
195 vtn_fail_if(dest_type->type != glsl_bool_type(),
196 "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
197 nir_intrinsic_op op;
198 switch (opcode) {
199 case SpvOpGroupNonUniformAll:
200 case SpvOpGroupAll:
201 case SpvOpSubgroupAllKHR:
202 op = nir_intrinsic_vote_all;
203 break;
204 case SpvOpGroupNonUniformAny:
205 case SpvOpGroupAny:
206 case SpvOpSubgroupAnyKHR:
207 op = nir_intrinsic_vote_any;
208 break;
209 case SpvOpSubgroupAllEqualKHR:
210 op = nir_intrinsic_vote_ieq;
211 break;
212 case SpvOpGroupNonUniformAllEqual:
213 switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
214 case GLSL_TYPE_FLOAT:
215 case GLSL_TYPE_FLOAT16:
216 case GLSL_TYPE_DOUBLE:
217 op = nir_intrinsic_vote_feq;
218 break;
219 case GLSL_TYPE_UINT:
220 case GLSL_TYPE_INT:
221 case GLSL_TYPE_UINT8:
222 case GLSL_TYPE_INT8:
223 case GLSL_TYPE_UINT16:
224 case GLSL_TYPE_INT16:
225 case GLSL_TYPE_UINT64:
226 case GLSL_TYPE_INT64:
227 case GLSL_TYPE_BOOL:
228 op = nir_intrinsic_vote_ieq;
229 break;
230 default:
231 unreachable("Unhandled type");
232 }
233 break;
234 default:
235 unreachable("Unhandled opcode");
236 }
237
238 nir_def *src0;
239 if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
240 opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
241 opcode == SpvOpGroupNonUniformAllEqual) {
242 src0 = vtn_get_nir_ssa(b, w[4]);
243 } else {
244 src0 = vtn_get_nir_ssa(b, w[3]);
245 }
246 nir_intrinsic_instr *intrin =
247 nir_intrinsic_instr_create(b->nb.shader, op);
248 if (nir_intrinsic_infos[op].src_components[0] == 0)
249 intrin->num_components = src0->num_components;
250 intrin->src[0] = nir_src_for_ssa(src0);
251 nir_def_init_for_type(&intrin->instr, &intrin->def,
252 dest_type->type);
253 nir_builder_instr_insert(&b->nb, &intrin->instr);
254
255 vtn_push_nir_ssa(b, w[2], &intrin->def);
256 break;
257 }
258
259 case SpvOpGroupNonUniformShuffle:
260 case SpvOpGroupNonUniformShuffleXor:
261 case SpvOpGroupNonUniformShuffleUp:
262 case SpvOpGroupNonUniformShuffleDown: {
263 nir_intrinsic_op op;
264 switch (opcode) {
265 case SpvOpGroupNonUniformShuffle:
266 op = nir_intrinsic_shuffle;
267 break;
268 case SpvOpGroupNonUniformShuffleXor:
269 op = nir_intrinsic_shuffle_xor;
270 break;
271 case SpvOpGroupNonUniformShuffleUp:
272 op = nir_intrinsic_shuffle_up;
273 break;
274 case SpvOpGroupNonUniformShuffleDown:
275 op = nir_intrinsic_shuffle_down;
276 break;
277 default:
278 unreachable("Invalid opcode");
279 }
280 vtn_push_ssa_value(b, w[2],
281 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
282 vtn_get_nir_ssa(b, w[5]), 0, 0));
283 break;
284 }
285
286 case SpvOpSubgroupShuffleINTEL:
287 case SpvOpSubgroupShuffleXorINTEL: {
288 nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleINTEL ?
289 nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor;
290 vtn_push_ssa_value(b, w[2],
291 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]),
292 vtn_get_nir_ssa(b, w[4]), 0, 0));
293 break;
294 }
295
296 case SpvOpSubgroupShuffleUpINTEL:
297 case SpvOpSubgroupShuffleDownINTEL: {
298 /* TODO: Move this lower on the compiler stack, where we can move the
299 * current/other data to adjacent registers to avoid doing a shuffle
300 * twice.
301 */
302
303 nir_builder *nb = &b->nb;
304 nir_def *size = nir_load_subgroup_size(nb);
305 nir_def *delta = vtn_get_nir_ssa(b, w[5]);
306
307 /* Rewrite UP in terms of DOWN.
308 *
309 * UP(a, b, delta) == DOWN(a, b, size - delta)
310 */
311 if (opcode == SpvOpSubgroupShuffleUpINTEL)
312 delta = nir_isub(nb, size, delta);
313
314 nir_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta);
315 struct vtn_ssa_value *current =
316 vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]),
317 index, 0, 0);
318
319 struct vtn_ssa_value *next =
320 vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]),
321 nir_isub(nb, index, size), 0, 0);
322
323 nir_def *cond = nir_ilt(nb, index, size);
324 vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def));
325
326 break;
327 }
328
329 case SpvOpGroupNonUniformRotateKHR: {
330 const uint32_t cluster_size = count > 6 ? vtn_constant_uint(b, w[6]) : 0;
331 vtn_fail_if(cluster_size && !IS_POT(cluster_size),
332 "Behavior is undefined unless ClusterSize is at least 1 and a power of 2.");
333
334 struct vtn_ssa_value *value = vtn_ssa_value(b, w[4]);
335 struct vtn_ssa_value *delta = vtn_ssa_value(b, w[5]);
336 vtn_push_nir_ssa(b, w[2],
337 vtn_build_subgroup_instr(b, nir_intrinsic_rotate,
338 value, delta->def, cluster_size, 0)->def);
339 break;
340 }
341
342 case SpvOpGroupNonUniformQuadBroadcast:
343 /* From the Vulkan spec 1.3.269:
344 *
345 * 9.27. Quad Group Operations:
346 * "Fragment shaders that statically execute quad group operations
347 * must launch sufficient invocations to ensure their correct operation;"
348 */
349 if (b->shader->info.stage == MESA_SHADER_FRAGMENT)
350 b->shader->info.fs.require_full_quads = true;
351
352 vtn_push_ssa_value(b, w[2],
353 vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
354 vtn_ssa_value(b, w[4]),
355 vtn_get_nir_ssa(b, w[5]), 0, 0));
356 break;
357
358 case SpvOpGroupNonUniformQuadSwap: {
359 if (b->shader->info.stage == MESA_SHADER_FRAGMENT)
360 b->shader->info.fs.require_full_quads = true;
361
362 unsigned direction = vtn_constant_uint(b, w[5]);
363 nir_intrinsic_op op;
364 switch (direction) {
365 case 0:
366 op = nir_intrinsic_quad_swap_horizontal;
367 break;
368 case 1:
369 op = nir_intrinsic_quad_swap_vertical;
370 break;
371 case 2:
372 op = nir_intrinsic_quad_swap_diagonal;
373 break;
374 default:
375 vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
376 }
377 vtn_push_ssa_value(b, w[2],
378 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
379 break;
380 }
381
382 case SpvOpGroupNonUniformQuadAllKHR: {
383 nir_def *dest = nir_quad_vote_all(&b->nb, 1, vtn_get_nir_ssa(b, w[3]));
384 vtn_push_nir_ssa(b, w[2], dest);
385 break;
386 }
387 case SpvOpGroupNonUniformQuadAnyKHR: {
388 nir_def *dest = nir_quad_vote_any(&b->nb, 1, vtn_get_nir_ssa(b, w[3]));
389 vtn_push_nir_ssa(b, w[2], dest);
390 break;
391 }
392
393 case SpvOpGroupNonUniformIAdd:
394 case SpvOpGroupNonUniformFAdd:
395 case SpvOpGroupNonUniformIMul:
396 case SpvOpGroupNonUniformFMul:
397 case SpvOpGroupNonUniformSMin:
398 case SpvOpGroupNonUniformUMin:
399 case SpvOpGroupNonUniformFMin:
400 case SpvOpGroupNonUniformSMax:
401 case SpvOpGroupNonUniformUMax:
402 case SpvOpGroupNonUniformFMax:
403 case SpvOpGroupNonUniformBitwiseAnd:
404 case SpvOpGroupNonUniformBitwiseOr:
405 case SpvOpGroupNonUniformBitwiseXor:
406 case SpvOpGroupNonUniformLogicalAnd:
407 case SpvOpGroupNonUniformLogicalOr:
408 case SpvOpGroupNonUniformLogicalXor:
409 case SpvOpGroupIAdd:
410 case SpvOpGroupFAdd:
411 case SpvOpGroupFMin:
412 case SpvOpGroupUMin:
413 case SpvOpGroupSMin:
414 case SpvOpGroupFMax:
415 case SpvOpGroupUMax:
416 case SpvOpGroupSMax:
417 case SpvOpGroupIAddNonUniformAMD:
418 case SpvOpGroupFAddNonUniformAMD:
419 case SpvOpGroupFMinNonUniformAMD:
420 case SpvOpGroupUMinNonUniformAMD:
421 case SpvOpGroupSMinNonUniformAMD:
422 case SpvOpGroupFMaxNonUniformAMD:
423 case SpvOpGroupUMaxNonUniformAMD:
424 case SpvOpGroupSMaxNonUniformAMD: {
425 nir_op reduction_op;
426 switch (opcode) {
427 case SpvOpGroupNonUniformIAdd:
428 case SpvOpGroupIAdd:
429 case SpvOpGroupIAddNonUniformAMD:
430 reduction_op = nir_op_iadd;
431 break;
432 case SpvOpGroupNonUniformFAdd:
433 case SpvOpGroupFAdd:
434 case SpvOpGroupFAddNonUniformAMD:
435 reduction_op = nir_op_fadd;
436 break;
437 case SpvOpGroupNonUniformIMul:
438 reduction_op = nir_op_imul;
439 break;
440 case SpvOpGroupNonUniformFMul:
441 reduction_op = nir_op_fmul;
442 break;
443 case SpvOpGroupNonUniformSMin:
444 case SpvOpGroupSMin:
445 case SpvOpGroupSMinNonUniformAMD:
446 reduction_op = nir_op_imin;
447 break;
448 case SpvOpGroupNonUniformUMin:
449 case SpvOpGroupUMin:
450 case SpvOpGroupUMinNonUniformAMD:
451 reduction_op = nir_op_umin;
452 break;
453 case SpvOpGroupNonUniformFMin:
454 case SpvOpGroupFMin:
455 case SpvOpGroupFMinNonUniformAMD:
456 reduction_op = nir_op_fmin;
457 break;
458 case SpvOpGroupNonUniformSMax:
459 case SpvOpGroupSMax:
460 case SpvOpGroupSMaxNonUniformAMD:
461 reduction_op = nir_op_imax;
462 break;
463 case SpvOpGroupNonUniformUMax:
464 case SpvOpGroupUMax:
465 case SpvOpGroupUMaxNonUniformAMD:
466 reduction_op = nir_op_umax;
467 break;
468 case SpvOpGroupNonUniformFMax:
469 case SpvOpGroupFMax:
470 case SpvOpGroupFMaxNonUniformAMD:
471 reduction_op = nir_op_fmax;
472 break;
473 case SpvOpGroupNonUniformBitwiseAnd:
474 case SpvOpGroupNonUniformLogicalAnd:
475 reduction_op = nir_op_iand;
476 break;
477 case SpvOpGroupNonUniformBitwiseOr:
478 case SpvOpGroupNonUniformLogicalOr:
479 reduction_op = nir_op_ior;
480 break;
481 case SpvOpGroupNonUniformBitwiseXor:
482 case SpvOpGroupNonUniformLogicalXor:
483 reduction_op = nir_op_ixor;
484 break;
485 default:
486 unreachable("Invalid reduction operation");
487 }
488
489 nir_intrinsic_op op;
490 unsigned cluster_size = 0;
491 switch ((SpvGroupOperation)w[4]) {
492 case SpvGroupOperationReduce:
493 op = nir_intrinsic_reduce;
494 break;
495 case SpvGroupOperationInclusiveScan:
496 op = nir_intrinsic_inclusive_scan;
497 break;
498 case SpvGroupOperationExclusiveScan:
499 op = nir_intrinsic_exclusive_scan;
500 break;
501 case SpvGroupOperationClusteredReduce:
502 op = nir_intrinsic_reduce;
503 assert(count == 7);
504 cluster_size = vtn_constant_uint(b, w[6]);
505 break;
506 default:
507 unreachable("Invalid group operation");
508 }
509
510 vtn_push_ssa_value(b, w[2],
511 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
512 reduction_op, cluster_size));
513 break;
514 }
515
516 default:
517 unreachable("Invalid SPIR-V opcode");
518 }
519 }
520