xref: /aosp_15_r20/external/mesa3d/src/compiler/spirv/vtn_alu.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27 
28 /*
29  * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30  * definition. But for matrix multiplies, we want to do one routine for
31  * multiplying a matrix by a matrix and then pretend that vectors are matrices
32  * with one column. So we "wrap" these things, and unwrap the result before we
33  * send it off.
34  */
35 
36 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39    if (val == NULL)
40       return NULL;
41 
42    if (glsl_type_is_matrix(val->type))
43       return val;
44 
45    struct vtn_ssa_value *dest = vtn_zalloc(b, struct vtn_ssa_value);
46    dest->type = glsl_get_bare_type(val->type);
47    dest->elems = vtn_alloc_array(b, struct vtn_ssa_value *, 1);
48    dest->elems[0] = val;
49 
50    return dest;
51 }
52 
53 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56    if (glsl_type_is_matrix(val->type))
57          return val;
58 
59    return val->elems[0];
60 }
61 
62 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)63 matrix_multiply(struct vtn_builder *b,
64                 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66 
67    struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68    struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69    struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70    struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71 
72    unsigned src0_rows = glsl_get_vector_elements(src0->type);
73    unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74    unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75 
76    const struct glsl_type *dest_type;
77    if (src1_columns > 1) {
78       dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79                                    src0_rows, src1_columns);
80    } else {
81       dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82    }
83    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84 
85    dest = wrap_matrix(b, dest);
86 
87    bool transpose_result = false;
88    if (src0_transpose && src1_transpose) {
89       /* transpose(A) * transpose(B) = transpose(B * A) */
90       src1 = src0_transpose;
91       src0 = src1_transpose;
92       src0_transpose = NULL;
93       src1_transpose = NULL;
94       transpose_result = true;
95    }
96 
97    for (unsigned i = 0; i < src1_columns; i++) {
98       /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
99       dest->elems[i]->def =
100          nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
101                   nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
102       for (int j = src0_columns - 2; j >= 0; j--) {
103          dest->elems[i]->def =
104             nir_ffma(&b->nb, src0->elems[j]->def,
105                              nir_channel(&b->nb, src1->elems[i]->def, j),
106                              dest->elems[i]->def);
107       }
108    }
109 
110    dest = unwrap_matrix(dest);
111 
112    if (transpose_result)
113       dest = vtn_ssa_transpose(b, dest);
114 
115    return dest;
116 }
117 
118 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_def * scalar)119 mat_times_scalar(struct vtn_builder *b,
120                  struct vtn_ssa_value *mat,
121                  nir_def *scalar)
122 {
123    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
124    for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
125       if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
126          dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
127       else
128          dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
129    }
130 
131    return dest;
132 }
133 
134 nir_def *
vtn_mediump_downconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)135 vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
136 {
137    if (def->bit_size == 16)
138       return def;
139 
140    switch (base_type) {
141    case GLSL_TYPE_FLOAT:
142       return nir_f2fmp(&b->nb, def);
143    case GLSL_TYPE_INT:
144    case GLSL_TYPE_UINT:
145       return nir_i2imp(&b->nb, def);
146    /* Workaround for 3DMark Wild Life which has RelaxedPrecision on
147     * OpLogical* operations (which is forbidden by spec).
148     */
149    case GLSL_TYPE_BOOL:
150       return def;
151    default:
152       unreachable("bad relaxed precision input type");
153    }
154 }
155 
156 struct vtn_ssa_value *
vtn_mediump_downconvert_value(struct vtn_builder * b,struct vtn_ssa_value * src)157 vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src)
158 {
159    if (!src)
160       return src;
161 
162    struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type);
163 
164    if (src->transposed) {
165       srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed);
166    } else {
167       enum glsl_base_type base_type = glsl_get_base_type(src->type);
168 
169       if (glsl_type_is_vector_or_scalar(src->type)) {
170          srcmp->def = vtn_mediump_downconvert(b, base_type, src->def);
171       } else {
172          assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT);
173          for (int i = 0; i < glsl_get_matrix_columns(src->type); i++)
174             srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def);
175       }
176    }
177 
178    return srcmp;
179 }
180 
181 static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)182 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
183                       struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
184 {
185    switch (opcode) {
186    case SpvOpFNegate: {
187       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
188       unsigned cols = glsl_get_matrix_columns(src0->type);
189       for (unsigned i = 0; i < cols; i++)
190          dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
191       return dest;
192    }
193 
194    case SpvOpFAdd: {
195       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
196       unsigned cols = glsl_get_matrix_columns(src0->type);
197       for (unsigned i = 0; i < cols; i++)
198          dest->elems[i]->def =
199             nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
200       return dest;
201    }
202 
203    case SpvOpFSub: {
204       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
205       unsigned cols = glsl_get_matrix_columns(src0->type);
206       for (unsigned i = 0; i < cols; i++)
207          dest->elems[i]->def =
208             nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
209       return dest;
210    }
211 
212    case SpvOpTranspose:
213       return vtn_ssa_transpose(b, src0);
214 
215    case SpvOpMatrixTimesScalar:
216       if (src0->transposed) {
217          return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
218                                                          src1->def));
219       } else {
220          return mat_times_scalar(b, src0, src1->def);
221       }
222       break;
223 
224    case SpvOpVectorTimesMatrix:
225    case SpvOpMatrixTimesVector:
226    case SpvOpMatrixTimesMatrix:
227       if (opcode == SpvOpVectorTimesMatrix) {
228          return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
229       } else {
230          return matrix_multiply(b, src0, src1);
231       }
232       break;
233 
234    default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
235    }
236 }
237 
238 static nir_alu_type
convert_op_src_type(SpvOp opcode)239 convert_op_src_type(SpvOp opcode)
240 {
241    switch (opcode) {
242    case SpvOpFConvert:
243    case SpvOpConvertFToS:
244    case SpvOpConvertFToU:
245       return nir_type_float;
246    case SpvOpSConvert:
247    case SpvOpConvertSToF:
248    case SpvOpSatConvertSToU:
249       return nir_type_int;
250    case SpvOpUConvert:
251    case SpvOpConvertUToF:
252    case SpvOpSatConvertUToS:
253       return nir_type_uint;
254    default:
255       unreachable("Unhandled conversion op");
256    }
257 }
258 
259 static nir_alu_type
convert_op_dst_type(SpvOp opcode)260 convert_op_dst_type(SpvOp opcode)
261 {
262    switch (opcode) {
263    case SpvOpFConvert:
264    case SpvOpConvertSToF:
265    case SpvOpConvertUToF:
266       return nir_type_float;
267    case SpvOpSConvert:
268    case SpvOpConvertFToS:
269    case SpvOpSatConvertUToS:
270       return nir_type_int;
271    case SpvOpUConvert:
272    case SpvOpConvertFToU:
273    case SpvOpSatConvertSToU:
274       return nir_type_uint;
275    default:
276       unreachable("Unhandled conversion op");
277    }
278 }
279 
280 nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,bool * exact,unsigned src_bit_size,unsigned dst_bit_size)281 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
282                                 SpvOp opcode, bool *swap, bool *exact,
283                                 unsigned src_bit_size, unsigned dst_bit_size)
284 {
285    /* Indicates that the first two arguments should be swapped.  This is
286     * used for implementing greater-than and less-than-or-equal.
287     */
288    *swap = false;
289 
290    *exact = false;
291 
292    switch (opcode) {
293    case SpvOpSNegate:            return nir_op_ineg;
294    case SpvOpFNegate:            return nir_op_fneg;
295    case SpvOpNot:                return nir_op_inot;
296    case SpvOpIAdd:               return nir_op_iadd;
297    case SpvOpFAdd:               return nir_op_fadd;
298    case SpvOpISub:               return nir_op_isub;
299    case SpvOpFSub:               return nir_op_fsub;
300    case SpvOpIMul:               return nir_op_imul;
301    case SpvOpFMul:               return nir_op_fmul;
302    case SpvOpUDiv:               return nir_op_udiv;
303    case SpvOpSDiv:               return nir_op_idiv;
304    case SpvOpFDiv:               return nir_op_fdiv;
305    case SpvOpUMod:               return nir_op_umod;
306    case SpvOpSMod:               return nir_op_imod;
307    case SpvOpFMod:               return nir_op_fmod;
308    case SpvOpSRem:               return nir_op_irem;
309    case SpvOpFRem:               return nir_op_frem;
310 
311    case SpvOpShiftRightLogical:     return nir_op_ushr;
312    case SpvOpShiftRightArithmetic:  return nir_op_ishr;
313    case SpvOpShiftLeftLogical:      return nir_op_ishl;
314    case SpvOpLogicalOr:             return nir_op_ior;
315    case SpvOpLogicalEqual:          return nir_op_ieq;
316    case SpvOpLogicalNotEqual:       return nir_op_ine;
317    case SpvOpLogicalAnd:            return nir_op_iand;
318    case SpvOpLogicalNot:            return nir_op_inot;
319    case SpvOpBitwiseOr:             return nir_op_ior;
320    case SpvOpBitwiseXor:            return nir_op_ixor;
321    case SpvOpBitwiseAnd:            return nir_op_iand;
322    case SpvOpSelect:                return nir_op_bcsel;
323    case SpvOpIEqual:                return nir_op_ieq;
324 
325    case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
326    case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
327    case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
328    case SpvOpBitReverse:            return nir_op_bitfield_reverse;
329 
330    case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
331    /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
332    case SpvOpAbsISubINTEL:          return nir_op_uabs_isub;
333    case SpvOpAbsUSubINTEL:          return nir_op_uabs_usub;
334    case SpvOpIAddSatINTEL:          return nir_op_iadd_sat;
335    case SpvOpUAddSatINTEL:          return nir_op_uadd_sat;
336    case SpvOpIAverageINTEL:         return nir_op_ihadd;
337    case SpvOpUAverageINTEL:         return nir_op_uhadd;
338    case SpvOpIAverageRoundedINTEL:  return nir_op_irhadd;
339    case SpvOpUAverageRoundedINTEL:  return nir_op_urhadd;
340    case SpvOpISubSatINTEL:          return nir_op_isub_sat;
341    case SpvOpUSubSatINTEL:          return nir_op_usub_sat;
342    case SpvOpIMul32x16INTEL:        return nir_op_imul_32x16;
343    case SpvOpUMul32x16INTEL:        return nir_op_umul_32x16;
344 
345    /* The ordered / unordered operators need special implementation besides
346     * the logical operator to use since they also need to check if operands are
347     * ordered.
348     */
349    case SpvOpFOrdEqual:                            *exact = true;  return nir_op_feq;
350    case SpvOpFUnordEqual:                          *exact = true;  return nir_op_feq;
351    case SpvOpINotEqual:                                            return nir_op_ine;
352    case SpvOpLessOrGreater:                        /* Deprecated, use OrdNotEqual */
353    case SpvOpFOrdNotEqual:                         *exact = true;  return nir_op_fneu;
354    case SpvOpFUnordNotEqual:                       *exact = true;  return nir_op_fneu;
355    case SpvOpULessThan:                                            return nir_op_ult;
356    case SpvOpSLessThan:                                            return nir_op_ilt;
357    case SpvOpFOrdLessThan:                         *exact = true;  return nir_op_flt;
358    case SpvOpFUnordLessThan:                       *exact = true;  return nir_op_flt;
359    case SpvOpUGreaterThan:          *swap = true;                  return nir_op_ult;
360    case SpvOpSGreaterThan:          *swap = true;                  return nir_op_ilt;
361    case SpvOpFOrdGreaterThan:       *swap = true;  *exact = true;  return nir_op_flt;
362    case SpvOpFUnordGreaterThan:     *swap = true;  *exact = true;  return nir_op_flt;
363    case SpvOpULessThanEqual:        *swap = true;                  return nir_op_uge;
364    case SpvOpSLessThanEqual:        *swap = true;                  return nir_op_ige;
365    case SpvOpFOrdLessThanEqual:     *swap = true;  *exact = true;  return nir_op_fge;
366    case SpvOpFUnordLessThanEqual:   *swap = true;  *exact = true;  return nir_op_fge;
367    case SpvOpUGreaterThanEqual:                                    return nir_op_uge;
368    case SpvOpSGreaterThanEqual:                                    return nir_op_ige;
369    case SpvOpFOrdGreaterThanEqual:                 *exact = true;  return nir_op_fge;
370    case SpvOpFUnordGreaterThanEqual:               *exact = true;  return nir_op_fge;
371 
372    /* Conversions: */
373    case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
374    case SpvOpUConvert:
375    case SpvOpConvertFToU:
376    case SpvOpConvertFToS:
377    case SpvOpConvertSToF:
378    case SpvOpConvertUToF:
379    case SpvOpSConvert:
380    case SpvOpFConvert: {
381       nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
382       nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
383       return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
384    }
385 
386    case SpvOpPtrCastToGeneric:   return nir_op_mov;
387    case SpvOpGenericCastToPtr:   return nir_op_mov;
388 
389    case SpvOpIsNormal:     return nir_op_fisnormal;
390    case SpvOpIsFinite:     return nir_op_fisfinite;
391 
392    default:
393       vtn_fail("No NIR equivalent: %u", opcode);
394    }
395 }
396 
397 static void
handle_fp_fast_math(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)398 handle_fp_fast_math(struct vtn_builder *b, UNUSED struct vtn_value *val,
399                  UNUSED int member, const struct vtn_decoration *dec,
400                  UNUSED void *_void)
401 {
402    vtn_assert(dec->scope == VTN_DEC_DECORATION);
403    if (dec->decoration != SpvDecorationFPFastMathMode)
404       return;
405 
406    SpvFPFastMathModeMask can_fast_math =
407       SpvFPFastMathModeAllowRecipMask |
408       SpvFPFastMathModeAllowContractMask |
409       SpvFPFastMathModeAllowReassocMask |
410       SpvFPFastMathModeAllowTransformMask;
411 
412    if ((dec->operands[0] & can_fast_math) != can_fast_math)
413       b->nb.exact = true;
414 
415    /* Decoration overrides defaults */
416    b->nb.fp_fast_math = 0;
417    if (!(dec->operands[0] & SpvFPFastMathModeNSZMask))
418       b->nb.fp_fast_math |=
419          FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP16 |
420          FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP32 |
421          FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP64;
422    if (!(dec->operands[0] & SpvFPFastMathModeNotNaNMask))
423       b->nb.fp_fast_math |=
424          FLOAT_CONTROLS_NAN_PRESERVE_FP16 |
425          FLOAT_CONTROLS_NAN_PRESERVE_FP32 |
426          FLOAT_CONTROLS_NAN_PRESERVE_FP64;
427    if (!(dec->operands[0] & SpvFPFastMathModeNotInfMask))
428       b->nb.fp_fast_math |=
429          FLOAT_CONTROLS_INF_PRESERVE_FP16 |
430          FLOAT_CONTROLS_INF_PRESERVE_FP32 |
431          FLOAT_CONTROLS_INF_PRESERVE_FP64;
432 }
433 
434 void
vtn_handle_fp_fast_math(struct vtn_builder * b,struct vtn_value * val)435 vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val)
436 {
437    /* Take the NaN/Inf/SZ preserve bits from the execution mode and set them
438     * on the builder, so the generated instructions can take it from it.
439     * We only care about some of them, check nir_alu_instr for details.
440     * We also copy all bit widths, because we can't easily get the correct one
441     * here.
442     */
443 #define FLOAT_CONTROLS2_BITS (FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP16 | \
444                               FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP32 | \
445                               FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP64)
446    static_assert(FLOAT_CONTROLS2_BITS == BITSET_MASK(9),
447       "enum float_controls and fp_fast_math out of sync!");
448    b->nb.fp_fast_math = b->shader->info.float_controls_execution_mode &
449       FLOAT_CONTROLS2_BITS;
450    vtn_foreach_decoration(b, val, handle_fp_fast_math, NULL);
451 #undef FLOAT_CONTROLS2_BITS
452 }
453 
454 static void
handle_no_contraction(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)455 handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
456                       UNUSED int member, const struct vtn_decoration *dec,
457                       UNUSED void *_void)
458 {
459    vtn_assert(dec->scope == VTN_DEC_DECORATION);
460    if (dec->decoration != SpvDecorationNoContraction)
461       return;
462 
463    b->nb.exact = true;
464 }
465 
466 void
vtn_handle_no_contraction(struct vtn_builder * b,struct vtn_value * val)467 vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
468 {
469    vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
470 }
471 
472 nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder * b,SpvFPRoundingMode mode)473 vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
474 {
475    switch (mode) {
476    case SpvFPRoundingModeRTE:
477       return nir_rounding_mode_rtne;
478    case SpvFPRoundingModeRTZ:
479       return nir_rounding_mode_rtz;
480    case SpvFPRoundingModeRTP:
481       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
482                   "FPRoundingModeRTP is only supported in kernels");
483       return nir_rounding_mode_ru;
484    case SpvFPRoundingModeRTN:
485       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
486                   "FPRoundingModeRTN is only supported in kernels");
487       return nir_rounding_mode_rd;
488    default:
489       vtn_fail("Unsupported rounding mode: %s",
490                spirv_fproundingmode_to_string(mode));
491       break;
492    }
493 }
494 
495 struct conversion_opts {
496    nir_rounding_mode rounding_mode;
497    bool saturate;
498 };
499 
500 static void
handle_conversion_opts(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _opts)501 handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
502                        UNUSED int member,
503                        const struct vtn_decoration *dec, void *_opts)
504 {
505    struct conversion_opts *opts = _opts;
506 
507    switch (dec->decoration) {
508    case SpvDecorationFPRoundingMode:
509       opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
510       break;
511 
512    case SpvDecorationSaturatedConversion:
513       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
514                   "Saturated conversions are only allowed in kernels");
515       opts->saturate = true;
516       break;
517 
518    default:
519       break;
520    }
521 }
522 
523 static void
handle_no_wrap(UNUSED struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _alu)524 handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
525                UNUSED int member,
526                const struct vtn_decoration *dec, void *_alu)
527 {
528    nir_alu_instr *alu = _alu;
529    switch (dec->decoration) {
530    case SpvDecorationNoSignedWrap:
531       alu->no_signed_wrap = true;
532       break;
533    case SpvDecorationNoUnsignedWrap:
534       alu->no_unsigned_wrap = true;
535       break;
536    default:
537       /* Do nothing. */
538       break;
539    }
540 }
541 
542 static void
vtn_value_is_relaxed_precision_cb(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * void_ctx)543 vtn_value_is_relaxed_precision_cb(struct vtn_builder *b,
544                           struct vtn_value *val, int member,
545                           const struct vtn_decoration *dec, void *void_ctx)
546 {
547    bool *relaxed_precision = void_ctx;
548    switch (dec->decoration) {
549    case SpvDecorationRelaxedPrecision:
550       *relaxed_precision = true;
551       break;
552 
553    default:
554       break;
555    }
556 }
557 
558 bool
vtn_value_is_relaxed_precision(struct vtn_builder * b,struct vtn_value * val)559 vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val)
560 {
561    bool result = false;
562    vtn_foreach_decoration(b, val,
563                           vtn_value_is_relaxed_precision_cb, &result);
564    return result;
565 }
566 
567 static bool
vtn_alu_op_mediump_16bit(struct vtn_builder * b,SpvOp opcode,struct vtn_value * dest_val)568 vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val)
569 {
570    if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val))
571       return false;
572 
573    switch (opcode) {
574    case SpvOpDPdx:
575    case SpvOpDPdy:
576    case SpvOpDPdxFine:
577    case SpvOpDPdyFine:
578    case SpvOpDPdxCoarse:
579    case SpvOpDPdyCoarse:
580    case SpvOpFwidth:
581    case SpvOpFwidthFine:
582    case SpvOpFwidthCoarse:
583       return b->options->mediump_16bit_derivatives;
584    default:
585       return true;
586    }
587 }
588 
589 static nir_def *
vtn_mediump_upconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)590 vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
591 {
592    if (def->bit_size != 16)
593       return def;
594 
595    switch (base_type) {
596    case GLSL_TYPE_FLOAT:
597       return nir_f2f32(&b->nb, def);
598    case GLSL_TYPE_INT:
599       return nir_i2i32(&b->nb, def);
600    case GLSL_TYPE_UINT:
601       return nir_u2u32(&b->nb, def);
602    default:
603       unreachable("bad relaxed precision output type");
604    }
605 }
606 
607 void
vtn_mediump_upconvert_value(struct vtn_builder * b,struct vtn_ssa_value * value)608 vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
609 {
610    enum glsl_base_type base_type = glsl_get_base_type(value->type);
611 
612    if (glsl_type_is_vector_or_scalar(value->type)) {
613       value->def = vtn_mediump_upconvert(b, base_type, value->def);
614    } else {
615       for (int i = 0; i < glsl_get_matrix_columns(value->type); i++)
616          value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def);
617    }
618 }
619 
620 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)621 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
622                const uint32_t *w, unsigned count)
623 {
624    struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
625    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
626 
627    if (glsl_type_is_cmat(dest_type)) {
628       vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count);
629       return;
630    }
631 
632    vtn_handle_no_contraction(b, dest_val);
633    vtn_handle_fp_fast_math(b, dest_val);
634    bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
635 
636    /* Collect the various SSA sources */
637    const unsigned num_inputs = count - 3;
638    struct vtn_ssa_value *vtn_src[4] = { NULL, };
639    for (unsigned i = 0; i < num_inputs; i++) {
640       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
641       if (mediump_16bit)
642          vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]);
643    }
644 
645    if (glsl_type_is_matrix(vtn_src[0]->type) ||
646        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
647       struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]);
648 
649       if (mediump_16bit)
650          vtn_mediump_upconvert_value(b, dest);
651 
652       vtn_push_ssa_value(b, w[2], dest);
653       b->nb.exact = b->exact;
654       return;
655    }
656 
657    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
658    nir_def *src[4] = { NULL, };
659    for (unsigned i = 0; i < num_inputs; i++) {
660       vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
661       src[i] = vtn_src[i]->def;
662    }
663 
664    switch (opcode) {
665    case SpvOpAny:
666       dest->def = nir_bany(&b->nb, src[0]);
667       break;
668 
669    case SpvOpAll:
670       dest->def = nir_ball(&b->nb, src[0]);
671       break;
672 
673    case SpvOpOuterProduct: {
674       for (unsigned i = 0; i < src[1]->num_components; i++) {
675          dest->elems[i]->def =
676             nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
677       }
678       break;
679    }
680 
681    case SpvOpDot:
682       dest->def = nir_fdot(&b->nb, src[0], src[1]);
683       break;
684 
685    case SpvOpIAddCarry:
686       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
687       dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
688       dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
689       break;
690 
691    case SpvOpISubBorrow:
692       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
693       dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
694       dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
695       break;
696 
697    case SpvOpUMulExtended: {
698       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
699       if (src[0]->bit_size == 32) {
700          nir_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
701          dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
702          dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
703       } else {
704          dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
705          dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
706       }
707       break;
708    }
709 
710    case SpvOpSMulExtended: {
711       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
712       if (src[0]->bit_size == 32) {
713          nir_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
714          dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
715          dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
716       } else {
717          dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
718          dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
719       }
720       break;
721    }
722 
723    case SpvOpDPdx:
724       dest->def = nir_ddx(&b->nb, src[0]);
725       break;
726    case SpvOpDPdxFine:
727       dest->def = nir_ddx_fine(&b->nb, src[0]);
728       break;
729    case SpvOpDPdxCoarse:
730       dest->def = nir_ddx_coarse(&b->nb, src[0]);
731       break;
732    case SpvOpDPdy:
733       dest->def = nir_ddy(&b->nb, src[0]);
734       break;
735    case SpvOpDPdyFine:
736       dest->def = nir_ddy_fine(&b->nb, src[0]);
737       break;
738    case SpvOpDPdyCoarse:
739       dest->def = nir_ddy_coarse(&b->nb, src[0]);
740       break;
741 
742    case SpvOpFwidth:
743       dest->def = nir_fadd(&b->nb,
744                                nir_fabs(&b->nb, nir_ddx(&b->nb, src[0])),
745                                nir_fabs(&b->nb, nir_ddy(&b->nb, src[0])));
746       break;
747    case SpvOpFwidthFine:
748       dest->def = nir_fadd(&b->nb,
749                                nir_fabs(&b->nb, nir_ddx_fine(&b->nb, src[0])),
750                                nir_fabs(&b->nb, nir_ddy_fine(&b->nb, src[0])));
751       break;
752    case SpvOpFwidthCoarse:
753       dest->def = nir_fadd(&b->nb,
754                                nir_fabs(&b->nb, nir_ddx_coarse(&b->nb, src[0])),
755                                nir_fabs(&b->nb, nir_ddy_coarse(&b->nb, src[0])));
756       break;
757 
758    case SpvOpVectorTimesScalar:
759       /* The builder will take care of splatting for us. */
760       dest->def = nir_fmul(&b->nb, src[0], src[1]);
761       break;
762 
763    case SpvOpIsNan: {
764       const bool save_exact = b->nb.exact;
765 
766       b->nb.exact = true;
767       dest->def = nir_fneu(&b->nb, src[0], src[0]);
768       b->nb.exact = save_exact;
769       break;
770    }
771 
772    case SpvOpOrdered: {
773       const bool save_exact = b->nb.exact;
774 
775       b->nb.exact = true;
776       dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
777                                    nir_feq(&b->nb, src[1], src[1]));
778       b->nb.exact = save_exact;
779       break;
780    }
781 
782    case SpvOpUnordered: {
783       const bool save_exact = b->nb.exact;
784 
785       b->nb.exact = true;
786       dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
787                                   nir_fneu(&b->nb, src[1], src[1]));
788       b->nb.exact = save_exact;
789       break;
790    }
791 
792    case SpvOpIsInf: {
793       nir_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
794       dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
795       break;
796    }
797 
798    case SpvOpFUnordEqual: {
799       const bool save_exact = b->nb.exact;
800 
801       b->nb.exact = true;
802 
803       /* This could also be implemented as !(a < b || b < a).  If one or both
804        * of the source are numbers, later optimization passes can easily
805        * eliminate the isnan() checks.  This may trim the sequence down to a
806        * single (a == b) operation.  Otherwise, the optimizer can transform
807        * whatever is left to !(a < b || b < a).  Since some applications will
808        * open-code this sequence, these optimizations are needed anyway.
809        */
810       dest->def =
811          nir_ior(&b->nb,
812                  nir_feq(&b->nb, src[0], src[1]),
813                  nir_ior(&b->nb,
814                          nir_fneu(&b->nb, src[0], src[0]),
815                          nir_fneu(&b->nb, src[1], src[1])));
816 
817       b->nb.exact = save_exact;
818       break;
819    }
820 
821    case SpvOpFUnordLessThan:
822    case SpvOpFUnordGreaterThan:
823    case SpvOpFUnordLessThanEqual:
824    case SpvOpFUnordGreaterThanEqual: {
825       bool swap;
826       bool unused_exact;
827       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
828       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
829       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
830                                                   &unused_exact,
831                                                   src_bit_size, dst_bit_size);
832 
833       if (swap) {
834          nir_def *tmp = src[0];
835          src[0] = src[1];
836          src[1] = tmp;
837       }
838 
839       const bool save_exact = b->nb.exact;
840 
841       b->nb.exact = true;
842 
843       /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
844       switch (op) {
845       case nir_op_fge: op = nir_op_flt; break;
846       case nir_op_flt: op = nir_op_fge; break;
847       default: unreachable("Impossible opcode.");
848       }
849 
850       dest->def =
851          nir_inot(&b->nb,
852                   nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
853 
854       b->nb.exact = save_exact;
855       break;
856    }
857 
858    case SpvOpLessOrGreater:
859    case SpvOpFOrdNotEqual: {
860       /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
861        * from the ALU will probably already be false if the operands are not
862        * ordered so we don’t need to handle it specially.
863        */
864       const bool save_exact = b->nb.exact;
865 
866       b->nb.exact = true;
867 
868       /* This could also be implemented as (a < b || b < a).  If one or both
869        * of the source are numbers, later optimization passes can easily
870        * eliminate the isnan() checks.  This may trim the sequence down to a
871        * single (a != b) operation.  Otherwise, the optimizer can transform
872        * whatever is left to (a < b || b < a).  Since some applications will
873        * open-code this sequence, these optimizations are needed anyway.
874        */
875       dest->def =
876          nir_iand(&b->nb,
877                   nir_fneu(&b->nb, src[0], src[1]),
878                   nir_iand(&b->nb,
879                           nir_feq(&b->nb, src[0], src[0]),
880                           nir_feq(&b->nb, src[1], src[1])));
881 
882       b->nb.exact = save_exact;
883       break;
884    }
885 
886    case SpvOpUConvert:
887    case SpvOpConvertFToU:
888    case SpvOpConvertFToS:
889    case SpvOpConvertSToF:
890    case SpvOpConvertUToF:
891    case SpvOpSConvert:
892    case SpvOpFConvert:
893    case SpvOpSatConvertSToU:
894    case SpvOpSatConvertUToS: {
895       unsigned src_bit_size = src[0]->bit_size;
896       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
897       nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
898       nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
899 
900       struct conversion_opts opts = {
901          .rounding_mode = nir_rounding_mode_undef,
902          .saturate = false,
903       };
904       vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
905 
906       if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
907          opts.saturate = true;
908 
909       if (b->shader->info.stage == MESA_SHADER_KERNEL) {
910          if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
911             dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
912                                          nir_rounding_mode_undef);
913          } else {
914             dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
915                                               src_type, dst_type,
916                                               opts.rounding_mode, opts.saturate);
917          }
918       } else {
919          vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
920                      dst_type != nir_type_float16,
921                      "Rounding modes are only allowed on conversions to "
922                      "16-bit float types");
923          dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
924                                       opts.rounding_mode);
925       }
926       break;
927    }
928 
929    case SpvOpBitFieldInsert:
930    case SpvOpBitFieldSExtract:
931    case SpvOpBitFieldUExtract:
932    case SpvOpShiftLeftLogical:
933    case SpvOpShiftRightArithmetic:
934    case SpvOpShiftRightLogical: {
935       bool swap;
936       bool exact;
937       unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
938       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
939       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
940                                                   src0_bit_size, dst_bit_size);
941 
942       assert(!exact);
943 
944       assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
945               op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
946               op == nir_op_ibitfield_extract);
947 
948       for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
949          unsigned src_bit_size =
950             nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
951          if (src_bit_size == 0)
952             continue;
953          if (src_bit_size != src[i]->bit_size) {
954             assert(src_bit_size == 32);
955             /* Convert the Shift, Offset and Count  operands to 32 bits, which is the bitsize
956              * supported by the NIR instructions. See discussion here:
957              *
958              * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
959              */
960             src[i] = nir_u2u32(&b->nb, src[i]);
961          }
962       }
963       dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
964       break;
965    }
966 
967    case SpvOpSignBitSet:
968       dest->def = nir_i2b(&b->nb,
969          nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
970       break;
971 
972    case SpvOpUCountTrailingZerosINTEL:
973       dest->def = nir_umin(&b->nb,
974                                nir_find_lsb(&b->nb, src[0]),
975                                nir_imm_int(&b->nb, 32u));
976       break;
977 
978    case SpvOpBitCount: {
979       /* bit_count always returns int32, but the SPIR-V opcode just says the return
980        * value needs to be big enough to store the number of bits.
981        */
982       dest->def = nir_u2uN(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
983       break;
984    }
985 
986    case SpvOpSDotKHR:
987    case SpvOpUDotKHR:
988    case SpvOpSUDotKHR:
989    case SpvOpSDotAccSatKHR:
990    case SpvOpUDotAccSatKHR:
991    case SpvOpSUDotAccSatKHR:
992       unreachable("Should have called vtn_handle_integer_dot instead.");
993 
994    default: {
995       bool swap;
996       bool exact;
997       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
998       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
999       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
1000                                                   &exact,
1001                                                   src_bit_size, dst_bit_size);
1002 
1003       if (swap) {
1004          nir_def *tmp = src[0];
1005          src[0] = src[1];
1006          src[1] = tmp;
1007       }
1008 
1009       switch (op) {
1010       case nir_op_ishl:
1011       case nir_op_ishr:
1012       case nir_op_ushr:
1013          if (src[1]->bit_size != 32)
1014             src[1] = nir_u2u32(&b->nb, src[1]);
1015          break;
1016       default:
1017          break;
1018       }
1019 
1020       const bool save_exact = b->nb.exact;
1021 
1022       if (exact)
1023          b->nb.exact = true;
1024 
1025       dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
1026 
1027       b->nb.exact = save_exact;
1028       break;
1029    } /* default */
1030    }
1031 
1032    switch (opcode) {
1033    case SpvOpIAdd:
1034    case SpvOpIMul:
1035    case SpvOpISub:
1036    case SpvOpShiftLeftLogical:
1037    case SpvOpSNegate: {
1038       nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
1039       vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
1040       break;
1041    }
1042    default:
1043       /* Do nothing. */
1044       break;
1045    }
1046 
1047    if (mediump_16bit)
1048       vtn_mediump_upconvert_value(b, dest);
1049    vtn_push_ssa_value(b, w[2], dest);
1050 
1051    b->nb.exact = b->exact;
1052 }
1053 
1054 void
vtn_handle_integer_dot(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)1055 vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
1056                        const uint32_t *w, unsigned count)
1057 {
1058    struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
1059    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
1060    const unsigned dest_size = glsl_get_bit_size(dest_type);
1061 
1062    vtn_handle_no_contraction(b, dest_val);
1063 
1064    /* Collect the various SSA sources.
1065     *
1066     * Due to the optional "Packed Vector Format" field, determine number of
1067     * inputs from the opcode.  This differs from vtn_handle_alu.
1068     */
1069    const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
1070                                 opcode == SpvOpUDotAccSatKHR ||
1071                                 opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
1072 
1073    vtn_assert(count >= num_inputs + 3);
1074 
1075    struct vtn_ssa_value *vtn_src[3] = { NULL, };
1076    nir_def *src[3] = { NULL, };
1077 
1078    for (unsigned i = 0; i < num_inputs; i++) {
1079       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
1080       src[i] = vtn_src[i]->def;
1081 
1082       vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
1083    }
1084 
1085    /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
1086     * the SPV_KHR_integer_dot_product spec says:
1087     *
1088     *    _Vector 1_ and _Vector 2_ must have the same type.
1089     *
1090     * The practical requirement is the same bit-size and the same number of
1091     * components.
1092     */
1093    vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
1094                glsl_get_bit_size(vtn_src[1]->type) ||
1095                glsl_get_vector_elements(vtn_src[0]->type) !=
1096                glsl_get_vector_elements(vtn_src[1]->type),
1097                "Vector 1 and vector 2 source of opcode %s must have the same "
1098                "type",
1099                spirv_op_to_string(opcode));
1100 
1101    if (num_inputs == 3) {
1102       /* The SPV_KHR_integer_dot_product spec says:
1103        *
1104        *    The type of Accumulator must be the same as Result Type.
1105        *
1106        * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
1107        * types (far below) assumes these types have the same size.
1108        */
1109       vtn_fail_if(dest_type != vtn_src[2]->type,
1110                   "Accumulator type must be the same as Result Type for "
1111                   "opcode %s",
1112                   spirv_op_to_string(opcode));
1113    }
1114 
1115    unsigned packed_bit_size = 8;
1116    if (glsl_type_is_vector(vtn_src[0]->type)) {
1117       /* FINISHME: Is this actually as good or better for platforms that don't
1118        * have the special instructions (i.e., one or both of has_dot_4x8 or
1119        * has_sudot_4x8 is false)?
1120        */
1121       if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
1122           glsl_get_bit_size(vtn_src[0]->type) == 8 &&
1123           glsl_get_bit_size(dest_type) <= 32) {
1124          src[0] = nir_pack_32_4x8(&b->nb, src[0]);
1125          src[1] = nir_pack_32_4x8(&b->nb, src[1]);
1126       } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
1127                  glsl_get_bit_size(vtn_src[0]->type) == 16 &&
1128                  glsl_get_bit_size(dest_type) <= 32 &&
1129                  opcode != SpvOpSUDotKHR &&
1130                  opcode != SpvOpSUDotAccSatKHR) {
1131          src[0] = nir_pack_32_2x16(&b->nb, src[0]);
1132          src[1] = nir_pack_32_2x16(&b->nb, src[1]);
1133          packed_bit_size = 16;
1134       }
1135    } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
1136               glsl_type_is_32bit(vtn_src[0]->type)) {
1137       /* The SPV_KHR_integer_dot_product spec says:
1138        *
1139        *    When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
1140        *    Vector Format_ must be specified to select how the integers are to
1141        *    be interpreted as vectors.
1142        *
1143        * The "Packed Vector Format" value follows the last input.
1144        */
1145       vtn_assert(count == (num_inputs + 4));
1146       const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
1147       vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
1148                   "Unsupported vector packing format %d for opcode %s",
1149                   pack_format, spirv_op_to_string(opcode));
1150    } else {
1151       vtn_fail_with_opcode("Invalid source types.", opcode);
1152    }
1153 
1154    nir_def *dest = NULL;
1155 
1156    if (src[0]->num_components > 1) {
1157       nir_def *(*src0_conversion)(nir_builder *, nir_def *, unsigned);
1158       nir_def *(*src1_conversion)(nir_builder *, nir_def *, unsigned);
1159 
1160       switch (opcode) {
1161       case SpvOpSDotKHR:
1162       case SpvOpSDotAccSatKHR:
1163          src0_conversion = nir_i2iN;
1164          src1_conversion = nir_i2iN;
1165          break;
1166 
1167       case SpvOpUDotKHR:
1168       case SpvOpUDotAccSatKHR:
1169          src0_conversion = nir_u2uN;
1170          src1_conversion = nir_u2uN;
1171          break;
1172 
1173       case SpvOpSUDotKHR:
1174       case SpvOpSUDotAccSatKHR:
1175          src0_conversion = nir_i2iN;
1176          src1_conversion = nir_u2uN;
1177          break;
1178 
1179       default:
1180          unreachable("Invalid opcode.");
1181       }
1182 
1183       /* The SPV_KHR_integer_dot_product spec says:
1184        *
1185        *    All components of the input vectors are sign-extended to the bit
1186        *    width of the result's type. The sign-extended input vectors are
1187        *    then multiplied component-wise and all components of the vector
1188        *    resulting from the component-wise multiplication are added
1189        *    together. The resulting value will equal the low-order N bits of
1190        *    the correct result R, where N is the result width and R is
1191        *    computed with enough precision to avoid overflow and underflow.
1192        */
1193       const unsigned vector_components =
1194          glsl_get_vector_elements(vtn_src[0]->type);
1195 
1196       for (unsigned i = 0; i < vector_components; i++) {
1197          nir_def *const src0 =
1198             src0_conversion(&b->nb, nir_channel(&b->nb, src[0], i), dest_size);
1199 
1200          nir_def *const src1 =
1201             src1_conversion(&b->nb, nir_channel(&b->nb, src[1], i), dest_size);
1202 
1203          nir_def *const mul_result = nir_imul(&b->nb, src0, src1);
1204 
1205          dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1206       }
1207 
1208       if (num_inputs == 3) {
1209          /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1210           *
1211           *    Signed integer dot product of _Vector 1_ and _Vector 2_ and
1212           *    signed saturating addition of the result with _Accumulator_.
1213           *
1214           * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1215           *
1216           *    Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1217           *    unsigned saturating addition of the result with _Accumulator_.
1218           *
1219           * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1220           *
1221           *    Mixed-signedness integer dot product of _Vector 1_ and _Vector
1222           *    2_ and signed saturating addition of the result with
1223           *    _Accumulator_.
1224           */
1225          dest = (opcode == SpvOpUDotAccSatKHR)
1226             ? nir_uadd_sat(&b->nb, dest, src[2])
1227             : nir_iadd_sat(&b->nb, dest, src[2]);
1228       }
1229    } else {
1230       assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1231       assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1232 
1233       nir_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1234       bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1235                        opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1236 
1237       if (packed_bit_size == 16) {
1238          switch (opcode) {
1239          case SpvOpSDotKHR:
1240             dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1241             break;
1242          case SpvOpUDotKHR:
1243             dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1244             break;
1245          case SpvOpSDotAccSatKHR:
1246             if (dest_size == 32)
1247                dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1248             else
1249                dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1250             break;
1251          case SpvOpUDotAccSatKHR:
1252             if (dest_size == 32)
1253                dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1254             else
1255                dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1256             break;
1257          default:
1258             unreachable("Invalid opcode.");
1259          }
1260       } else {
1261          switch (opcode) {
1262          case SpvOpSDotKHR:
1263             dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1264             break;
1265          case SpvOpUDotKHR:
1266             dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1267             break;
1268          case SpvOpSUDotKHR:
1269             dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1270             break;
1271          case SpvOpSDotAccSatKHR:
1272             if (dest_size == 32)
1273                dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1274             else
1275                dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1276             break;
1277          case SpvOpUDotAccSatKHR:
1278             if (dest_size == 32)
1279                dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1280             else
1281                dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1282             break;
1283          case SpvOpSUDotAccSatKHR:
1284             if (dest_size == 32)
1285                dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1286             else
1287                dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1288             break;
1289          default:
1290             unreachable("Invalid opcode.");
1291          }
1292       }
1293 
1294       if (dest_size != 32) {
1295          /* When the accumulator is 32-bits, a NIR dot-product with saturate
1296           * is generated above.  In all other cases a regular dot-product is
1297           * generated above, and separate addition with saturate is generated
1298           * here.
1299           *
1300           * The SPV_KHR_integer_dot_product spec says:
1301           *
1302           *    If any of the multiplications or additions, with the exception
1303           *    of the final accumulation, overflow or underflow, the result of
1304           *    the instruction is undefined.
1305           *
1306           * Therefore it is safe to cast the dot-product result down to the
1307           * size of the accumulator before doing the addition.  Since the
1308           * result of the dot-product cannot overflow 32-bits, this is also
1309           * safe to cast up.
1310           */
1311          if (num_inputs == 3) {
1312             dest = is_signed
1313                ? nir_iadd_sat(&b->nb, nir_i2iN(&b->nb, dest, dest_size), src[2])
1314                : nir_uadd_sat(&b->nb, nir_u2uN(&b->nb, dest, dest_size), src[2]);
1315          } else {
1316             dest = is_signed
1317                ? nir_i2iN(&b->nb, dest, dest_size)
1318                : nir_u2uN(&b->nb, dest, dest_size);
1319          }
1320       }
1321    }
1322 
1323    vtn_push_nir_ssa(b, w[2], dest);
1324 
1325    b->nb.exact = b->exact;
1326 }
1327 
1328 void
vtn_handle_bitcast(struct vtn_builder * b,const uint32_t * w,unsigned count)1329 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1330 {
1331    vtn_assert(count == 4);
1332    /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1333     *
1334     *    "If Result Type has the same number of components as Operand, they
1335     *    must also have the same component width, and results are computed per
1336     *    component.
1337     *
1338     *    If Result Type has a different number of components than Operand, the
1339     *    total number of bits in Result Type must equal the total number of
1340     *    bits in Operand. Let L be the type, either Result Type or Operand’s
1341     *    type, that has the larger number of components. Let S be the other
1342     *    type, with the smaller number of components. The number of components
1343     *    in L must be an integer multiple of the number of components in S.
1344     *    The first component (that is, the only or lowest-numbered component)
1345     *    of S maps to the first components of L, and so on, up to the last
1346     *    component of S mapping to the last components of L. Within this
1347     *    mapping, any single component of S (mapping to multiple components of
1348     *    L) maps its lower-ordered bits to the lower-numbered components of L."
1349     */
1350 
1351    struct vtn_type *type = vtn_get_type(b, w[1]);
1352    if (type->base_type == vtn_base_type_cooperative_matrix) {
1353       vtn_handle_cooperative_instruction(b, SpvOpBitcast, w, count);
1354       return;
1355    }
1356 
1357    struct nir_def *src = vtn_get_nir_ssa(b, w[3]);
1358 
1359    vtn_fail_if(src->num_components * src->bit_size !=
1360                glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1361                "Source (%%%u) and destination (%%%u) of OpBitcast must have the same "
1362                "total number of bits", w[3], w[2]);
1363    nir_def *val =
1364       nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1365    vtn_push_nir_ssa(b, w[2], val);
1366 }
1367