xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_opcodes.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1#
2# Copyright (C) 2014 Connor Abbott
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22#
23# Authors:
24#    Connor Abbott ([email protected])
25
26import re
27
28# Class that represents all the information we have about the opcode
29# NOTE: this must be kept in sync with nir_op_info
30
31class Opcode(object):
32   """Class that represents all the information we have about the opcode
33   NOTE: this must be kept in sync with nir_op_info
34   """
35   def __init__(self, name, output_size, output_type, input_sizes,
36                input_types, is_conversion, algebraic_properties, const_expr,
37                description):
38      """Parameters:
39
40      - name is the name of the opcode (prepend nir_op_ for the enum name)
41      - all types are strings that get nir_type_ prepended to them
42      - input_types is a list of types
43      - is_conversion is true if this opcode represents a type conversion
44      - algebraic_properties is a space-seperated string, where nir_op_is_ is
45        prepended before each entry
46      - const_expr is an expression or series of statements that computes the
47        constant value of the opcode given the constant values of its inputs.
48      - Optional description of the opcode for documentation.
49
50      Constant expressions are formed from the variables src0, src1, ...,
51      src(N-1), where N is the number of arguments.  The output of the
52      expression should be stored in the dst variable.  Per-component input
53      and output variables will be scalars and non-per-component input and
54      output variables will be a struct with fields named x, y, z, and w
55      all of the correct type.  Input and output variables can be assumed
56      to already be of the correct type and need no conversion.  In
57      particular, the conversion from the C bool type to/from  NIR_TRUE and
58      NIR_FALSE happens automatically.
59
60      For per-component instructions, the entire expression will be
61      executed once for each component.  For non-per-component
62      instructions, the expression is expected to store the correct values
63      in dst.x, dst.y, etc.  If "dst" does not exist anywhere in the
64      constant expression, an assignment to dst will happen automatically
65      and the result will be equivalent to "dst = <expression>" for
66      per-component instructions and "dst.x = dst.y = ... = <expression>"
67      for non-per-component instructions.
68      """
69      assert isinstance(name, str)
70      assert isinstance(output_size, int)
71      assert isinstance(output_type, str)
72      assert isinstance(input_sizes, list)
73      assert isinstance(input_sizes[0], int)
74      assert isinstance(input_types, list)
75      assert isinstance(input_types[0], str)
76      assert isinstance(is_conversion, bool)
77      assert isinstance(algebraic_properties, str)
78      assert isinstance(const_expr, str)
79      assert len(input_sizes) == len(input_types)
80      assert 0 <= output_size <= 5 or (output_size == 8) or (output_size == 16)
81      for size in input_sizes:
82         assert 0 <= size <= 5 or (size == 8) or (size == 16)
83         if output_size != 0:
84            assert size != 0
85      self.name = name
86      self.num_inputs = len(input_sizes)
87      self.output_size = output_size
88      self.output_type = output_type
89      self.input_sizes = input_sizes
90      self.input_types = input_types
91      self.is_conversion = is_conversion
92      self.algebraic_properties = algebraic_properties
93      self.const_expr = const_expr
94      self.description = description
95
96# helper variables for strings
97tfloat = "float"
98tint = "int"
99tbool = "bool"
100tbool1 = "bool1"
101tbool8 = "bool8"
102tbool16 = "bool16"
103tbool32 = "bool32"
104tuint = "uint"
105tuint8 = "uint8"
106tint16 = "int16"
107tuint16 = "uint16"
108tfloat16 = "float16"
109tfloat32 = "float32"
110tint32 = "int32"
111tuint32 = "uint32"
112tint64 = "int64"
113tuint64 = "uint64"
114tfloat64 = "float64"
115
116_TYPE_SPLIT_RE = re.compile(r'(?P<type>int|uint|float|bool)(?P<bits>\d+)?')
117
118def type_has_size(type_):
119    m = _TYPE_SPLIT_RE.match(type_)
120    assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
121    return m.group('bits') is not None
122
123def type_size(type_):
124    m = _TYPE_SPLIT_RE.match(type_)
125    assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
126    assert m.group('bits') is not None, \
127           'NIR type string has no bit size: "{}"'.format(type_)
128    return int(m.group('bits'))
129
130def type_sizes(type_):
131    if type_has_size(type_):
132        return [type_size(type_)]
133    elif type_ == 'bool':
134        return [1, 8, 16, 32]
135    elif type_ == 'float':
136        return [16, 32, 64]
137    else:
138        return [1, 8, 16, 32, 64]
139
140def type_base_type(type_):
141    m = _TYPE_SPLIT_RE.match(type_)
142    assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
143    return m.group('type')
144
145# Operation where the first two sources are commutative.
146#
147# For 2-source operations, this just mathematical commutativity.  Some
148# 3-source operations, like ffma, are only commutative in the first two
149# sources.
150_2src_commutative = "2src_commutative "
151associative = "associative "
152selection = "selection "
153derivative = "derivative "
154
155# global dictionary of opcodes
156opcodes = {}
157
158def opcode(name, output_size, output_type, input_sizes, input_types,
159           is_conversion, algebraic_properties, const_expr, description = ""):
160   assert name not in opcodes
161   opcodes[name] = Opcode(name, output_size, output_type, input_sizes,
162                          input_types, is_conversion, algebraic_properties,
163                          const_expr, description)
164
165def unop_convert(name, out_type, in_type, const_expr, description = ""):
166   opcode(name, 0, out_type, [0], [in_type], False, "", const_expr, description)
167
168def unop(name, ty, const_expr, description = "", algebraic_properties = ""):
169   opcode(name, 0, ty, [0], [ty], False, algebraic_properties, const_expr,
170          description)
171
172def unop_horiz(name, output_size, output_type, input_size, input_type,
173               const_expr, description = ""):
174   opcode(name, output_size, output_type, [input_size], [input_type],
175          False, "", const_expr, description)
176
177def unop_reduce(name, output_size, output_type, input_type, prereduce_expr,
178                reduce_expr, final_expr, description = ""):
179   def prereduce(src):
180      return "(" + prereduce_expr.format(src=src) + ")"
181   def final(src):
182      return final_expr.format(src="(" + src + ")")
183   def reduce_(src0, src1):
184      return reduce_expr.format(src0=src0, src1=src1)
185   src0 = prereduce("src0.x")
186   src1 = prereduce("src0.y")
187   src2 = prereduce("src0.z")
188   src3 = prereduce("src0.w")
189   unop_horiz(name + "2", output_size, output_type, 2, input_type,
190              final(reduce_(src0, src1)), description)
191   unop_horiz(name + "3", output_size, output_type, 3, input_type,
192              final(reduce_(reduce_(src0, src1), src2)), description)
193   unop_horiz(name + "4", output_size, output_type, 4, input_type,
194              final(reduce_(reduce_(src0, src1), reduce_(src2, src3))),
195              description)
196
197def unop_numeric_convert(name, out_type, in_type, const_expr, description = ""):
198   opcode(name, 0, out_type, [0], [in_type], True, "", const_expr, description)
199
200unop("mov", tuint, "src0")
201
202unop("ineg", tint, "src0 == u_intN_min(bit_size) ? src0 : -src0")
203unop("fneg", tfloat, "-src0")
204unop("inot", tint, "~src0", description = "Invert every bit of the integer")
205
206unop("fsign", tfloat, ("bit_size == 64 ? " +
207                       "(isnan(src0) ? 0.0  : ((src0 == 0.0 ) ? src0 : (src0 > 0.0 ) ? 1.0  : -1.0 )) : " +
208                       "(isnan(src0) ? 0.0f : ((src0 == 0.0f) ? src0 : (src0 > 0.0f) ? 1.0f : -1.0f))"),
209     description = """
210Roughly implements the OpenGL / Vulkan rules for ``sign(float)``.
211The ``GLSL.std.450 FSign`` instruction is defined as:
212
213    Result is 1.0 if x > 0, 0.0 if x = 0, or -1.0 if x < 0.
214
215If the source is equal to zero, there is a preference for the result to have
216the same sign, but this is not required (it is required by OpenCL).  If the
217source is not a number, there is a preference for the result to be +0.0, but
218this is not required (it is required by OpenCL).  If the source is not a
219number, and the result is not +0.0, the result should definitely **not** be
220NaN.
221
222The values returned for constant folding match the behavior required by
223OpenCL.
224     """)
225
226unop("isign", tint, "(src0 == 0) ? 0 : ((src0 > 0) ? 1 : -1)")
227unop("iabs", tint, "(src0 < 0) ? -src0 : src0")
228unop("fabs", tfloat, "fabs(src0)")
229unop("fsat", tfloat, ("fmin(fmax(src0, 0.0), 1.0)"))
230unop("frcp", tfloat, "bit_size == 64 ? 1.0 / src0 : 1.0f / src0")
231unop("frsq", tfloat, "bit_size == 64 ? 1.0 / sqrt(src0) : 1.0f / sqrtf(src0)")
232unop("fsqrt", tfloat, "bit_size == 64 ? sqrt(src0) : sqrtf(src0)")
233unop("fexp2", tfloat, "exp2f(src0)")
234unop("flog2", tfloat, "log2f(src0)")
235
236# Generate all of the numeric conversion opcodes
237for src_t in [tint, tuint, tfloat, tbool]:
238   if src_t == tbool:
239      dst_types = [tfloat, tint, tbool]
240   elif src_t == tint:
241      dst_types = [tfloat, tint]
242   elif src_t == tuint:
243      dst_types = [tfloat, tuint]
244   elif src_t == tfloat:
245      dst_types = [tint, tuint, tfloat]
246
247   for dst_t in dst_types:
248      for dst_bit_size in type_sizes(dst_t):
249          if dst_bit_size == 16 and dst_t == tfloat and src_t == tfloat:
250              rnd_modes = ['_rtne', '_rtz', '']
251              for rnd_mode in rnd_modes:
252                  if rnd_mode == '_rtne':
253                      conv_expr = """
254                      if (bit_size > 32) {
255                         dst = _mesa_half_to_float(_mesa_double_to_float16_rtne(src0));
256                      } else if (bit_size > 16) {
257                         dst = _mesa_half_to_float(_mesa_float_to_float16_rtne(src0));
258                      } else {
259                         dst = src0;
260                      }
261                      """
262                  elif rnd_mode == '_rtz':
263                      conv_expr = """
264                      if (bit_size > 32) {
265                         dst = _mesa_half_to_float(_mesa_double_to_float16_rtz(src0));
266                      } else if (bit_size > 16) {
267                         dst = _mesa_half_to_float(_mesa_float_to_float16_rtz(src0));
268                      } else {
269                         dst = src0;
270                      }
271                      """
272                  else:
273                      conv_expr = """
274                      if (bit_size > 32) {
275                         if (nir_is_rounding_mode_rtz(execution_mode, 16))
276                            dst = _mesa_half_to_float(_mesa_double_to_float16_rtz(src0));
277                         else
278                            dst = _mesa_half_to_float(_mesa_double_to_float16_rtne(src0));
279                      } else if (bit_size > 16) {
280                         if (nir_is_rounding_mode_rtz(execution_mode, 16))
281                            dst = _mesa_half_to_float(_mesa_float_to_float16_rtz(src0));
282                         else
283                            dst = _mesa_half_to_float(_mesa_float_to_float16_rtne(src0));
284                      } else {
285                         dst = src0;
286                      }
287                      """
288
289                  unop_numeric_convert("{0}2{1}{2}{3}".format(src_t[0],
290                                                              dst_t[0],
291                                                              dst_bit_size,
292                                                              rnd_mode),
293                                       dst_t + str(dst_bit_size),
294                                       src_t, conv_expr)
295          elif dst_bit_size == 32 and dst_t == tfloat and src_t == tfloat:
296              conv_expr = """
297              if (bit_size > 32 && nir_is_rounding_mode_rtz(execution_mode, 32)) {
298                 dst = _mesa_double_to_float_rtz(src0);
299              } else {
300                 dst = src0;
301              }
302              """
303              unop_numeric_convert("{0}2{1}{2}".format(src_t[0], dst_t[0],
304                                                       dst_bit_size),
305                                   dst_t + str(dst_bit_size), src_t, conv_expr)
306          else:
307              conv_expr = "src0 != 0" if dst_t == tbool else "src0"
308              unop_numeric_convert("{0}2{1}{2}".format(src_t[0], dst_t[0],
309                                                       dst_bit_size),
310                                   dst_t + str(dst_bit_size), src_t, conv_expr)
311
312def unop_numeric_convert_mp(base, src_t, dst_t):
313    op_like = base + "16"
314    unop_numeric_convert(base + "mp", src_t, dst_t, opcodes[op_like].const_expr,
315                         description = """
316Special opcode that is the same as :nir:alu-op:`{}` except that it is safe to
317remove it if the result is immediately converted back to 32 bits again. This is
318generated as part of the precision lowering pass. ``mp`` stands for medium
319precision.
320                         """.format(op_like))
321
322unop_numeric_convert_mp("f2f", tfloat16, tfloat32)
323unop_numeric_convert_mp("i2i", tint16, tint32)
324# u2ump isn't defined, because the behavior is equal to i2imp
325unop_numeric_convert_mp("f2i", tint16, tfloat32)
326unop_numeric_convert_mp("f2u", tuint16, tfloat32)
327unop_numeric_convert_mp("i2f", tfloat16, tint32)
328unop_numeric_convert_mp("u2f", tfloat16, tuint32)
329
330# Unary floating-point rounding operations.
331
332
333unop("ftrunc", tfloat, "bit_size == 64 ? trunc(src0) : truncf(src0)")
334unop("fceil", tfloat, "bit_size == 64 ? ceil(src0) : ceilf(src0)")
335unop("ffloor", tfloat, "bit_size == 64 ? floor(src0) : floorf(src0)")
336unop("ffract", tfloat, "src0 - (bit_size == 64 ? floor(src0) : floorf(src0))")
337unop("fround_even", tfloat, "bit_size == 64 ? _mesa_roundeven(src0) : _mesa_roundevenf(src0)")
338
339unop("fquantize2f16", tfloat, "(fabs(src0) < ldexpf(1.0, -14)) ? copysignf(0.0f, src0) : _mesa_half_to_float(_mesa_float_to_half(src0))")
340
341# Trigonometric operations.
342
343
344unop("fsin", tfloat, "bit_size == 64 ? sin(src0) : sinf(src0)")
345unop("fcos", tfloat, "bit_size == 64 ? cos(src0) : cosf(src0)")
346
347# dfrexp
348unop_convert("frexp_exp", tint32, tfloat, "frexp(src0, &dst);")
349unop_convert("frexp_sig", tfloat, tfloat, "int n; dst = frexp(src0, &n);")
350
351# Partial derivatives.
352deriv_template = """
353Calculate the screen-space partial derivative using {} derivatives of the input
354with respect to the {}-axis. The constant folding is trivial as the derivative
355of a constant is 0 if the constant is not Inf or NaN.
356"""
357
358for mode, suffix in [("either fine or coarse", ""), ("fine", "_fine"), ("coarse", "_coarse")]:
359    for axis in ["x", "y"]:
360        unop(f"fdd{axis}{suffix}", tfloat, "isfinite(src0) ? 0.0 : NAN",
361             algebraic_properties = derivative,
362             description = deriv_template.format(mode, axis.upper()))
363
364# Floating point pack and unpack operations.
365
366def pack_2x16(fmt, in_type):
367   unop_horiz("pack_" + fmt + "_2x16", 1, tuint32, 2, in_type, """
368dst.x = (uint32_t) pack_fmt_1x16(src0.x);
369dst.x |= ((uint32_t) pack_fmt_1x16(src0.y)) << 16;
370""".replace("fmt", fmt))
371
372def pack_4x8(fmt):
373   unop_horiz("pack_" + fmt + "_4x8", 1, tuint32, 4, tfloat32, """
374dst.x = (uint32_t) pack_fmt_1x8(src0.x);
375dst.x |= ((uint32_t) pack_fmt_1x8(src0.y)) << 8;
376dst.x |= ((uint32_t) pack_fmt_1x8(src0.z)) << 16;
377dst.x |= ((uint32_t) pack_fmt_1x8(src0.w)) << 24;
378""".replace("fmt", fmt))
379
380def unpack_2x16(fmt):
381   unop_horiz("unpack_" + fmt + "_2x16", 2, tfloat32, 1, tuint32, """
382dst.x = unpack_fmt_1x16((uint16_t)(src0.x & 0xffff));
383dst.y = unpack_fmt_1x16((uint16_t)(src0.x << 16));
384""".replace("fmt", fmt))
385
386def unpack_4x8(fmt):
387   unop_horiz("unpack_" + fmt + "_4x8", 4, tfloat32, 1, tuint32, """
388dst.x = unpack_fmt_1x8((uint8_t)(src0.x & 0xff));
389dst.y = unpack_fmt_1x8((uint8_t)((src0.x >> 8) & 0xff));
390dst.z = unpack_fmt_1x8((uint8_t)((src0.x >> 16) & 0xff));
391dst.w = unpack_fmt_1x8((uint8_t)(src0.x >> 24));
392""".replace("fmt", fmt))
393
394
395pack_2x16("snorm", tfloat)
396pack_4x8("snorm")
397pack_2x16("unorm", tfloat)
398pack_4x8("unorm")
399pack_2x16("half", tfloat32)
400unpack_2x16("snorm")
401unpack_4x8("snorm")
402unpack_2x16("unorm")
403unpack_4x8("unorm")
404
405unop_horiz("pack_uint_2x16", 1, tuint32, 2, tuint32, """
406dst.x = _mesa_unsigned_to_unsigned(src0.x, 16);
407dst.x |= _mesa_unsigned_to_unsigned(src0.y, 16) << 16;
408""", description = """
409Convert two unsigned integers into a packed unsigned short (clamp is applied).
410""")
411
412unop_horiz("pack_sint_2x16", 1, tint32, 2, tint32, """
413dst.x = _mesa_signed_to_signed(src0.x, 16) & 0xffff;
414dst.x |= _mesa_signed_to_signed(src0.y, 16) << 16;
415""", description = """
416Convert two signed integers into a packed signed short (clamp is applied).
417""")
418
419unop_horiz("pack_uvec2_to_uint", 1, tuint32, 2, tuint32, """
420dst.x = (src0.x & 0xffff) | (src0.y << 16);
421""")
422
423unop_horiz("pack_uvec4_to_uint", 1, tuint32, 4, tuint32, """
424dst.x = (src0.x <<  0) |
425        (src0.y <<  8) |
426        (src0.z << 16) |
427        (src0.w << 24);
428""")
429
430unop_horiz("pack_32_4x8", 1, tuint32, 4, tuint8,
431           "dst.x = src0.x | ((uint32_t)src0.y << 8) | ((uint32_t)src0.z << 16) | ((uint32_t)src0.w << 24);")
432
433unop_horiz("pack_32_2x16", 1, tuint32, 2, tuint16,
434           "dst.x = src0.x | ((uint32_t)src0.y << 16);")
435
436unop_horiz("pack_64_2x32", 1, tuint64, 2, tuint32,
437           "dst.x = src0.x | ((uint64_t)src0.y << 32);")
438
439unop_horiz("pack_64_4x16", 1, tuint64, 4, tuint16,
440           "dst.x = src0.x | ((uint64_t)src0.y << 16) | ((uint64_t)src0.z << 32) | ((uint64_t)src0.w << 48);")
441
442unop_horiz("unpack_64_2x32", 2, tuint32, 1, tuint64,
443           "dst.x = src0.x; dst.y = src0.x >> 32;")
444
445unop_horiz("unpack_64_4x16", 4, tuint16, 1, tuint64,
446           "dst.x = src0.x; dst.y = src0.x >> 16; dst.z = src0.x >> 32; dst.w = src0.x >> 48;")
447
448unop_horiz("unpack_32_2x16", 2, tuint16, 1, tuint32,
449           "dst.x = src0.x; dst.y = src0.x >> 16;")
450
451unop_horiz("unpack_32_4x8", 4, tuint8, 1, tuint32,
452           "dst.x = src0.x; dst.y = src0.x >> 8; dst.z = src0.x >> 16; dst.w = src0.x >> 24;")
453
454unop_horiz("unpack_half_2x16", 2, tfloat32, 1, tuint32, """
455dst.x = unpack_half_1x16((uint16_t)(src0.x & 0xffff), nir_is_denorm_flush_to_zero(execution_mode, 16));
456dst.y = unpack_half_1x16((uint16_t)(src0.x >> 16), nir_is_denorm_flush_to_zero(execution_mode, 16));
457""")
458
459# Lowered floating point unpacking operations.
460
461unop_convert("unpack_half_2x16_split_x", tfloat32, tuint32,
462             "unpack_half_1x16((uint16_t)(src0 & 0xffff), nir_is_denorm_flush_to_zero(execution_mode, 16))")
463unop_convert("unpack_half_2x16_split_y", tfloat32, tuint32,
464             "unpack_half_1x16((uint16_t)(src0 >> 16), nir_is_denorm_flush_to_zero(execution_mode, 16))")
465
466
467unop_convert("unpack_32_2x16_split_x", tuint16, tuint32, "src0")
468unop_convert("unpack_32_2x16_split_y", tuint16, tuint32, "src0 >> 16")
469
470unop_convert("unpack_64_2x32_split_x", tuint32, tuint64, "src0")
471unop_convert("unpack_64_2x32_split_y", tuint32, tuint64, "src0 >> 32")
472
473# Bit operations, part of ARB_gpu_shader5.
474
475
476unop("bitfield_reverse", tuint32, """
477/* we're not winning any awards for speed here, but that's ok */
478dst = 0;
479for (unsigned bit = 0; bit < 32; bit++)
480   dst |= ((src0 >> bit) & 1) << (31 - bit);
481""")
482unop_convert("bit_count", tuint32, tuint, """
483dst = 0;
484for (unsigned bit = 0; bit < bit_size; bit++) {
485   if ((src0 >> bit) & 1)
486      dst++;
487}
488""")
489
490unop_convert("ufind_msb", tint32, tuint, """
491dst = -1;
492for (int bit = bit_size - 1; bit >= 0; bit--) {
493   if ((src0 >> bit) & 1) {
494      dst = bit;
495      break;
496   }
497}
498""")
499
500unop_convert("ufind_msb_rev", tint32, tuint, """
501dst = -1;
502for (int bit = 0; bit < bit_size; bit++) {
503   if ((src0 << bit) & 0x80000000) {
504      dst = bit;
505      break;
506   }
507}
508""")
509
510unop("uclz", tuint32, """
511int bit;
512for (bit = bit_size - 1; bit >= 0; bit--) {
513   if ((src0 & (1u << bit)) != 0)
514      break;
515}
516dst = (unsigned)(bit_size - bit - 1);
517""")
518
519unop("ifind_msb", tint32, """
520dst = -1;
521for (int bit = bit_size - 1; bit >= 0; bit--) {
522   /* If src0 < 0, we're looking for the first 0 bit.
523    * if src0 >= 0, we're looking for the first 1 bit.
524    */
525   if ((((src0 >> bit) & 1) && (src0 >= 0)) ||
526      (!((src0 >> bit) & 1) && (src0 < 0))) {
527      dst = bit;
528      break;
529   }
530}
531""")
532
533unop("ifind_msb_rev", tint32, """
534dst = -1;
535/* We are looking for the highest bit that's not the same as the sign bit. */
536uint32_t sign = src0 & 0x80000000u;
537for (int bit = 0; bit < 32; bit++) {
538   if (((src0 << bit) & 0x80000000u) != sign) {
539      dst = bit;
540      break;
541   }
542}
543""")
544
545unop_convert("find_lsb", tint32, tint, """
546dst = -1;
547for (unsigned bit = 0; bit < bit_size; bit++) {
548   if ((src0 >> bit) & 1) {
549      dst = bit;
550      break;
551   }
552}
553""")
554
555unop_reduce("fsum", 1, tfloat, tfloat, "{src}", "{src0} + {src1}", "{src}",
556            description = "Sum of vector components")
557
558def binop_convert(name, out_type, in_type1, alg_props, const_expr, description="", in_type2=None):
559   if in_type2 is None:
560      in_type2 = in_type1
561   opcode(name, 0, out_type, [0, 0], [in_type1, in_type2],
562          False, alg_props, const_expr, description)
563
564def binop(name, ty, alg_props, const_expr, description = ""):
565   binop_convert(name, ty, ty, alg_props, const_expr, description)
566
567def binop_compare(name, ty, alg_props, const_expr, description = "", ty2=None):
568   binop_convert(name, tbool1, ty, alg_props, const_expr, description, ty2)
569
570def binop_compare8(name, ty, alg_props, const_expr, description = "", ty2=None):
571   binop_convert(name, tbool8, ty, alg_props, const_expr, description, ty2)
572
573def binop_compare16(name, ty, alg_props, const_expr, description = "", ty2=None):
574   binop_convert(name, tbool16, ty, alg_props, const_expr, description, ty2)
575
576def binop_compare32(name, ty, alg_props, const_expr, description = "", ty2=None):
577   binop_convert(name, tbool32, ty, alg_props, const_expr, description, ty2)
578
579def binop_compare_all_sizes(name, ty, alg_props, const_expr, description = "", ty2=None):
580   binop_compare(name, ty, alg_props, const_expr, description, ty2)
581   binop_compare8(name + "8", ty, alg_props, const_expr, description, ty2)
582   binop_compare16(name + "16", ty, alg_props, const_expr, description, ty2)
583   binop_compare32(name + "32", ty, alg_props, const_expr, description, ty2)
584
585def binop_horiz(name, out_size, out_type, src1_size, src1_type, src2_size,
586                src2_type, const_expr, description = ""):
587   opcode(name, out_size, out_type, [src1_size, src2_size], [src1_type, src2_type],
588          False, "", const_expr, description)
589
590def binop_reduce(name, output_size, output_type, src_type, prereduce_expr,
591                 reduce_expr, final_expr, suffix="", description = ""):
592   def final(src):
593      return final_expr.format(src= "(" + src + ")")
594   def reduce_(src0, src1):
595      return reduce_expr.format(src0=src0, src1=src1)
596   def prereduce(src0, src1):
597      return "(" + prereduce_expr.format(src0=src0, src1=src1) + ")"
598   srcs = [prereduce("src0." + letter, "src1." + letter) for letter in "xyzwefghijklmnop"]
599   def pairwise_reduce(start, size):
600      if (size == 1):
601         return srcs[start]
602      return reduce_(pairwise_reduce(start + size // 2, size // 2), pairwise_reduce(start, size // 2))
603   for size in [2, 4, 8, 16]:
604      opcode(name + str(size) + suffix, output_size, output_type,
605             [size, size], [src_type, src_type], False, _2src_commutative,
606             final(pairwise_reduce(0, size)), description)
607   opcode(name + "3" + suffix, output_size, output_type,
608          [3, 3], [src_type, src_type], False, _2src_commutative,
609          final(reduce_(reduce_(srcs[2], srcs[1]), srcs[0])), description)
610   opcode(name + "5" + suffix, output_size, output_type,
611          [5, 5], [src_type, src_type], False, _2src_commutative,
612          final(reduce_(srcs[4], reduce_(reduce_(srcs[3], srcs[2]),
613                                         reduce_(srcs[1], srcs[0])))),
614          description)
615
616def binop_reduce_all_sizes(name, output_size, src_type, prereduce_expr,
617                           reduce_expr, final_expr, description = ""):
618   binop_reduce(name, output_size, tbool1, src_type,
619                prereduce_expr, reduce_expr, final_expr, description)
620   binop_reduce("b8" + name[1:], output_size, tbool8, src_type,
621                prereduce_expr, reduce_expr, final_expr, description)
622   binop_reduce("b16" + name[1:], output_size, tbool16, src_type,
623                prereduce_expr, reduce_expr, final_expr, description)
624   binop_reduce("b32" + name[1:], output_size, tbool32, src_type,
625                prereduce_expr, reduce_expr, final_expr, description)
626
627binop("fadd", tfloat, _2src_commutative + associative,"""
628if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
629   if (bit_size == 64)
630      dst = _mesa_double_add_rtz(src0, src1);
631   else
632      dst = _mesa_double_to_float_rtz((double)src0 + (double)src1);
633} else {
634   dst = src0 + src1;
635}
636""")
637binop("iadd", tint, _2src_commutative + associative, "(uint64_t)src0 + (uint64_t)src1")
638binop("iadd_sat", tint, _2src_commutative, """
639      src1 > 0 ?
640         (src0 + src1 < src0 ? u_intN_max(bit_size) : src0 + src1) :
641         (src0 < src0 + src1 ? u_intN_min(bit_size) : src0 + src1)
642""")
643binop("uadd_sat", tuint, _2src_commutative,
644      "(src0 + src1) < src0 ? u_uintN_max(sizeof(src0) * 8) : (src0 + src1)")
645binop("isub_sat", tint, "", """
646      src1 < 0 ?
647         (src0 - src1 < src0 ? u_intN_max(bit_size) : src0 - src1) :
648         (src0 < src0 - src1 ? u_intN_min(bit_size) : src0 - src1)
649""")
650binop("usub_sat", tuint, "", "src0 < src1 ? 0 : src0 - src1")
651
652binop("fsub", tfloat, "", """
653if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
654   if (bit_size == 64)
655      dst = _mesa_double_sub_rtz(src0, src1);
656   else
657      dst = _mesa_double_to_float_rtz((double)src0 - (double)src1);
658} else {
659   dst = src0 - src1;
660}
661""")
662binop("isub", tint, "", "src0 - src1")
663binop_convert("uabs_isub", tuint, tint, "", """
664              src1 > src0 ? (uint64_t) src1 - (uint64_t) src0
665                          : (uint64_t) src0 - (uint64_t) src1
666""")
667binop("uabs_usub", tuint, "", "(src1 > src0) ? (src1 - src0) : (src0 - src1)")
668
669binop("fmul", tfloat, _2src_commutative + associative, """
670if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
671   if (bit_size == 64)
672      dst = _mesa_double_mul_rtz(src0, src1);
673   else
674      dst = _mesa_double_to_float_rtz((double)src0 * (double)src1);
675} else {
676   dst = src0 * src1;
677}
678""")
679
680binop("fmulz", tfloat32, _2src_commutative + associative, """
681if (src0 == 0.0 || src1 == 0.0)
682   dst = 0.0;
683else if (nir_is_rounding_mode_rtz(execution_mode, 32))
684   dst = _mesa_double_to_float_rtz((double)src0 * (double)src1);
685else
686   dst = src0 * src1;
687""", description = """
688Unlike :nir:alu-op:`fmul`, anything (even infinity or NaN) multiplied by zero is
689always zero. ``fmulz(0.0, inf)`` and ``fmulz(0.0, nan)`` must be +/-0.0, even
690if ``INF_PRESERVE/NAN_PRESERVE`` is not used. If ``SIGNED_ZERO_PRESERVE`` is
691used, then the result must be a positive zero if either operand is zero.
692""")
693
694
695binop("imul", tint, _2src_commutative + associative, """
696   /* Use 64-bit multiplies to prevent overflow of signed arithmetic */
697   dst = (uint64_t)src0 * (uint64_t)src1;
698""", description = "Low 32-bits of signed/unsigned integer multiply")
699
700binop_convert("imul_2x32_64", tint64, tint32, _2src_commutative,
701              "(int64_t)src0 * (int64_t)src1",
702              description = "Multiply signed 32-bit integers, 64-bit result")
703binop_convert("umul_2x32_64", tuint64, tuint32, _2src_commutative,
704              "(uint64_t)src0 * (uint64_t)src1",
705              description = "Multiply unsigned 32-bit integers, 64-bit result")
706
707binop("imul_high", tint, _2src_commutative, """
708if (bit_size == 64) {
709   /* We need to do a full 128-bit x 128-bit multiply in order for the sign
710    * extension to work properly.  The casts are kind-of annoying but needed
711    * to prevent compiler warnings.
712    */
713   uint32_t src0_u32[4] = {
714      src0,
715      (int64_t)src0 >> 32,
716      (int64_t)src0 >> 63,
717      (int64_t)src0 >> 63,
718   };
719   uint32_t src1_u32[4] = {
720      src1,
721      (int64_t)src1 >> 32,
722      (int64_t)src1 >> 63,
723      (int64_t)src1 >> 63,
724   };
725   uint32_t prod_u32[4];
726   ubm_mul_u32arr(prod_u32, src0_u32, src1_u32);
727   dst = (uint64_t)prod_u32[2] | ((uint64_t)prod_u32[3] << 32);
728} else {
729   /* First, sign-extend to 64-bit, then convert to unsigned to prevent
730    * potential overflow of signed multiply */
731   dst = ((uint64_t)(int64_t)src0 * (uint64_t)(int64_t)src1) >> bit_size;
732}
733""", description = "High 32-bits of signed integer multiply")
734
735binop("umul_high", tuint, _2src_commutative, """
736if (bit_size == 64) {
737   /* The casts are kind-of annoying but needed to prevent compiler warnings. */
738   uint32_t src0_u32[2] = { src0, (uint64_t)src0 >> 32 };
739   uint32_t src1_u32[2] = { src1, (uint64_t)src1 >> 32 };
740   uint32_t prod_u32[4];
741   ubm_mul_u32arr(prod_u32, src0_u32, src1_u32);
742   dst = (uint64_t)prod_u32[2] | ((uint64_t)prod_u32[3] << 32);
743} else {
744   dst = ((uint64_t)src0 * (uint64_t)src1) >> bit_size;
745}
746""", description = "High 32-bits of unsigned integer multiply")
747
748binop("umul_low", tuint32, _2src_commutative, """
749uint64_t mask = (1 << (bit_size / 2)) - 1;
750dst = ((uint64_t)src0 & mask) * ((uint64_t)src1 & mask);
751""", description = "Low 32-bits of unsigned integer multiply")
752
753binop("imul_32x16", tint32, "", "src0 * (int16_t) src1",
754      description = "Multiply 32-bits with low 16-bits, with sign extension")
755binop("umul_32x16", tuint32, "", "src0 * (uint16_t) src1",
756      description = "Multiply 32-bits with low 16-bits, with zero extension")
757
758binop("fdiv", tfloat, "", "src0 / src1")
759binop("idiv", tint, "", "src1 == 0 ? 0 : (src0 / src1)")
760binop("udiv", tuint, "", "src1 == 0 ? 0 : (src0 / src1)")
761
762binop_convert("uadd_carry", tuint, tuint, _2src_commutative,
763              "src0 + src1 < src0",
764              description = """
765Return an integer (1 or 0) representing the carry resulting from the
766addition of the two unsigned arguments.
767              """)
768
769binop_convert("usub_borrow", tuint, tuint, "", "src0 < src1", description = """
770Return an integer (1 or 0) representing the borrow resulting from the
771subtraction of the two unsigned arguments.
772              """)
773
774# hadd: (a + b) >> 1 (without overflow)
775# x + y = x - (x & ~y) + (x & ~y) + y - (~x & y) + (~x & y)
776#       =      (x & y) + (x & ~y) +      (x & y) + (~x & y)
777#       = 2 *  (x & y) + (x & ~y) +                (~x & y)
778#       =     ((x & y) << 1) + (x ^ y)
779#
780# Since we know that the bottom bit of (x & y) << 1 is zero,
781#
782# (x + y) >> 1 = (((x & y) << 1) + (x ^ y)) >> 1
783#              =   (x & y) +      ((x ^ y)  >> 1)
784binop("ihadd", tint, _2src_commutative, "(src0 & src1) + ((src0 ^ src1) >> 1)")
785binop("uhadd", tuint, _2src_commutative, "(src0 & src1) + ((src0 ^ src1) >> 1)")
786
787# rhadd: (a + b + 1) >> 1 (without overflow)
788# x + y + 1 = x + (~x & y) - (~x & y) + y + (x & ~y) - (x & ~y) + 1
789#           =      (x | y) - (~x & y) +      (x | y) - (x & ~y) + 1
790#           = 2 *  (x | y) - ((~x & y) +               (x & ~y)) + 1
791#           =     ((x | y) << 1) - (x ^ y) + 1
792#
793# Since we know that the bottom bit of (x & y) << 1 is zero,
794#
795# (x + y + 1) >> 1 = (x | y) + (-(x ^ y) + 1) >> 1)
796#                  = (x | y) -  ((x ^ y)      >> 1)
797binop("irhadd", tint, _2src_commutative, "(src0 | src1) - ((src0 ^ src1) >> 1)")
798binop("urhadd", tuint, _2src_commutative, "(src0 | src1) - ((src0 ^ src1) >> 1)")
799
800binop("umod", tuint, "", "src1 == 0 ? 0 : src0 % src1")
801
802# For signed integers, there are several different possible definitions of
803# "modulus" or "remainder".  We follow the conventions used by LLVM and
804# SPIR-V.  The irem opcode implements the standard C/C++ signed "%"
805# operation while the imod opcode implements the more mathematical
806# "modulus" operation.  For details on the difference, see
807#
808# http://mathforum.org/library/drmath/view/52343.html
809
810binop("irem", tint, "", "src1 == 0 ? 0 : src0 % src1")
811binop("imod", tint, "",
812      "src1 == 0 ? 0 : ((src0 % src1 == 0 || (src0 >= 0) == (src1 >= 0)) ?"
813      "                 src0 % src1 : src0 % src1 + src1)")
814binop("fmod", tfloat, "", "src0 - src1 * floorf(src0 / src1)")
815binop("frem", tfloat, "", "src0 - src1 * truncf(src0 / src1)")
816
817#
818# Comparisons
819#
820
821
822# these integer-aware comparisons return a boolean (0 or ~0)
823
824binop_compare_all_sizes("flt", tfloat, "", "src0 < src1")
825binop_compare_all_sizes("fge", tfloat, "", "src0 >= src1")
826binop_compare_all_sizes("fltu", tfloat, "", "isnan(src0) || isnan(src1) || src0 < src1")
827binop_compare_all_sizes("fgeu", tfloat, "", "isnan(src0) || isnan(src1) || src0 >= src1")
828binop_compare_all_sizes("feq", tfloat, _2src_commutative, "src0 == src1")
829binop_compare_all_sizes("fneu", tfloat, _2src_commutative, "src0 != src1")
830binop_compare_all_sizes("fequ", tfloat, _2src_commutative, "isnan(src0) || isnan(src1) || src0 == src1")
831binop_compare_all_sizes("fneo", tfloat, _2src_commutative, "!isnan(src0) && !isnan(src1) && src0 != src1")
832binop_compare_all_sizes("funord", tfloat, _2src_commutative, "isnan(src0) || isnan(src1)")
833binop_compare_all_sizes("ford", tfloat, _2src_commutative, "!isnan(src0) && !isnan(src1)")
834binop_compare_all_sizes("ilt", tint, "", "src0 < src1")
835binop_compare_all_sizes("ige", tint, "", "src0 >= src1")
836binop_compare_all_sizes("ieq", tint, _2src_commutative, "src0 == src1")
837binop_compare_all_sizes("ine", tint, _2src_commutative, "src0 != src1")
838binop_compare_all_sizes("ult", tuint, "", "src0 < src1")
839binop_compare_all_sizes("uge", tuint, "", "src0 >= src1")
840
841binop_compare_all_sizes("bitnz", tuint, "", "((uint64_t)src0 >> (src1 & (bit_size - 1)) & 0x1) == 0x1",
842   "only uses the least significant bits like SM5 shifts", tuint32)
843
844binop_compare_all_sizes("bitz", tuint, "", "((uint64_t)src0 >> (src1 & (bit_size - 1)) & 0x1) == 0x0",
845   "only uses the least significant bits like SM5 shifts", tuint32)
846
847# integer-aware GLSL-style comparisons that compare floats and ints
848
849binop_reduce_all_sizes("ball_fequal",  1, tfloat, "{src0} == {src1}",
850                       "{src0} && {src1}", "{src}")
851binop_reduce_all_sizes("bany_fnequal", 1, tfloat, "{src0} != {src1}",
852                       "{src0} || {src1}", "{src}")
853binop_reduce_all_sizes("ball_iequal",  1, tint, "{src0} == {src1}",
854                       "{src0} && {src1}", "{src}")
855binop_reduce_all_sizes("bany_inequal", 1, tint, "{src0} != {src1}",
856                       "{src0} || {src1}", "{src}")
857
858# non-integer-aware GLSL-style comparisons that return 0.0 or 1.0
859
860binop_reduce("fall_equal",  1, tfloat32, tfloat32, "{src0} == {src1}",
861             "{src0} && {src1}", "{src} ? 1.0f : 0.0f")
862binop_reduce("fany_nequal", 1, tfloat32, tfloat32, "{src0} != {src1}",
863             "{src0} || {src1}", "{src} ? 1.0f : 0.0f")
864
865# These comparisons for integer-less hardware return 1.0 and 0.0 for true
866# and false respectively
867
868binop("slt", tfloat, "", "(src0 < src1) ? 1.0f : 0.0f") # Set on Less Than
869binop("sge", tfloat, "", "(src0 >= src1) ? 1.0f : 0.0f") # Set on Greater or Equal
870binop("seq", tfloat, _2src_commutative, "(src0 == src1) ? 1.0f : 0.0f") # Set on Equal
871binop("sne", tfloat, _2src_commutative, "(src0 != src1) ? 1.0f : 0.0f") # Set on Not Equal
872
873shift_note = """
874SPIRV shifts are undefined for shift-operands >= bitsize,
875but SM5 shifts are defined to use only the least significant bits.
876The NIR definition is according to the SM5 specification.
877"""
878
879opcode("ishl", 0, tint, [0, 0], [tint, tuint32], False, "",
880       "(uint64_t)src0 << (src1 & (sizeof(src0) * 8 - 1))",
881       description = "Left shift." + shift_note)
882opcode("ishr", 0, tint, [0, 0], [tint, tuint32], False, "",
883       "src0 >> (src1 & (sizeof(src0) * 8 - 1))",
884       description = "Signed right-shift." + shift_note)
885opcode("ushr", 0, tuint, [0, 0], [tuint, tuint32], False, "",
886       "src0 >> (src1 & (sizeof(src0) * 8 - 1))",
887       description = "Unsigned right-shift." + shift_note)
888
889opcode("urol", 0, tuint, [0, 0], [tuint, tuint32], False, "", """
890   uint32_t rotate_mask = sizeof(src0) * 8 - 1;
891   dst = (src0 << (src1 & rotate_mask)) |
892         (src0 >> (-src1 & rotate_mask));
893""")
894opcode("uror", 0, tuint, [0, 0], [tuint, tuint32], False, "", """
895   uint32_t rotate_mask = sizeof(src0) * 8 - 1;
896   dst = (src0 >> (src1 & rotate_mask)) |
897         (src0 << (-src1 & rotate_mask));
898""")
899
900opcode("shfr", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32], False, "", """
901   uint32_t rotate_mask = sizeof(src0) * 8 - 1;
902   dst = (src1 >> (src2 & rotate_mask)) |
903         (src0 << (-src2 & rotate_mask));
904""")
905
906bitwise_description = """
907Bitwise {0}, also used as a boolean {0} for hardware supporting integers.
908"""
909
910binop("iand", tuint, _2src_commutative + associative, "src0 & src1",
911      description = bitwise_description.format("AND"))
912binop("ior", tuint, _2src_commutative + associative, "src0 | src1",
913      description = bitwise_description.format("OR"))
914binop("ixor", tuint, _2src_commutative + associative, "src0 ^ src1",
915      description = bitwise_description.format("XOR"))
916
917
918binop_reduce("fdot", 1, tfloat, tfloat, "{src0} * {src1}", "{src0} + {src1}",
919             "{src}")
920
921binop_reduce("fdot", 0, tfloat, tfloat,
922             "{src0} * {src1}", "{src0} + {src1}", "{src}",
923             suffix="_replicated")
924
925opcode("fdph", 1, tfloat, [3, 4], [tfloat, tfloat], False, "",
926       "src0.x * src1.x + src0.y * src1.y + src0.z * src1.z + src1.w")
927opcode("fdph_replicated", 0, tfloat, [3, 4], [tfloat, tfloat], False, "",
928       "src0.x * src1.x + src0.y * src1.y + src0.z * src1.z + src1.w")
929
930# The C fmin/fmax functions have implementation-defined behaviour for signed
931# zeroes. However, SPIR-V requires:
932#
933#   fmin(-0, +0) = -0
934#   fmax(+0, -0) = +0
935#
936# The NIR opcodes match SPIR-V. Furthermore, the NIR opcodes are commutative, so
937# we must also ensure:
938#
939#   fmin(+0, -0) = -0
940#   fmax(-0, +0) = +0
941#
942# To implement the constant folding, when the sources are equal, we use the
943# min/max of the bit patterns which will order the signed zeroes while
944# preserving all other values.
945for op, macro in [("fmin", "MIN2"), ("fmax", "MAX2")]:
946    binop(op, tfloat, _2src_commutative + associative,
947          "bit_size == 64 ? " +
948          f"(src0 == src1 ? uid({macro}((int64_t)dui(src0), (int64_t)dui(src1))) : {op}(src0, src1)) :"
949          f"(src0 == src1 ? uif({macro}((int32_t)fui(src0), (int32_t)fui(src1))) : {op}f(src0, src1))")
950
951binop("imin", tint, _2src_commutative + associative, "MIN2(src0, src1)")
952binop("umin", tuint, _2src_commutative + associative, "MIN2(src0, src1)")
953binop("imax", tint, _2src_commutative + associative, "MAX2(src0, src1)")
954binop("umax", tuint, _2src_commutative + associative, "MAX2(src0, src1)")
955
956binop("fpow", tfloat, "", "bit_size == 64 ? pow(src0, src1) : powf(src0, src1)")
957
958binop_horiz("pack_half_2x16_split", 1, tuint32, 1, tfloat32, 1, tfloat32,
959            "pack_half_1x16(src0.x) | ((uint32_t)(pack_half_1x16(src1.x)) << 16)")
960
961binop_horiz("pack_half_2x16_rtz_split", 1, tuint32, 1, tfloat32, 1, tfloat32,
962            "pack_half_1x16_rtz(src0.x) | (uint32_t)(pack_half_1x16_rtz(src1.x) << 16)")
963
964binop_convert("pack_64_2x32_split", tuint64, tuint32, "",
965              "src0 | ((uint64_t)src1 << 32)")
966
967binop_convert("pack_32_2x16_split", tuint32, tuint16, "",
968              "src0 | ((uint32_t)src1 << 16)")
969
970opcode("pack_32_4x8_split", 0, tuint32, [0, 0, 0, 0], [tuint8, tuint8, tuint8, tuint8],
971       False, "",
972       "src0 | ((uint32_t)src1 << 8) | ((uint32_t)src2 << 16) | ((uint32_t)src3 << 24)")
973
974binop_convert("bfm", tuint32, tint32, "", """
975int bits = src0 & 0x1F;
976int offset = src1 & 0x1F;
977dst = ((1u << bits) - 1) << offset;
978""", description = """
979Implements the behavior of the first operation of the SM5 "bfi" assembly
980and that of the "bfi1" i965 instruction. That is, the bits and offset values
981are from the low five bits of src0 and src1, respectively.
982""")
983
984opcode("ldexp", 0, tfloat, [0, 0], [tfloat, tint32], False, "", """
985dst = (bit_size == 64) ? ldexp(src0, src1) : ldexpf(src0, src1);
986/* flush denormals to zero. */
987if (!isnormal(dst))
988   dst = copysignf(0.0f, src0);
989""")
990
991binop_horiz("vec2", 2, tuint, 1, tuint, 1, tuint, """
992dst.x = src0.x;
993dst.y = src1.x;
994""", description = """
995Combines the first component of each input to make a 2-component vector.
996""")
997
998# Byte extraction
999binop("extract_u8", tuint, "", "(uint8_t)(src0 >> (src1 * 8))")
1000binop("extract_i8", tint, "", "(int8_t)(src0 >> (src1 * 8))")
1001
1002# Word extraction
1003binop("extract_u16", tuint, "", "(uint16_t)(src0 >> (src1 * 16))")
1004binop("extract_i16", tint, "", "(int16_t)(src0 >> (src1 * 16))")
1005
1006# Byte/word insertion
1007binop("insert_u8", tuint, "", "(src0 & 0xff) << (src1 * 8)")
1008binop("insert_u16", tuint, "", "(src0 & 0xffff) << (src1 * 16)")
1009
1010
1011def triop(name, ty, alg_props, const_expr, description = ""):
1012   opcode(name, 0, ty, [0, 0, 0], [ty, ty, ty], False, alg_props, const_expr,
1013          description)
1014def triop_horiz(name, output_size, src1_size, src2_size, src3_size, const_expr,
1015                description = ""):
1016   opcode(name, output_size, tuint,
1017   [src1_size, src2_size, src3_size],
1018   [tuint, tuint, tuint], False, "", const_expr, description)
1019
1020triop("ffma", tfloat, _2src_commutative, """
1021if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
1022   if (bit_size == 64)
1023      dst = _mesa_double_fma_rtz(src0, src1, src2);
1024   else if (bit_size == 32)
1025      dst = _mesa_float_fma_rtz(src0, src1, src2);
1026   else
1027      dst = _mesa_double_to_float_rtz(_mesa_double_fma_rtz(src0, src1, src2));
1028} else {
1029   if (bit_size == 32)
1030      dst = fmaf(src0, src1, src2);
1031   else
1032      dst = fma(src0, src1, src2);
1033}
1034""")
1035
1036triop("ffmaz", tfloat32, _2src_commutative, """
1037if (src0 == 0.0 || src1 == 0.0)
1038   dst = 0.0 + src2;
1039else if (nir_is_rounding_mode_rtz(execution_mode, 32))
1040   dst = _mesa_float_fma_rtz(src0, src1, src2);
1041else
1042   dst = fmaf(src0, src1, src2);
1043""", description = """
1044Floating-point multiply-add with modified zero handling.
1045
1046Unlike :nir:alu-op:`ffma`, anything (even infinity or NaN) multiplied by zero is
1047always zero. ``ffmaz(0.0, inf, src2)`` and ``ffmaz(0.0, nan, src2)`` must be
1048``+/-0.0 + src2``, even if ``INF_PRESERVE/NAN_PRESERVE`` is not used. If
1049``SIGNED_ZERO_PRESERVE`` is used, then the result must be a positive
1050zero plus src2 if either src0 or src1 is zero.
1051""")
1052
1053triop("flrp", tfloat, "", "src0 * (1 - src2) + src1 * src2")
1054
1055triop("iadd3", tint, _2src_commutative + associative, "src0 + src1 + src2",
1056      description = "Ternary addition")
1057
1058triop("imad", tint, _2src_commutative + associative, "src0 * src1 + src2",
1059      description = "Integer multiply-add")
1060
1061csel_description = """
1062A vector conditional select instruction (like ?:, but operating per-
1063component on vectors). The condition is {} bool ({}).
1064"""
1065
1066triop("fcsel", tfloat32, selection, "(src0 != 0.0f) ? src1 : src2",
1067      description = csel_description.format("a floating point", "0.0 vs 1.0"))
1068opcode("bcsel", 0, tuint, [0, 0, 0],
1069       [tbool1, tuint, tuint], False, selection, "src0 ? src1 : src2",
1070       description = csel_description.format("a 1-bit", "0 vs 1"))
1071opcode("b8csel", 0, tuint, [0, 0, 0],
1072       [tbool8, tuint, tuint], False, selection, "src0 ? src1 : src2",
1073       description = csel_description.format("an 8-bit", "0 vs ~0"))
1074opcode("b16csel", 0, tuint, [0, 0, 0],
1075       [tbool16, tuint, tuint], False, selection, "src0 ? src1 : src2",
1076       description = csel_description.format("a 16-bit", "0 vs ~0"))
1077opcode("b32csel", 0, tuint, [0, 0, 0],
1078       [tbool32, tuint, tuint], False, selection, "src0 ? src1 : src2",
1079       description = csel_description.format("a 32-bit", "0 vs ~0"))
1080
1081triop("i32csel_gt", tint32, selection, "(src0 > 0) ? src1 : src2")
1082triop("i32csel_ge", tint32, selection, "(src0 >= 0) ? src1 : src2")
1083
1084triop("fcsel_gt", tfloat32, selection, "(src0 > 0.0f) ? src1 : src2")
1085triop("fcsel_ge", tfloat32, selection, "(src0 >= 0.0f) ? src1 : src2")
1086
1087triop("bfi", tuint32, "", """
1088unsigned mask = src0, insert = src1, base = src2;
1089if (mask == 0) {
1090   dst = base;
1091} else {
1092   unsigned tmp = mask;
1093   while (!(tmp & 1)) {
1094      tmp >>= 1;
1095      insert <<= 1;
1096   }
1097   dst = (base & ~mask) | (insert & mask);
1098}
1099""", description = "SM5 bfi assembly")
1100
1101
1102triop("bitfield_select", tuint, "", "(src0 & src1) | (~src0 & src2)")
1103
1104# SM5 ubfe/ibfe assembly: only the 5 least significant bits of offset and bits are used.
1105opcode("ubfe", 0, tuint32,
1106       [0, 0, 0], [tuint32, tuint32, tuint32], False, "", """
1107unsigned base = src0;
1108unsigned offset = src1 & 0x1F;
1109unsigned bits = src2 & 0x1F;
1110if (bits == 0) {
1111   dst = 0;
1112} else if (offset + bits < 32) {
1113   dst = (base << (32 - bits - offset)) >> (32 - bits);
1114} else {
1115   dst = base >> offset;
1116}
1117""")
1118opcode("ibfe", 0, tint32,
1119       [0, 0, 0], [tint32, tuint32, tuint32], False, "", """
1120int base = src0;
1121unsigned offset = src1 & 0x1F;
1122unsigned bits = src2 & 0x1F;
1123if (bits == 0) {
1124   dst = 0;
1125} else if (offset + bits < 32) {
1126   dst = (base << (32 - bits - offset)) >> (32 - bits);
1127} else {
1128   dst = base >> offset;
1129}
1130""")
1131
1132# GLSL bitfieldExtract()
1133opcode("ubitfield_extract", 0, tuint32,
1134       [0, 0, 0], [tuint32, tint32, tint32], False, "", """
1135unsigned base = src0;
1136int offset = src1, bits = src2;
1137if (bits == 0) {
1138   dst = 0;
1139} else if (bits < 0 || offset < 0 || offset + bits > 32) {
1140   dst = 0; /* undefined per the spec */
1141} else {
1142   dst = (base >> offset) & ((1ull << bits) - 1);
1143}
1144""")
1145opcode("ibitfield_extract", 0, tint32,
1146       [0, 0, 0], [tint32, tint32, tint32], False, "", """
1147int base = src0;
1148int offset = src1, bits = src2;
1149if (bits == 0) {
1150   dst = 0;
1151} else if (offset < 0 || bits < 0 || offset + bits > 32) {
1152   dst = 0;
1153} else {
1154   dst = (base << (32 - offset - bits)) >> (32 - bits); /* use sign-extending shift */
1155}
1156""")
1157
1158triop("msad_4x8", tuint32, "", """
1159dst = msad(src0, src1, src2);
1160""", description = """
1161Masked sum of absolute differences with accumulation. Equivalent to AMD's v_msad_u8
1162instruction and DXIL's MSAD.
1163
1164The first two sources contain packed 8-bit unsigned integers, the instruction
1165will calculate the absolute difference of integers when src0's is non-zero, and
1166then add them together. There is also a third source which is a 32-bit unsigned
1167integer and added to the result.
1168""")
1169
1170opcode("mqsad_4x8", 4, tuint32, [1, 2, 4], [tuint32, tuint32, tuint32], False, "", """
1171uint64_t src = src1.x | ((uint64_t)src1.y << 32);
1172dst.x = msad(src0.x, src, src2.x);
1173dst.y = msad(src0.x, src >> 8, src2.y);
1174dst.z = msad(src0.x, src >> 16, src2.z);
1175dst.w = msad(src0.x, src >> 24, src2.w);
1176""")
1177
1178# Combines the first component of each input to make a 3-component vector.
1179
1180triop_horiz("vec3", 3, 1, 1, 1, """
1181dst.x = src0.x;
1182dst.y = src1.x;
1183dst.z = src2.x;
1184""")
1185
1186def quadop_horiz(name, output_size, src1_size, src2_size, src3_size,
1187                 src4_size, const_expr):
1188   opcode(name, output_size, tuint,
1189          [src1_size, src2_size, src3_size, src4_size],
1190          [tuint, tuint, tuint, tuint],
1191          False, "", const_expr)
1192
1193opcode("bitfield_insert", 0, tuint32, [0, 0, 0, 0],
1194       [tuint32, tuint32, tint32, tint32], False, "", """
1195unsigned base = src0, insert = src1;
1196int offset = src2, bits = src3;
1197if (bits == 0) {
1198   dst = base;
1199} else if (offset < 0 || bits < 0 || bits + offset > 32) {
1200   dst = 0;
1201} else {
1202   unsigned mask = ((1ull << bits) - 1) << offset;
1203   dst = (base & ~mask) | ((insert << offset) & mask);
1204}
1205""")
1206
1207quadop_horiz("vec4", 4, 1, 1, 1, 1, """
1208dst.x = src0.x;
1209dst.y = src1.x;
1210dst.z = src2.x;
1211dst.w = src3.x;
1212""")
1213
1214opcode("vec5", 5, tuint,
1215       [1] * 5, [tuint] * 5,
1216       False, "", """
1217dst.x = src0.x;
1218dst.y = src1.x;
1219dst.z = src2.x;
1220dst.w = src3.x;
1221dst.e = src4.x;
1222""")
1223
1224opcode("vec8", 8, tuint,
1225       [1] * 8, [tuint] * 8,
1226       False, "", """
1227dst.x = src0.x;
1228dst.y = src1.x;
1229dst.z = src2.x;
1230dst.w = src3.x;
1231dst.e = src4.x;
1232dst.f = src5.x;
1233dst.g = src6.x;
1234dst.h = src7.x;
1235""")
1236
1237opcode("vec16", 16, tuint,
1238       [1] * 16, [tuint] * 16,
1239       False, "", """
1240dst.x = src0.x;
1241dst.y = src1.x;
1242dst.z = src2.x;
1243dst.w = src3.x;
1244dst.e = src4.x;
1245dst.f = src5.x;
1246dst.g = src6.x;
1247dst.h = src7.x;
1248dst.i = src8.x;
1249dst.j = src9.x;
1250dst.k = src10.x;
1251dst.l = src11.x;
1252dst.m = src12.x;
1253dst.n = src13.x;
1254dst.o = src14.x;
1255dst.p = src15.x;
1256""")
1257
1258# An integer multiply instruction for address calculation.  This is
1259# similar to imul, except that the results are undefined in case of
1260# overflow.  Overflow is defined according to the size of the variable
1261# being dereferenced.
1262#
1263# This relaxed definition, compared to imul, allows an optimization
1264# pass to propagate bounds (ie, from an load/store intrinsic) to the
1265# sources, such that lower precision integer multiplies can be used.
1266# This is useful on hw that has 24b or perhaps 16b integer multiply
1267# instructions.
1268binop("amul", tint, _2src_commutative + associative, "src0 * src1")
1269
1270# ir3-specific instruction that maps directly to mul-add shift high mix,
1271# (IMADSH_MIX16 i.e. al * bh << 16 + c). It is used for lowering integer
1272# multiplication (imul) on Freedreno backend..
1273opcode("imadsh_mix16", 0, tint32,
1274       [0, 0, 0], [tint32, tint32, tint32], False, "", """
1275dst = ((((src0 & 0x0000ffff) << 16) * (src1 & 0xffff0000)) >> 16) + src2;
1276""")
1277
1278# ir3-specific instruction that maps directly to ir3 mad.s24.
1279#
1280# 24b multiply into 32b result (with sign extension) plus 32b int
1281triop("imad24_ir3", tint32, _2src_commutative,
1282      "(((int32_t)src0 << 8) >> 8) * (((int32_t)src1 << 8) >> 8) + src2")
1283
1284# r600/gcn specific instruction that evaluates unnormalized cube texture coordinates
1285# and face index
1286# The actual texture coordinates are evaluated from this according to
1287#    dst.yx / abs(dst.z) + 1.5
1288unop_horiz("cube_amd", 4, tfloat32, 3, tfloat32, """
1289   dst.x = dst.y = dst.z = 0.0;
1290   float absX = fabsf(src0.x);
1291   float absY = fabsf(src0.y);
1292   float absZ = fabsf(src0.z);
1293
1294   if (absX >= absY && absX >= absZ) { dst.z = 2 * src0.x; }
1295   if (absY >= absX && absY >= absZ) { dst.z = 2 * src0.y; }
1296   if (absZ >= absX && absZ >= absY) { dst.z = 2 * src0.z; }
1297
1298   if (src0.x >= 0 && absX >= absY && absX >= absZ) {
1299      dst.y = -src0.z; dst.x = -src0.y; dst.w = 0;
1300   }
1301   if (src0.x < 0 && absX >= absY && absX >= absZ) {
1302      dst.y = src0.z; dst.x = -src0.y; dst.w = 1;
1303   }
1304   if (src0.y >= 0 && absY >= absX && absY >= absZ) {
1305      dst.y = src0.x; dst.x = src0.z; dst.w = 2;
1306   }
1307   if (src0.y < 0 && absY >= absX && absY >= absZ) {
1308      dst.y = src0.x; dst.x = -src0.z; dst.w = 3;
1309   }
1310   if (src0.z >= 0 && absZ >= absX && absZ >= absY) {
1311      dst.y = src0.x; dst.x = -src0.y; dst.w = 4;
1312   }
1313   if (src0.z < 0 && absZ >= absX && absZ >= absY) {
1314      dst.y = -src0.x; dst.x = -src0.y; dst.w = 5;
1315   }
1316""")
1317
1318# r600/gcn specific sin and cos
1319# these trigeometric functions need some lowering because the supported
1320# input values are expected to be normalized by dividing by (2 * pi)
1321unop("fsin_amd", tfloat, "sinf(6.2831853 * src0)")
1322unop("fcos_amd", tfloat, "cosf(6.2831853 * src0)")
1323
1324# Midgard specific sin and cos
1325# These expect their inputs to be divided by pi.
1326unop("fsin_mdg", tfloat, "sinf(3.141592653589793 * src0)")
1327unop("fcos_mdg", tfloat, "cosf(3.141592653589793 * src0)")
1328
1329# AGX specific sin with input expressed in quadrants. Used in the lowering for
1330# fsin/fcos. This corresponds to a sequence of 3 ALU ops in the backend (where
1331# the angle is further decomposed by quadrant, sinc is computed, and the angle
1332# is multiplied back for sin). Lowering fsin/fcos to fsin_agx requires some
1333# additional ALU that NIR may be able to optimize.
1334unop("fsin_agx", tfloat, "sinf(src0 * (6.2831853/4.0))")
1335
1336# AGX specific bitfield extraction from a pair of 32bit registers.
1337# src0,src1: the two registers
1338# src2: bit position of the LSB of the bitfield
1339# src3: number of bits in the bitfield if src3 > 0
1340#       src3 = 0 is equivalent to src3 = 32
1341# NOTE: src3 is a nir constant by contract
1342opcode("extr_agx", 0, tuint32,
1343       [0, 0, 0, 0], [tuint32, tuint32, tuint32, tuint32], False, "", """
1344    uint32_t mask = 0xFFFFFFFF;
1345    uint8_t shift = src2 & 0x7F;
1346    if (src3 != 0) {
1347       mask = (1 << src3) - 1;
1348    }
1349    if (shift >= 64) {
1350        dst = 0;
1351    } else {
1352        dst = (((((uint64_t) src1) << 32) | (uint64_t) src0) >> shift) & mask;
1353    }
1354""");
1355
1356# AGX multiply-shift-add. Corresponds to iadd/isub/imad/imsub instructions.
1357# The shift must be <= 4 (domain restriction). For performance, it should be
1358# constant.
1359opcode("imadshl_agx", 0, tint, [0, 0, 0, 0], [tint, tint, tint, tint], False,
1360       "", f"(src0 * src1) + (src2 << src3)")
1361opcode("imsubshl_agx", 0, tint, [0, 0, 0, 0], [tint, tint, tint, tint], False,
1362       "", f"(src0 * src1) - (src2 << src3)")
1363
1364binop_convert("interleave_agx", tuint32, tuint16, "", """
1365      dst = 0;
1366      for (unsigned bit = 0; bit < 16; bit++) {
1367          dst |= (src0 & (1 << bit)) << bit;
1368          dst |= (src1 & (1 << bit)) << (bit + 1);
1369      }""", description="""
1370      Interleave bits of 16-bit integers to calculate a 32-bit integer. This can
1371      be used as-is for Morton encoding.
1372      """)
1373
1374# NVIDIA PRMT
1375opcode("prmt_nv", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32],
1376       False, "", """
1377    dst = 0;
1378    for (unsigned i = 0; i < 4; i++) {
1379        uint8_t byte = (src0 >> (i * 4)) & 0x7;
1380        uint8_t x = byte < 4 ? (src1 >> (byte * 8))
1381                             : (src2 >> ((byte - 4) * 8));
1382        if ((src0 >> (i * 4)) & 0x8)
1383            x = ((int8_t)x) >> 7;
1384        dst |= ((uint32_t)x) << i * 8;
1385    }""")
1386
1387# 24b multiply into 32b result (with sign extension)
1388binop("imul24", tint32, _2src_commutative + associative,
1389      "(((int32_t)src0 << 8) >> 8) * (((int32_t)src1 << 8) >> 8)")
1390
1391# unsigned 24b multiply into 32b result plus 32b int
1392triop("umad24", tuint32, _2src_commutative,
1393      "(((uint32_t)src0 << 8) >> 8) * (((uint32_t)src1 << 8) >> 8) + src2")
1394
1395# unsigned 24b multiply into 32b result uint
1396binop("umul24", tint32, _2src_commutative + associative,
1397      "(((uint32_t)src0 << 8) >> 8) * (((uint32_t)src1 << 8) >> 8)")
1398
1399# relaxed versions of the above, which assume input is in the 24bit range (no clamping)
1400binop("imul24_relaxed", tint32, _2src_commutative + associative, "src0 * src1")
1401triop("umad24_relaxed", tuint32, _2src_commutative, "src0 * src1 + src2")
1402binop("umul24_relaxed", tuint32, _2src_commutative + associative, "src0 * src1")
1403
1404unop_convert("fisnormal", tbool1, tfloat, "isnormal(src0)")
1405unop_convert("fisfinite", tbool1, tfloat, "isfinite(src0)")
1406unop_convert("fisfinite32", tbool32, tfloat, "isfinite(src0)")
1407
1408# vc4-specific opcodes
1409
1410# Saturated vector add for 4 8bit ints.
1411binop("usadd_4x8_vc4", tint32, _2src_commutative + associative, """
1412dst = 0;
1413for (int i = 0; i < 32; i += 8) {
1414   dst |= MIN2(((src0 >> i) & 0xff) + ((src1 >> i) & 0xff), 0xff) << i;
1415}
1416""")
1417
1418# Saturated vector subtract for 4 8bit ints.
1419binop("ussub_4x8_vc4", tint32, "", """
1420dst = 0;
1421for (int i = 0; i < 32; i += 8) {
1422   int src0_chan = (src0 >> i) & 0xff;
1423   int src1_chan = (src1 >> i) & 0xff;
1424   if (src0_chan > src1_chan)
1425      dst |= (src0_chan - src1_chan) << i;
1426}
1427""")
1428
1429# vector min for 4 8bit ints.
1430binop("umin_4x8_vc4", tint32, _2src_commutative + associative, """
1431dst = 0;
1432for (int i = 0; i < 32; i += 8) {
1433   dst |= MIN2((src0 >> i) & 0xff, (src1 >> i) & 0xff) << i;
1434}
1435""")
1436
1437# vector max for 4 8bit ints.
1438binop("umax_4x8_vc4", tint32, _2src_commutative + associative, """
1439dst = 0;
1440for (int i = 0; i < 32; i += 8) {
1441   dst |= MAX2((src0 >> i) & 0xff, (src1 >> i) & 0xff) << i;
1442}
1443""")
1444
1445# unorm multiply: (a * b) / 255.
1446binop("umul_unorm_4x8_vc4", tuint32, _2src_commutative + associative, """
1447dst = 0;
1448for (int i = 0; i < 32; i += 8) {
1449   uint32_t src0_chan = (src0 >> i) & 0xff;
1450   uint32_t src1_chan = (src1 >> i) & 0xff;
1451   dst |= ((src0_chan * src1_chan) / 255) << i;
1452}
1453""")
1454
1455# v3d-specific opcodes
1456
1457# v3d-specific (v71) instruction that packs bits of 2 2x16 floating point into
1458# r11g11b10 bits, rounding to nearest even, so
1459#  dst[10:0]  = float16_to_float11 (src0[15:0])
1460#  dst[21:11] = float16_to_float11 (src0[31:16])
1461#  dst[31:22] = float16_to_float10 (src1[15:0])
1462binop_convert("pack_32_to_r11g11b10_v3d", tuint32, tuint32, "",
1463              "pack_32_to_r11g11b10_v3d(src0, src1)")
1464
1465# v3d-specific (v71) instruction that packs 2x32 bit to 2x16 bit integer. The
1466# difference with pack_32_2x16_split is that the sources are 32bit too. So it
1467# receives 2 32-bit integer, and packs the lower halfword as 2x16 on a 32-bit
1468# integer.
1469binop_horiz("pack_2x32_to_2x16_v3d", 1, tuint32, 1, tuint32, 1, tuint32,
1470            "(src0.x & 0xffff) | (src1.x << 16)")
1471
1472# v3d-specific (v71) instruction that packs bits of 2 2x16 integers into
1473# r10g10b10a2:
1474#   dst[9:0]   = src0[9:0]
1475#   dst[19:10] = src0[25:16]
1476#   dst[29:20] = src1[9:0]
1477#   dst[31:30] = src1[17:16]
1478binop_convert("pack_uint_32_to_r10g10b10a2_v3d", tuint32, tuint32, "",
1479              "(src0 & 0x3ff) | ((src0 >> 16) & 0x3ff) << 10 | (src1 & 0x3ff) << 20 | ((src1 >> 16) & 0x3ff) << 30")
1480
1481# v3d-specific (v71) instruction that packs 2 2x16 bit integers into 4x8 bits:
1482#   dst[7:0]   = src0[7:0]
1483#   dst[15:8]  = src0[23:16]
1484#   dst[23:16] = src1[7:0]
1485#   dst[31:24] = src1[23:16]
1486opcode("pack_4x16_to_4x8_v3d", 0, tuint32, [0, 0], [tuint32, tuint32],
1487       False, "",
1488       "(src0 & 0x000000ff) | (src0 & 0x00ff0000) >> 8 | (src1 & 0x000000ff) << 16 | (src1 & 0x00ff0000) << 8")
1489
1490# v3d-specific (v71) instructions to convert 2x16 floating point to 2x8 bit unorm/snorm
1491unop("pack_2x16_to_unorm_2x8_v3d", tuint32,
1492     "_mesa_half_to_unorm(src0 & 0xffff, 8) | (_mesa_half_to_unorm(src0 >> 16, 8) << 16)")
1493unop("pack_2x16_to_snorm_2x8_v3d", tuint32,
1494     "_mesa_half_to_snorm(src0 & 0xffff, 8) | ((uint32_t)(_mesa_half_to_snorm(src0 >> 16, 8)) << 16)")
1495
1496# v3d-specific (v71) instructions to convert 32-bit floating point to 16 bit unorm/snorm
1497unop("f2unorm_16_v3d", tuint32, "_mesa_float_to_unorm16(src0)")
1498unop("f2snorm_16_v3d", tuint32, "_mesa_float_to_snorm16(src0)")
1499
1500# v3d-specific (v71) instructions to convert 2x16 bit floating points to 2x10 bit unorm
1501unop("pack_2x16_to_unorm_2x10_v3d", tuint32, "pack_2x16_to_unorm_2x10(src0)")
1502
1503# v3d-specific (v71) instructions to convert 2x16 bit floating points to one 2-bit
1504# and one 10 bit unorm
1505unop("pack_2x16_to_unorm_10_2_v3d", tuint32, "pack_2x16_to_unorm_10_2(src0)")
1506
1507# Mali-specific opcodes
1508unop("fsat_signed_mali", tfloat, ("fmin(fmax(src0, -1.0), 1.0)"))
1509unop("fclamp_pos_mali", tfloat, ("fmax(src0, 0.0)"))
1510
1511opcode("b32fcsel_mdg", 0, tuint, [0, 0, 0],
1512       [tbool32, tfloat, tfloat], False, selection, "src0 ? src1 : src2",
1513       description = csel_description.format("a 32-bit", "0 vs ~0") + """
1514       This Midgard-specific variant takes floating-point sources, rather than
1515       integer sources. That includes support for floating point modifiers in
1516       the backend.
1517       """)
1518
1519# DXIL specific double [un]pack
1520# DXIL doesn't support generic [un]pack instructions, so we want those
1521# lowered to bit ops. HLSL doesn't support 64bit bitcasts to/from
1522# double, only [un]pack. Technically DXIL does, but considering they
1523# can't be generated from HLSL, we want to match what would be coming from DXC.
1524# This is essentially just the standard [un]pack, except that it doesn't get
1525# lowered so we can handle it in the backend and turn it into MakeDouble/SplitDouble
1526unop_horiz("pack_double_2x32_dxil", 1, tuint64, 2, tuint32,
1527           "dst.x = src0.x | ((uint64_t)src0.y << 32);")
1528unop_horiz("unpack_double_2x32_dxil", 2, tuint32, 1, tuint64,
1529           "dst.x = src0.x; dst.y = src0.x >> 32;")
1530
1531# src0 and src1 are i8vec4 packed in an int32, and src2 is an int32.  The int8
1532# components are sign-extended to 32-bits, and a dot-product is performed on
1533# the resulting vectors.  src2 is added to the result of the dot-product.
1534opcode("sdot_4x8_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1535       False, _2src_commutative, """
1536   const int32_t v0x = (int8_t)(src0      );
1537   const int32_t v0y = (int8_t)(src0 >>  8);
1538   const int32_t v0z = (int8_t)(src0 >> 16);
1539   const int32_t v0w = (int8_t)(src0 >> 24);
1540   const int32_t v1x = (int8_t)(src1      );
1541   const int32_t v1y = (int8_t)(src1 >>  8);
1542   const int32_t v1z = (int8_t)(src1 >> 16);
1543   const int32_t v1w = (int8_t)(src1 >> 24);
1544
1545   dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1546""")
1547
1548# Like sdot_4x8_iadd, but unsigned.
1549opcode("udot_4x8_uadd", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32],
1550       False, _2src_commutative, """
1551   const uint32_t v0x = (uint8_t)(src0      );
1552   const uint32_t v0y = (uint8_t)(src0 >>  8);
1553   const uint32_t v0z = (uint8_t)(src0 >> 16);
1554   const uint32_t v0w = (uint8_t)(src0 >> 24);
1555   const uint32_t v1x = (uint8_t)(src1      );
1556   const uint32_t v1y = (uint8_t)(src1 >>  8);
1557   const uint32_t v1z = (uint8_t)(src1 >> 16);
1558   const uint32_t v1w = (uint8_t)(src1 >> 24);
1559
1560   dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1561""")
1562
1563# src0 is i8vec4 packed in an int32, src1 is u8vec4 packed in an int32, and
1564# src2 is an int32.  The 8-bit components are extended to 32-bits, and a
1565# dot-product is performed on the resulting vectors.  src2 is added to the
1566# result of the dot-product.
1567#
1568# NOTE: Unlike many of the other dp4a opcodes, this mixed signs of source 0
1569# and source 1 mean that this opcode is not 2-source commutative
1570opcode("sudot_4x8_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1571       False, "", """
1572   const int32_t v0x = (int8_t)(src0      );
1573   const int32_t v0y = (int8_t)(src0 >>  8);
1574   const int32_t v0z = (int8_t)(src0 >> 16);
1575   const int32_t v0w = (int8_t)(src0 >> 24);
1576   const uint32_t v1x = (uint8_t)(src1      );
1577   const uint32_t v1y = (uint8_t)(src1 >>  8);
1578   const uint32_t v1z = (uint8_t)(src1 >> 16);
1579   const uint32_t v1w = (uint8_t)(src1 >> 24);
1580
1581   dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1582""")
1583
1584# Like sdot_4x8_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff].
1585opcode("sdot_4x8_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1586       False, _2src_commutative, """
1587   const int64_t v0x = (int8_t)(src0      );
1588   const int64_t v0y = (int8_t)(src0 >>  8);
1589   const int64_t v0z = (int8_t)(src0 >> 16);
1590   const int64_t v0w = (int8_t)(src0 >> 24);
1591   const int64_t v1x = (int8_t)(src1      );
1592   const int64_t v1y = (int8_t)(src1 >>  8);
1593   const int64_t v1z = (int8_t)(src1 >> 16);
1594   const int64_t v1w = (int8_t)(src1 >> 24);
1595
1596   const int64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1597
1598   dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp);
1599""")
1600
1601# Like udot_4x8_uadd, but the result is clampled to the range [0, 0xfffffffff].
1602opcode("udot_4x8_uadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1603       False, _2src_commutative, """
1604   const uint64_t v0x = (uint8_t)(src0      );
1605   const uint64_t v0y = (uint8_t)(src0 >>  8);
1606   const uint64_t v0z = (uint8_t)(src0 >> 16);
1607   const uint64_t v0w = (uint8_t)(src0 >> 24);
1608   const uint64_t v1x = (uint8_t)(src1      );
1609   const uint64_t v1y = (uint8_t)(src1 >>  8);
1610   const uint64_t v1z = (uint8_t)(src1 >> 16);
1611   const uint64_t v1w = (uint8_t)(src1 >> 24);
1612
1613   const uint64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1614
1615   dst = tmp >= UINT32_MAX ? UINT32_MAX : tmp;
1616""")
1617
1618# Like sudot_4x8_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff].
1619#
1620# NOTE: Unlike many of the other dp4a opcodes, this mixed signs of source 0
1621# and source 1 mean that this opcode is not 2-source commutative
1622opcode("sudot_4x8_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1623       False, "", """
1624   const int64_t v0x = (int8_t)(src0      );
1625   const int64_t v0y = (int8_t)(src0 >>  8);
1626   const int64_t v0z = (int8_t)(src0 >> 16);
1627   const int64_t v0w = (int8_t)(src0 >> 24);
1628   const uint64_t v1x = (uint8_t)(src1      );
1629   const uint64_t v1y = (uint8_t)(src1 >>  8);
1630   const uint64_t v1z = (uint8_t)(src1 >> 16);
1631   const uint64_t v1w = (uint8_t)(src1 >> 24);
1632
1633   const int64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1634
1635   dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp);
1636""")
1637
1638# src0 and src1 are i16vec2 packed in an int32, and src2 is an int32.  The int16
1639# components are sign-extended to 32-bits, and a dot-product is performed on
1640# the resulting vectors.  src2 is added to the result of the dot-product.
1641opcode("sdot_2x16_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1642       False, _2src_commutative, """
1643   const int32_t v0x = (int16_t)(src0      );
1644   const int32_t v0y = (int16_t)(src0 >> 16);
1645   const int32_t v1x = (int16_t)(src1      );
1646   const int32_t v1y = (int16_t)(src1 >> 16);
1647
1648   dst = (v0x * v1x) + (v0y * v1y) + src2;
1649""")
1650
1651# Like sdot_2x16_iadd, but unsigned.
1652opcode("udot_2x16_uadd", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32],
1653       False, _2src_commutative, """
1654   const uint32_t v0x = (uint16_t)(src0      );
1655   const uint32_t v0y = (uint16_t)(src0 >> 16);
1656   const uint32_t v1x = (uint16_t)(src1      );
1657   const uint32_t v1y = (uint16_t)(src1 >> 16);
1658
1659   dst = (v0x * v1x) + (v0y * v1y) + src2;
1660""")
1661
1662# Like sdot_2x16_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff].
1663opcode("sdot_2x16_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1664       False, _2src_commutative, """
1665   const int64_t v0x = (int16_t)(src0      );
1666   const int64_t v0y = (int16_t)(src0 >> 16);
1667   const int64_t v1x = (int16_t)(src1      );
1668   const int64_t v1y = (int16_t)(src1 >> 16);
1669
1670   const int64_t tmp = (v0x * v1x) + (v0y * v1y) + src2;
1671
1672   dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp);
1673""")
1674
1675# Like udot_2x16_uadd, but the result is clampled to the range [0, 0xfffffffff].
1676opcode("udot_2x16_uadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1677       False, _2src_commutative, """
1678   const uint64_t v0x = (uint16_t)(src0      );
1679   const uint64_t v0y = (uint16_t)(src0 >> 16);
1680   const uint64_t v1x = (uint16_t)(src1      );
1681   const uint64_t v1y = (uint16_t)(src1 >> 16);
1682
1683   const uint64_t tmp = (v0x * v1x) + (v0y * v1y) + src2;
1684
1685   dst = tmp >= UINT32_MAX ? UINT32_MAX : tmp;
1686""")
1687