1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27
28 /*
29 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30 * definition. But for matrix multiplies, we want to do one routine for
31 * multiplying a matrix by a matrix and then pretend that vectors are matrices
32 * with one column. So we "wrap" these things, and unwrap the result before we
33 * send it off.
34 */
35
36 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39 if (val == NULL)
40 return NULL;
41
42 if (glsl_type_is_matrix(val->type))
43 return val;
44
45 struct vtn_ssa_value *dest = vtn_zalloc(b, struct vtn_ssa_value);
46 dest->type = glsl_get_bare_type(val->type);
47 dest->elems = vtn_alloc_array(b, struct vtn_ssa_value *, 1);
48 dest->elems[0] = val;
49
50 return dest;
51 }
52
53 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56 if (glsl_type_is_matrix(val->type))
57 return val;
58
59 return val->elems[0];
60 }
61
62 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)63 matrix_multiply(struct vtn_builder *b,
64 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66
67 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71
72 unsigned src0_rows = glsl_get_vector_elements(src0->type);
73 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75
76 const struct glsl_type *dest_type;
77 if (src1_columns > 1) {
78 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79 src0_rows, src1_columns);
80 } else {
81 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82 }
83 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84
85 dest = wrap_matrix(b, dest);
86
87 bool transpose_result = false;
88 if (src0_transpose && src1_transpose) {
89 /* transpose(A) * transpose(B) = transpose(B * A) */
90 src1 = src0_transpose;
91 src0 = src1_transpose;
92 src0_transpose = NULL;
93 src1_transpose = NULL;
94 transpose_result = true;
95 }
96
97 for (unsigned i = 0; i < src1_columns; i++) {
98 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
99 dest->elems[i]->def =
100 nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
101 nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
102 for (int j = src0_columns - 2; j >= 0; j--) {
103 dest->elems[i]->def =
104 nir_ffma(&b->nb, src0->elems[j]->def,
105 nir_channel(&b->nb, src1->elems[i]->def, j),
106 dest->elems[i]->def);
107 }
108 }
109
110 dest = unwrap_matrix(dest);
111
112 if (transpose_result)
113 dest = vtn_ssa_transpose(b, dest);
114
115 return dest;
116 }
117
118 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_def * scalar)119 mat_times_scalar(struct vtn_builder *b,
120 struct vtn_ssa_value *mat,
121 nir_def *scalar)
122 {
123 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
124 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
125 if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
126 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
127 else
128 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
129 }
130
131 return dest;
132 }
133
134 nir_def *
vtn_mediump_downconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)135 vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
136 {
137 if (def->bit_size == 16)
138 return def;
139
140 switch (base_type) {
141 case GLSL_TYPE_FLOAT:
142 return nir_f2fmp(&b->nb, def);
143 case GLSL_TYPE_INT:
144 case GLSL_TYPE_UINT:
145 return nir_i2imp(&b->nb, def);
146 /* Workaround for 3DMark Wild Life which has RelaxedPrecision on
147 * OpLogical* operations (which is forbidden by spec).
148 */
149 case GLSL_TYPE_BOOL:
150 return def;
151 default:
152 unreachable("bad relaxed precision input type");
153 }
154 }
155
156 struct vtn_ssa_value *
vtn_mediump_downconvert_value(struct vtn_builder * b,struct vtn_ssa_value * src)157 vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src)
158 {
159 if (!src)
160 return src;
161
162 struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type);
163
164 if (src->transposed) {
165 srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed);
166 } else {
167 enum glsl_base_type base_type = glsl_get_base_type(src->type);
168
169 if (glsl_type_is_vector_or_scalar(src->type)) {
170 srcmp->def = vtn_mediump_downconvert(b, base_type, src->def);
171 } else {
172 assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT);
173 for (int i = 0; i < glsl_get_matrix_columns(src->type); i++)
174 srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def);
175 }
176 }
177
178 return srcmp;
179 }
180
181 static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)182 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
183 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
184 {
185 switch (opcode) {
186 case SpvOpFNegate: {
187 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
188 unsigned cols = glsl_get_matrix_columns(src0->type);
189 for (unsigned i = 0; i < cols; i++)
190 dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
191 return dest;
192 }
193
194 case SpvOpFAdd: {
195 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
196 unsigned cols = glsl_get_matrix_columns(src0->type);
197 for (unsigned i = 0; i < cols; i++)
198 dest->elems[i]->def =
199 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
200 return dest;
201 }
202
203 case SpvOpFSub: {
204 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
205 unsigned cols = glsl_get_matrix_columns(src0->type);
206 for (unsigned i = 0; i < cols; i++)
207 dest->elems[i]->def =
208 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
209 return dest;
210 }
211
212 case SpvOpTranspose:
213 return vtn_ssa_transpose(b, src0);
214
215 case SpvOpMatrixTimesScalar:
216 if (src0->transposed) {
217 return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
218 src1->def));
219 } else {
220 return mat_times_scalar(b, src0, src1->def);
221 }
222 break;
223
224 case SpvOpVectorTimesMatrix:
225 case SpvOpMatrixTimesVector:
226 case SpvOpMatrixTimesMatrix:
227 if (opcode == SpvOpVectorTimesMatrix) {
228 return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
229 } else {
230 return matrix_multiply(b, src0, src1);
231 }
232 break;
233
234 default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
235 }
236 }
237
238 static nir_alu_type
convert_op_src_type(SpvOp opcode)239 convert_op_src_type(SpvOp opcode)
240 {
241 switch (opcode) {
242 case SpvOpFConvert:
243 case SpvOpConvertFToS:
244 case SpvOpConvertFToU:
245 return nir_type_float;
246 case SpvOpSConvert:
247 case SpvOpConvertSToF:
248 case SpvOpSatConvertSToU:
249 return nir_type_int;
250 case SpvOpUConvert:
251 case SpvOpConvertUToF:
252 case SpvOpSatConvertUToS:
253 return nir_type_uint;
254 default:
255 unreachable("Unhandled conversion op");
256 }
257 }
258
259 static nir_alu_type
convert_op_dst_type(SpvOp opcode)260 convert_op_dst_type(SpvOp opcode)
261 {
262 switch (opcode) {
263 case SpvOpFConvert:
264 case SpvOpConvertSToF:
265 case SpvOpConvertUToF:
266 return nir_type_float;
267 case SpvOpSConvert:
268 case SpvOpConvertFToS:
269 case SpvOpSatConvertUToS:
270 return nir_type_int;
271 case SpvOpUConvert:
272 case SpvOpConvertFToU:
273 case SpvOpSatConvertSToU:
274 return nir_type_uint;
275 default:
276 unreachable("Unhandled conversion op");
277 }
278 }
279
280 nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,bool * exact,unsigned src_bit_size,unsigned dst_bit_size)281 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
282 SpvOp opcode, bool *swap, bool *exact,
283 unsigned src_bit_size, unsigned dst_bit_size)
284 {
285 /* Indicates that the first two arguments should be swapped. This is
286 * used for implementing greater-than and less-than-or-equal.
287 */
288 *swap = false;
289
290 *exact = false;
291
292 switch (opcode) {
293 case SpvOpSNegate: return nir_op_ineg;
294 case SpvOpFNegate: return nir_op_fneg;
295 case SpvOpNot: return nir_op_inot;
296 case SpvOpIAdd: return nir_op_iadd;
297 case SpvOpFAdd: return nir_op_fadd;
298 case SpvOpISub: return nir_op_isub;
299 case SpvOpFSub: return nir_op_fsub;
300 case SpvOpIMul: return nir_op_imul;
301 case SpvOpFMul: return nir_op_fmul;
302 case SpvOpUDiv: return nir_op_udiv;
303 case SpvOpSDiv: return nir_op_idiv;
304 case SpvOpFDiv: return nir_op_fdiv;
305 case SpvOpUMod: return nir_op_umod;
306 case SpvOpSMod: return nir_op_imod;
307 case SpvOpFMod: return nir_op_fmod;
308 case SpvOpSRem: return nir_op_irem;
309 case SpvOpFRem: return nir_op_frem;
310
311 case SpvOpShiftRightLogical: return nir_op_ushr;
312 case SpvOpShiftRightArithmetic: return nir_op_ishr;
313 case SpvOpShiftLeftLogical: return nir_op_ishl;
314 case SpvOpLogicalOr: return nir_op_ior;
315 case SpvOpLogicalEqual: return nir_op_ieq;
316 case SpvOpLogicalNotEqual: return nir_op_ine;
317 case SpvOpLogicalAnd: return nir_op_iand;
318 case SpvOpLogicalNot: return nir_op_inot;
319 case SpvOpBitwiseOr: return nir_op_ior;
320 case SpvOpBitwiseXor: return nir_op_ixor;
321 case SpvOpBitwiseAnd: return nir_op_iand;
322 case SpvOpSelect: return nir_op_bcsel;
323 case SpvOpIEqual: return nir_op_ieq;
324
325 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
326 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
327 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
328 case SpvOpBitReverse: return nir_op_bitfield_reverse;
329
330 case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
331 /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
332 case SpvOpAbsISubINTEL: return nir_op_uabs_isub;
333 case SpvOpAbsUSubINTEL: return nir_op_uabs_usub;
334 case SpvOpIAddSatINTEL: return nir_op_iadd_sat;
335 case SpvOpUAddSatINTEL: return nir_op_uadd_sat;
336 case SpvOpIAverageINTEL: return nir_op_ihadd;
337 case SpvOpUAverageINTEL: return nir_op_uhadd;
338 case SpvOpIAverageRoundedINTEL: return nir_op_irhadd;
339 case SpvOpUAverageRoundedINTEL: return nir_op_urhadd;
340 case SpvOpISubSatINTEL: return nir_op_isub_sat;
341 case SpvOpUSubSatINTEL: return nir_op_usub_sat;
342 case SpvOpIMul32x16INTEL: return nir_op_imul_32x16;
343 case SpvOpUMul32x16INTEL: return nir_op_umul_32x16;
344
345 /* The ordered / unordered operators need special implementation besides
346 * the logical operator to use since they also need to check if operands are
347 * ordered.
348 */
349 case SpvOpFOrdEqual: *exact = true; return nir_op_feq;
350 case SpvOpFUnordEqual: *exact = true; return nir_op_feq;
351 case SpvOpINotEqual: return nir_op_ine;
352 case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */
353 case SpvOpFOrdNotEqual: *exact = true; return nir_op_fneu;
354 case SpvOpFUnordNotEqual: *exact = true; return nir_op_fneu;
355 case SpvOpULessThan: return nir_op_ult;
356 case SpvOpSLessThan: return nir_op_ilt;
357 case SpvOpFOrdLessThan: *exact = true; return nir_op_flt;
358 case SpvOpFUnordLessThan: *exact = true; return nir_op_flt;
359 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
360 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
361 case SpvOpFOrdGreaterThan: *swap = true; *exact = true; return nir_op_flt;
362 case SpvOpFUnordGreaterThan: *swap = true; *exact = true; return nir_op_flt;
363 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
364 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
365 case SpvOpFOrdLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
366 case SpvOpFUnordLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
367 case SpvOpUGreaterThanEqual: return nir_op_uge;
368 case SpvOpSGreaterThanEqual: return nir_op_ige;
369 case SpvOpFOrdGreaterThanEqual: *exact = true; return nir_op_fge;
370 case SpvOpFUnordGreaterThanEqual: *exact = true; return nir_op_fge;
371
372 /* Conversions: */
373 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
374 case SpvOpUConvert:
375 case SpvOpConvertFToU:
376 case SpvOpConvertFToS:
377 case SpvOpConvertSToF:
378 case SpvOpConvertUToF:
379 case SpvOpSConvert:
380 case SpvOpFConvert: {
381 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
382 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
383 return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
384 }
385
386 case SpvOpPtrCastToGeneric: return nir_op_mov;
387 case SpvOpGenericCastToPtr: return nir_op_mov;
388
389 case SpvOpIsNormal: return nir_op_fisnormal;
390 case SpvOpIsFinite: return nir_op_fisfinite;
391
392 default:
393 vtn_fail("No NIR equivalent: %u", opcode);
394 }
395 }
396
397 static void
handle_fp_fast_math(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)398 handle_fp_fast_math(struct vtn_builder *b, UNUSED struct vtn_value *val,
399 UNUSED int member, const struct vtn_decoration *dec,
400 UNUSED void *_void)
401 {
402 vtn_assert(dec->scope == VTN_DEC_DECORATION);
403 if (dec->decoration != SpvDecorationFPFastMathMode)
404 return;
405
406 SpvFPFastMathModeMask can_fast_math =
407 SpvFPFastMathModeAllowRecipMask |
408 SpvFPFastMathModeAllowContractMask |
409 SpvFPFastMathModeAllowReassocMask |
410 SpvFPFastMathModeAllowTransformMask;
411
412 if ((dec->operands[0] & can_fast_math) != can_fast_math)
413 b->nb.exact = true;
414
415 /* Decoration overrides defaults */
416 b->nb.fp_fast_math = 0;
417 if (!(dec->operands[0] & SpvFPFastMathModeNSZMask))
418 b->nb.fp_fast_math |=
419 FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP16 |
420 FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP32 |
421 FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP64;
422 if (!(dec->operands[0] & SpvFPFastMathModeNotNaNMask))
423 b->nb.fp_fast_math |=
424 FLOAT_CONTROLS_NAN_PRESERVE_FP16 |
425 FLOAT_CONTROLS_NAN_PRESERVE_FP32 |
426 FLOAT_CONTROLS_NAN_PRESERVE_FP64;
427 if (!(dec->operands[0] & SpvFPFastMathModeNotInfMask))
428 b->nb.fp_fast_math |=
429 FLOAT_CONTROLS_INF_PRESERVE_FP16 |
430 FLOAT_CONTROLS_INF_PRESERVE_FP32 |
431 FLOAT_CONTROLS_INF_PRESERVE_FP64;
432 }
433
434 void
vtn_handle_fp_fast_math(struct vtn_builder * b,struct vtn_value * val)435 vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val)
436 {
437 /* Take the NaN/Inf/SZ preserve bits from the execution mode and set them
438 * on the builder, so the generated instructions can take it from it.
439 * We only care about some of them, check nir_alu_instr for details.
440 * We also copy all bit widths, because we can't easily get the correct one
441 * here.
442 */
443 #define FLOAT_CONTROLS2_BITS (FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP16 | \
444 FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP32 | \
445 FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP64)
446 static_assert(FLOAT_CONTROLS2_BITS == BITSET_MASK(9),
447 "enum float_controls and fp_fast_math out of sync!");
448 b->nb.fp_fast_math = b->shader->info.float_controls_execution_mode &
449 FLOAT_CONTROLS2_BITS;
450 vtn_foreach_decoration(b, val, handle_fp_fast_math, NULL);
451 #undef FLOAT_CONTROLS2_BITS
452 }
453
454 static void
handle_no_contraction(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)455 handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
456 UNUSED int member, const struct vtn_decoration *dec,
457 UNUSED void *_void)
458 {
459 vtn_assert(dec->scope == VTN_DEC_DECORATION);
460 if (dec->decoration != SpvDecorationNoContraction)
461 return;
462
463 b->nb.exact = true;
464 }
465
466 void
vtn_handle_no_contraction(struct vtn_builder * b,struct vtn_value * val)467 vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
468 {
469 vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
470 }
471
472 nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder * b,SpvFPRoundingMode mode)473 vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
474 {
475 switch (mode) {
476 case SpvFPRoundingModeRTE:
477 return nir_rounding_mode_rtne;
478 case SpvFPRoundingModeRTZ:
479 return nir_rounding_mode_rtz;
480 case SpvFPRoundingModeRTP:
481 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
482 "FPRoundingModeRTP is only supported in kernels");
483 return nir_rounding_mode_ru;
484 case SpvFPRoundingModeRTN:
485 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
486 "FPRoundingModeRTN is only supported in kernels");
487 return nir_rounding_mode_rd;
488 default:
489 vtn_fail("Unsupported rounding mode: %s",
490 spirv_fproundingmode_to_string(mode));
491 break;
492 }
493 }
494
495 struct conversion_opts {
496 nir_rounding_mode rounding_mode;
497 bool saturate;
498 };
499
500 static void
handle_conversion_opts(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _opts)501 handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
502 UNUSED int member,
503 const struct vtn_decoration *dec, void *_opts)
504 {
505 struct conversion_opts *opts = _opts;
506
507 switch (dec->decoration) {
508 case SpvDecorationFPRoundingMode:
509 opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
510 break;
511
512 case SpvDecorationSaturatedConversion:
513 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
514 "Saturated conversions are only allowed in kernels");
515 opts->saturate = true;
516 break;
517
518 default:
519 break;
520 }
521 }
522
523 static void
handle_no_wrap(UNUSED struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _alu)524 handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
525 UNUSED int member,
526 const struct vtn_decoration *dec, void *_alu)
527 {
528 nir_alu_instr *alu = _alu;
529 switch (dec->decoration) {
530 case SpvDecorationNoSignedWrap:
531 alu->no_signed_wrap = true;
532 break;
533 case SpvDecorationNoUnsignedWrap:
534 alu->no_unsigned_wrap = true;
535 break;
536 default:
537 /* Do nothing. */
538 break;
539 }
540 }
541
542 static void
vtn_value_is_relaxed_precision_cb(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * void_ctx)543 vtn_value_is_relaxed_precision_cb(struct vtn_builder *b,
544 struct vtn_value *val, int member,
545 const struct vtn_decoration *dec, void *void_ctx)
546 {
547 bool *relaxed_precision = void_ctx;
548 switch (dec->decoration) {
549 case SpvDecorationRelaxedPrecision:
550 *relaxed_precision = true;
551 break;
552
553 default:
554 break;
555 }
556 }
557
558 bool
vtn_value_is_relaxed_precision(struct vtn_builder * b,struct vtn_value * val)559 vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val)
560 {
561 bool result = false;
562 vtn_foreach_decoration(b, val,
563 vtn_value_is_relaxed_precision_cb, &result);
564 return result;
565 }
566
567 static bool
vtn_alu_op_mediump_16bit(struct vtn_builder * b,SpvOp opcode,struct vtn_value * dest_val)568 vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val)
569 {
570 if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val))
571 return false;
572
573 switch (opcode) {
574 case SpvOpDPdx:
575 case SpvOpDPdy:
576 case SpvOpDPdxFine:
577 case SpvOpDPdyFine:
578 case SpvOpDPdxCoarse:
579 case SpvOpDPdyCoarse:
580 case SpvOpFwidth:
581 case SpvOpFwidthFine:
582 case SpvOpFwidthCoarse:
583 return b->options->mediump_16bit_derivatives;
584 default:
585 return true;
586 }
587 }
588
589 static nir_def *
vtn_mediump_upconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)590 vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
591 {
592 if (def->bit_size != 16)
593 return def;
594
595 switch (base_type) {
596 case GLSL_TYPE_FLOAT:
597 return nir_f2f32(&b->nb, def);
598 case GLSL_TYPE_INT:
599 return nir_i2i32(&b->nb, def);
600 case GLSL_TYPE_UINT:
601 return nir_u2u32(&b->nb, def);
602 default:
603 unreachable("bad relaxed precision output type");
604 }
605 }
606
607 void
vtn_mediump_upconvert_value(struct vtn_builder * b,struct vtn_ssa_value * value)608 vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
609 {
610 enum glsl_base_type base_type = glsl_get_base_type(value->type);
611
612 if (glsl_type_is_vector_or_scalar(value->type)) {
613 value->def = vtn_mediump_upconvert(b, base_type, value->def);
614 } else {
615 for (int i = 0; i < glsl_get_matrix_columns(value->type); i++)
616 value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def);
617 }
618 }
619
620 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)621 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
622 const uint32_t *w, unsigned count)
623 {
624 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
625 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
626
627 if (glsl_type_is_cmat(dest_type)) {
628 vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count);
629 return;
630 }
631
632 vtn_handle_no_contraction(b, dest_val);
633 vtn_handle_fp_fast_math(b, dest_val);
634 bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
635
636 /* Collect the various SSA sources */
637 const unsigned num_inputs = count - 3;
638 struct vtn_ssa_value *vtn_src[4] = { NULL, };
639 for (unsigned i = 0; i < num_inputs; i++) {
640 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
641 if (mediump_16bit)
642 vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]);
643 }
644
645 if (glsl_type_is_matrix(vtn_src[0]->type) ||
646 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
647 struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]);
648
649 if (mediump_16bit)
650 vtn_mediump_upconvert_value(b, dest);
651
652 vtn_push_ssa_value(b, w[2], dest);
653 b->nb.exact = b->exact;
654 return;
655 }
656
657 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
658 nir_def *src[4] = { NULL, };
659 for (unsigned i = 0; i < num_inputs; i++) {
660 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
661 src[i] = vtn_src[i]->def;
662 }
663
664 switch (opcode) {
665 case SpvOpAny:
666 dest->def = nir_bany(&b->nb, src[0]);
667 break;
668
669 case SpvOpAll:
670 dest->def = nir_ball(&b->nb, src[0]);
671 break;
672
673 case SpvOpOuterProduct: {
674 for (unsigned i = 0; i < src[1]->num_components; i++) {
675 dest->elems[i]->def =
676 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
677 }
678 break;
679 }
680
681 case SpvOpDot:
682 dest->def = nir_fdot(&b->nb, src[0], src[1]);
683 break;
684
685 case SpvOpIAddCarry:
686 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
687 dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
688 dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
689 break;
690
691 case SpvOpISubBorrow:
692 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
693 dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
694 dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
695 break;
696
697 case SpvOpUMulExtended: {
698 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
699 if (src[0]->bit_size == 32) {
700 nir_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
701 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
702 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
703 } else {
704 dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
705 dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
706 }
707 break;
708 }
709
710 case SpvOpSMulExtended: {
711 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
712 if (src[0]->bit_size == 32) {
713 nir_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
714 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
715 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
716 } else {
717 dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
718 dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
719 }
720 break;
721 }
722
723 case SpvOpDPdx:
724 dest->def = nir_ddx(&b->nb, src[0]);
725 break;
726 case SpvOpDPdxFine:
727 dest->def = nir_ddx_fine(&b->nb, src[0]);
728 break;
729 case SpvOpDPdxCoarse:
730 dest->def = nir_ddx_coarse(&b->nb, src[0]);
731 break;
732 case SpvOpDPdy:
733 dest->def = nir_ddy(&b->nb, src[0]);
734 break;
735 case SpvOpDPdyFine:
736 dest->def = nir_ddy_fine(&b->nb, src[0]);
737 break;
738 case SpvOpDPdyCoarse:
739 dest->def = nir_ddy_coarse(&b->nb, src[0]);
740 break;
741
742 case SpvOpFwidth:
743 dest->def = nir_fadd(&b->nb,
744 nir_fabs(&b->nb, nir_ddx(&b->nb, src[0])),
745 nir_fabs(&b->nb, nir_ddy(&b->nb, src[0])));
746 break;
747 case SpvOpFwidthFine:
748 dest->def = nir_fadd(&b->nb,
749 nir_fabs(&b->nb, nir_ddx_fine(&b->nb, src[0])),
750 nir_fabs(&b->nb, nir_ddy_fine(&b->nb, src[0])));
751 break;
752 case SpvOpFwidthCoarse:
753 dest->def = nir_fadd(&b->nb,
754 nir_fabs(&b->nb, nir_ddx_coarse(&b->nb, src[0])),
755 nir_fabs(&b->nb, nir_ddy_coarse(&b->nb, src[0])));
756 break;
757
758 case SpvOpVectorTimesScalar:
759 /* The builder will take care of splatting for us. */
760 dest->def = nir_fmul(&b->nb, src[0], src[1]);
761 break;
762
763 case SpvOpIsNan: {
764 const bool save_exact = b->nb.exact;
765
766 b->nb.exact = true;
767 dest->def = nir_fneu(&b->nb, src[0], src[0]);
768 b->nb.exact = save_exact;
769 break;
770 }
771
772 case SpvOpOrdered: {
773 const bool save_exact = b->nb.exact;
774
775 b->nb.exact = true;
776 dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
777 nir_feq(&b->nb, src[1], src[1]));
778 b->nb.exact = save_exact;
779 break;
780 }
781
782 case SpvOpUnordered: {
783 const bool save_exact = b->nb.exact;
784
785 b->nb.exact = true;
786 dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
787 nir_fneu(&b->nb, src[1], src[1]));
788 b->nb.exact = save_exact;
789 break;
790 }
791
792 case SpvOpIsInf: {
793 nir_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
794 dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
795 break;
796 }
797
798 case SpvOpFUnordEqual: {
799 const bool save_exact = b->nb.exact;
800
801 b->nb.exact = true;
802
803 /* This could also be implemented as !(a < b || b < a). If one or both
804 * of the source are numbers, later optimization passes can easily
805 * eliminate the isnan() checks. This may trim the sequence down to a
806 * single (a == b) operation. Otherwise, the optimizer can transform
807 * whatever is left to !(a < b || b < a). Since some applications will
808 * open-code this sequence, these optimizations are needed anyway.
809 */
810 dest->def =
811 nir_ior(&b->nb,
812 nir_feq(&b->nb, src[0], src[1]),
813 nir_ior(&b->nb,
814 nir_fneu(&b->nb, src[0], src[0]),
815 nir_fneu(&b->nb, src[1], src[1])));
816
817 b->nb.exact = save_exact;
818 break;
819 }
820
821 case SpvOpFUnordLessThan:
822 case SpvOpFUnordGreaterThan:
823 case SpvOpFUnordLessThanEqual:
824 case SpvOpFUnordGreaterThanEqual: {
825 bool swap;
826 bool unused_exact;
827 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
828 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
829 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
830 &unused_exact,
831 src_bit_size, dst_bit_size);
832
833 if (swap) {
834 nir_def *tmp = src[0];
835 src[0] = src[1];
836 src[1] = tmp;
837 }
838
839 const bool save_exact = b->nb.exact;
840
841 b->nb.exact = true;
842
843 /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
844 switch (op) {
845 case nir_op_fge: op = nir_op_flt; break;
846 case nir_op_flt: op = nir_op_fge; break;
847 default: unreachable("Impossible opcode.");
848 }
849
850 dest->def =
851 nir_inot(&b->nb,
852 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
853
854 b->nb.exact = save_exact;
855 break;
856 }
857
858 case SpvOpLessOrGreater:
859 case SpvOpFOrdNotEqual: {
860 /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
861 * from the ALU will probably already be false if the operands are not
862 * ordered so we don’t need to handle it specially.
863 */
864 const bool save_exact = b->nb.exact;
865
866 b->nb.exact = true;
867
868 /* This could also be implemented as (a < b || b < a). If one or both
869 * of the source are numbers, later optimization passes can easily
870 * eliminate the isnan() checks. This may trim the sequence down to a
871 * single (a != b) operation. Otherwise, the optimizer can transform
872 * whatever is left to (a < b || b < a). Since some applications will
873 * open-code this sequence, these optimizations are needed anyway.
874 */
875 dest->def =
876 nir_iand(&b->nb,
877 nir_fneu(&b->nb, src[0], src[1]),
878 nir_iand(&b->nb,
879 nir_feq(&b->nb, src[0], src[0]),
880 nir_feq(&b->nb, src[1], src[1])));
881
882 b->nb.exact = save_exact;
883 break;
884 }
885
886 case SpvOpUConvert:
887 case SpvOpConvertFToU:
888 case SpvOpConvertFToS:
889 case SpvOpConvertSToF:
890 case SpvOpConvertUToF:
891 case SpvOpSConvert:
892 case SpvOpFConvert:
893 case SpvOpSatConvertSToU:
894 case SpvOpSatConvertUToS: {
895 unsigned src_bit_size = src[0]->bit_size;
896 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
897 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
898 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
899
900 struct conversion_opts opts = {
901 .rounding_mode = nir_rounding_mode_undef,
902 .saturate = false,
903 };
904 vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
905
906 if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
907 opts.saturate = true;
908
909 if (b->shader->info.stage == MESA_SHADER_KERNEL) {
910 if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
911 dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
912 nir_rounding_mode_undef);
913 } else {
914 dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
915 src_type, dst_type,
916 opts.rounding_mode, opts.saturate);
917 }
918 } else {
919 vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
920 dst_type != nir_type_float16,
921 "Rounding modes are only allowed on conversions to "
922 "16-bit float types");
923 dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
924 opts.rounding_mode);
925 }
926 break;
927 }
928
929 case SpvOpBitFieldInsert:
930 case SpvOpBitFieldSExtract:
931 case SpvOpBitFieldUExtract:
932 case SpvOpShiftLeftLogical:
933 case SpvOpShiftRightArithmetic:
934 case SpvOpShiftRightLogical: {
935 bool swap;
936 bool exact;
937 unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
938 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
939 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
940 src0_bit_size, dst_bit_size);
941
942 assert(!exact);
943
944 assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
945 op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
946 op == nir_op_ibitfield_extract);
947
948 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
949 unsigned src_bit_size =
950 nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
951 if (src_bit_size == 0)
952 continue;
953 if (src_bit_size != src[i]->bit_size) {
954 assert(src_bit_size == 32);
955 /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
956 * supported by the NIR instructions. See discussion here:
957 *
958 * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
959 */
960 src[i] = nir_u2u32(&b->nb, src[i]);
961 }
962 }
963 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
964 break;
965 }
966
967 case SpvOpSignBitSet:
968 dest->def = nir_i2b(&b->nb,
969 nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
970 break;
971
972 case SpvOpUCountTrailingZerosINTEL:
973 dest->def = nir_umin(&b->nb,
974 nir_find_lsb(&b->nb, src[0]),
975 nir_imm_int(&b->nb, 32u));
976 break;
977
978 case SpvOpBitCount: {
979 /* bit_count always returns int32, but the SPIR-V opcode just says the return
980 * value needs to be big enough to store the number of bits.
981 */
982 dest->def = nir_u2uN(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
983 break;
984 }
985
986 case SpvOpSDotKHR:
987 case SpvOpUDotKHR:
988 case SpvOpSUDotKHR:
989 case SpvOpSDotAccSatKHR:
990 case SpvOpUDotAccSatKHR:
991 case SpvOpSUDotAccSatKHR:
992 unreachable("Should have called vtn_handle_integer_dot instead.");
993
994 default: {
995 bool swap;
996 bool exact;
997 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
998 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
999 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
1000 &exact,
1001 src_bit_size, dst_bit_size);
1002
1003 if (swap) {
1004 nir_def *tmp = src[0];
1005 src[0] = src[1];
1006 src[1] = tmp;
1007 }
1008
1009 switch (op) {
1010 case nir_op_ishl:
1011 case nir_op_ishr:
1012 case nir_op_ushr:
1013 if (src[1]->bit_size != 32)
1014 src[1] = nir_u2u32(&b->nb, src[1]);
1015 break;
1016 default:
1017 break;
1018 }
1019
1020 const bool save_exact = b->nb.exact;
1021
1022 if (exact)
1023 b->nb.exact = true;
1024
1025 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
1026
1027 b->nb.exact = save_exact;
1028 break;
1029 } /* default */
1030 }
1031
1032 switch (opcode) {
1033 case SpvOpIAdd:
1034 case SpvOpIMul:
1035 case SpvOpISub:
1036 case SpvOpShiftLeftLogical:
1037 case SpvOpSNegate: {
1038 nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
1039 vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
1040 break;
1041 }
1042 default:
1043 /* Do nothing. */
1044 break;
1045 }
1046
1047 if (mediump_16bit)
1048 vtn_mediump_upconvert_value(b, dest);
1049 vtn_push_ssa_value(b, w[2], dest);
1050
1051 b->nb.exact = b->exact;
1052 }
1053
1054 void
vtn_handle_integer_dot(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)1055 vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
1056 const uint32_t *w, unsigned count)
1057 {
1058 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
1059 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
1060 const unsigned dest_size = glsl_get_bit_size(dest_type);
1061
1062 vtn_handle_no_contraction(b, dest_val);
1063
1064 /* Collect the various SSA sources.
1065 *
1066 * Due to the optional "Packed Vector Format" field, determine number of
1067 * inputs from the opcode. This differs from vtn_handle_alu.
1068 */
1069 const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
1070 opcode == SpvOpUDotAccSatKHR ||
1071 opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
1072
1073 vtn_assert(count >= num_inputs + 3);
1074
1075 struct vtn_ssa_value *vtn_src[3] = { NULL, };
1076 nir_def *src[3] = { NULL, };
1077
1078 for (unsigned i = 0; i < num_inputs; i++) {
1079 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
1080 src[i] = vtn_src[i]->def;
1081
1082 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
1083 }
1084
1085 /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
1086 * the SPV_KHR_integer_dot_product spec says:
1087 *
1088 * _Vector 1_ and _Vector 2_ must have the same type.
1089 *
1090 * The practical requirement is the same bit-size and the same number of
1091 * components.
1092 */
1093 vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
1094 glsl_get_bit_size(vtn_src[1]->type) ||
1095 glsl_get_vector_elements(vtn_src[0]->type) !=
1096 glsl_get_vector_elements(vtn_src[1]->type),
1097 "Vector 1 and vector 2 source of opcode %s must have the same "
1098 "type",
1099 spirv_op_to_string(opcode));
1100
1101 if (num_inputs == 3) {
1102 /* The SPV_KHR_integer_dot_product spec says:
1103 *
1104 * The type of Accumulator must be the same as Result Type.
1105 *
1106 * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
1107 * types (far below) assumes these types have the same size.
1108 */
1109 vtn_fail_if(dest_type != vtn_src[2]->type,
1110 "Accumulator type must be the same as Result Type for "
1111 "opcode %s",
1112 spirv_op_to_string(opcode));
1113 }
1114
1115 unsigned packed_bit_size = 8;
1116 if (glsl_type_is_vector(vtn_src[0]->type)) {
1117 /* FINISHME: Is this actually as good or better for platforms that don't
1118 * have the special instructions (i.e., one or both of has_dot_4x8 or
1119 * has_sudot_4x8 is false)?
1120 */
1121 if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
1122 glsl_get_bit_size(vtn_src[0]->type) == 8 &&
1123 glsl_get_bit_size(dest_type) <= 32) {
1124 src[0] = nir_pack_32_4x8(&b->nb, src[0]);
1125 src[1] = nir_pack_32_4x8(&b->nb, src[1]);
1126 } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
1127 glsl_get_bit_size(vtn_src[0]->type) == 16 &&
1128 glsl_get_bit_size(dest_type) <= 32 &&
1129 opcode != SpvOpSUDotKHR &&
1130 opcode != SpvOpSUDotAccSatKHR) {
1131 src[0] = nir_pack_32_2x16(&b->nb, src[0]);
1132 src[1] = nir_pack_32_2x16(&b->nb, src[1]);
1133 packed_bit_size = 16;
1134 }
1135 } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
1136 glsl_type_is_32bit(vtn_src[0]->type)) {
1137 /* The SPV_KHR_integer_dot_product spec says:
1138 *
1139 * When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
1140 * Vector Format_ must be specified to select how the integers are to
1141 * be interpreted as vectors.
1142 *
1143 * The "Packed Vector Format" value follows the last input.
1144 */
1145 vtn_assert(count == (num_inputs + 4));
1146 const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
1147 vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
1148 "Unsupported vector packing format %d for opcode %s",
1149 pack_format, spirv_op_to_string(opcode));
1150 } else {
1151 vtn_fail_with_opcode("Invalid source types.", opcode);
1152 }
1153
1154 nir_def *dest = NULL;
1155
1156 if (src[0]->num_components > 1) {
1157 nir_def *(*src0_conversion)(nir_builder *, nir_def *, unsigned);
1158 nir_def *(*src1_conversion)(nir_builder *, nir_def *, unsigned);
1159
1160 switch (opcode) {
1161 case SpvOpSDotKHR:
1162 case SpvOpSDotAccSatKHR:
1163 src0_conversion = nir_i2iN;
1164 src1_conversion = nir_i2iN;
1165 break;
1166
1167 case SpvOpUDotKHR:
1168 case SpvOpUDotAccSatKHR:
1169 src0_conversion = nir_u2uN;
1170 src1_conversion = nir_u2uN;
1171 break;
1172
1173 case SpvOpSUDotKHR:
1174 case SpvOpSUDotAccSatKHR:
1175 src0_conversion = nir_i2iN;
1176 src1_conversion = nir_u2uN;
1177 break;
1178
1179 default:
1180 unreachable("Invalid opcode.");
1181 }
1182
1183 /* The SPV_KHR_integer_dot_product spec says:
1184 *
1185 * All components of the input vectors are sign-extended to the bit
1186 * width of the result's type. The sign-extended input vectors are
1187 * then multiplied component-wise and all components of the vector
1188 * resulting from the component-wise multiplication are added
1189 * together. The resulting value will equal the low-order N bits of
1190 * the correct result R, where N is the result width and R is
1191 * computed with enough precision to avoid overflow and underflow.
1192 */
1193 const unsigned vector_components =
1194 glsl_get_vector_elements(vtn_src[0]->type);
1195
1196 for (unsigned i = 0; i < vector_components; i++) {
1197 nir_def *const src0 =
1198 src0_conversion(&b->nb, nir_channel(&b->nb, src[0], i), dest_size);
1199
1200 nir_def *const src1 =
1201 src1_conversion(&b->nb, nir_channel(&b->nb, src[1], i), dest_size);
1202
1203 nir_def *const mul_result = nir_imul(&b->nb, src0, src1);
1204
1205 dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1206 }
1207
1208 if (num_inputs == 3) {
1209 /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1210 *
1211 * Signed integer dot product of _Vector 1_ and _Vector 2_ and
1212 * signed saturating addition of the result with _Accumulator_.
1213 *
1214 * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1215 *
1216 * Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1217 * unsigned saturating addition of the result with _Accumulator_.
1218 *
1219 * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1220 *
1221 * Mixed-signedness integer dot product of _Vector 1_ and _Vector
1222 * 2_ and signed saturating addition of the result with
1223 * _Accumulator_.
1224 */
1225 dest = (opcode == SpvOpUDotAccSatKHR)
1226 ? nir_uadd_sat(&b->nb, dest, src[2])
1227 : nir_iadd_sat(&b->nb, dest, src[2]);
1228 }
1229 } else {
1230 assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1231 assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1232
1233 nir_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1234 bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1235 opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1236
1237 if (packed_bit_size == 16) {
1238 switch (opcode) {
1239 case SpvOpSDotKHR:
1240 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1241 break;
1242 case SpvOpUDotKHR:
1243 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1244 break;
1245 case SpvOpSDotAccSatKHR:
1246 if (dest_size == 32)
1247 dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1248 else
1249 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1250 break;
1251 case SpvOpUDotAccSatKHR:
1252 if (dest_size == 32)
1253 dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1254 else
1255 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1256 break;
1257 default:
1258 unreachable("Invalid opcode.");
1259 }
1260 } else {
1261 switch (opcode) {
1262 case SpvOpSDotKHR:
1263 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1264 break;
1265 case SpvOpUDotKHR:
1266 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1267 break;
1268 case SpvOpSUDotKHR:
1269 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1270 break;
1271 case SpvOpSDotAccSatKHR:
1272 if (dest_size == 32)
1273 dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1274 else
1275 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1276 break;
1277 case SpvOpUDotAccSatKHR:
1278 if (dest_size == 32)
1279 dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1280 else
1281 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1282 break;
1283 case SpvOpSUDotAccSatKHR:
1284 if (dest_size == 32)
1285 dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1286 else
1287 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1288 break;
1289 default:
1290 unreachable("Invalid opcode.");
1291 }
1292 }
1293
1294 if (dest_size != 32) {
1295 /* When the accumulator is 32-bits, a NIR dot-product with saturate
1296 * is generated above. In all other cases a regular dot-product is
1297 * generated above, and separate addition with saturate is generated
1298 * here.
1299 *
1300 * The SPV_KHR_integer_dot_product spec says:
1301 *
1302 * If any of the multiplications or additions, with the exception
1303 * of the final accumulation, overflow or underflow, the result of
1304 * the instruction is undefined.
1305 *
1306 * Therefore it is safe to cast the dot-product result down to the
1307 * size of the accumulator before doing the addition. Since the
1308 * result of the dot-product cannot overflow 32-bits, this is also
1309 * safe to cast up.
1310 */
1311 if (num_inputs == 3) {
1312 dest = is_signed
1313 ? nir_iadd_sat(&b->nb, nir_i2iN(&b->nb, dest, dest_size), src[2])
1314 : nir_uadd_sat(&b->nb, nir_u2uN(&b->nb, dest, dest_size), src[2]);
1315 } else {
1316 dest = is_signed
1317 ? nir_i2iN(&b->nb, dest, dest_size)
1318 : nir_u2uN(&b->nb, dest, dest_size);
1319 }
1320 }
1321 }
1322
1323 vtn_push_nir_ssa(b, w[2], dest);
1324
1325 b->nb.exact = b->exact;
1326 }
1327
1328 void
vtn_handle_bitcast(struct vtn_builder * b,const uint32_t * w,unsigned count)1329 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1330 {
1331 vtn_assert(count == 4);
1332 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1333 *
1334 * "If Result Type has the same number of components as Operand, they
1335 * must also have the same component width, and results are computed per
1336 * component.
1337 *
1338 * If Result Type has a different number of components than Operand, the
1339 * total number of bits in Result Type must equal the total number of
1340 * bits in Operand. Let L be the type, either Result Type or Operand’s
1341 * type, that has the larger number of components. Let S be the other
1342 * type, with the smaller number of components. The number of components
1343 * in L must be an integer multiple of the number of components in S.
1344 * The first component (that is, the only or lowest-numbered component)
1345 * of S maps to the first components of L, and so on, up to the last
1346 * component of S mapping to the last components of L. Within this
1347 * mapping, any single component of S (mapping to multiple components of
1348 * L) maps its lower-ordered bits to the lower-numbered components of L."
1349 */
1350
1351 struct vtn_type *type = vtn_get_type(b, w[1]);
1352 if (type->base_type == vtn_base_type_cooperative_matrix) {
1353 vtn_handle_cooperative_instruction(b, SpvOpBitcast, w, count);
1354 return;
1355 }
1356
1357 struct nir_def *src = vtn_get_nir_ssa(b, w[3]);
1358
1359 vtn_fail_if(src->num_components * src->bit_size !=
1360 glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1361 "Source (%%%u) and destination (%%%u) of OpBitcast must have the same "
1362 "total number of bits", w[3], w[2]);
1363 nir_def *val =
1364 nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1365 vtn_push_nir_ssa(b, w[2], val);
1366 }
1367