xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_constant_expressions.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1import re
2from nir_opcodes import opcodes
3from nir_opcodes import type_has_size, type_size, type_sizes, type_base_type
4
5def type_add_size(type_, size):
6    if type_has_size(type_):
7        return type_
8    return type_ + str(size)
9
10def op_bit_sizes(op):
11    sizes = None
12    if not type_has_size(op.output_type):
13        sizes = set(type_sizes(op.output_type))
14
15    for input_type in op.input_types:
16        if not type_has_size(input_type):
17            if sizes is None:
18                sizes = set(type_sizes(input_type))
19            else:
20                sizes = sizes.intersection(set(type_sizes(input_type)))
21
22    return sorted(list(sizes)) if sizes is not None else None
23
24def get_const_field(type_):
25    if type_size(type_) == 1:
26        return 'b'
27    elif type_base_type(type_) == 'bool':
28        return 'i' + str(type_size(type_))
29    elif type_ == "float16":
30        return "u16"
31    else:
32        return type_base_type(type_)[0] + str(type_size(type_))
33
34template = """\
35/*
36 * Copyright (C) 2014 Intel Corporation
37 *
38 * Permission is hereby granted, free of charge, to any person obtaining a
39 * copy of this software and associated documentation files (the "Software"),
40 * to deal in the Software without restriction, including without limitation
41 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
42 * and/or sell copies of the Software, and to permit persons to whom the
43 * Software is furnished to do so, subject to the following conditions:
44 *
45 * The above copyright notice and this permission notice (including the next
46 * paragraph) shall be included in all copies or substantial portions of the
47 * Software.
48 *
49 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
50 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
51 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
52 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
53 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
54 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
55 * IN THE SOFTWARE.
56 */
57
58#include <math.h>
59#include "util/rounding.h" /* for _mesa_roundeven */
60#include "util/half_float.h"
61#include "util/double.h"
62#include "util/softfloat.h"
63#include "util/bigmath.h"
64#include "util/format/format_utils.h"
65#include "util/format_r11g11b10f.h"
66#include "util/u_math.h"
67#include "nir_constant_expressions.h"
68
69/**
70 * \brief Checks if the provided value is a denorm and flushes it to zero.
71 */
72static void
73constant_denorm_flush_to_zero(nir_const_value *value, unsigned bit_size)
74{
75    switch(bit_size) {
76    case 64:
77        if (0 == (value->u64 & 0x7ff0000000000000))
78            value->u64 &= 0x8000000000000000;
79        break;
80    case 32:
81        if (0 == (value->u32 & 0x7f800000))
82            value->u32 &= 0x80000000;
83        break;
84    case 16:
85        if (0 == (value->u16 & 0x7c00))
86            value->u16 &= 0x8000;
87    }
88}
89
90/**
91 * Evaluate one component of packSnorm4x8.
92 */
93static uint8_t
94pack_snorm_1x8(float x)
95{
96    /* From section 8.4 of the GLSL 4.30 spec:
97     *
98     *    packSnorm4x8
99     *    ------------
100     *    The conversion for component c of v to fixed point is done as
101     *    follows:
102     *
103     *      packSnorm4x8: round(clamp(c, -1, +1) * 127.0)
104     *
105     * We must first cast the float to an int, because casting a negative
106     * float to a uint is undefined.
107     */
108   return (uint8_t) (int)
109          _mesa_roundevenf(CLAMP(x, -1.0f, +1.0f) * 127.0f);
110}
111
112/**
113 * Evaluate one component of packSnorm2x16.
114 */
115static uint16_t
116pack_snorm_1x16(float x)
117{
118    /* From section 8.4 of the GLSL ES 3.00 spec:
119     *
120     *    packSnorm2x16
121     *    -------------
122     *    The conversion for component c of v to fixed point is done as
123     *    follows:
124     *
125     *      packSnorm2x16: round(clamp(c, -1, +1) * 32767.0)
126     *
127     * We must first cast the float to an int, because casting a negative
128     * float to a uint is undefined.
129     */
130   return (uint16_t) (int)
131          _mesa_roundevenf(CLAMP(x, -1.0f, +1.0f) * 32767.0f);
132}
133
134/**
135 * Evaluate one component of unpackSnorm4x8.
136 */
137static float
138unpack_snorm_1x8(uint8_t u)
139{
140    /* From section 8.4 of the GLSL 4.30 spec:
141     *
142     *    unpackSnorm4x8
143     *    --------------
144     *    The conversion for unpacked fixed-point value f to floating point is
145     *    done as follows:
146     *
147     *       unpackSnorm4x8: clamp(f / 127.0, -1, +1)
148     */
149   return CLAMP((int8_t) u / 127.0f, -1.0f, +1.0f);
150}
151
152/**
153 * Evaluate one component of unpackSnorm2x16.
154 */
155static float
156unpack_snorm_1x16(uint16_t u)
157{
158    /* From section 8.4 of the GLSL ES 3.00 spec:
159     *
160     *    unpackSnorm2x16
161     *    ---------------
162     *    The conversion for unpacked fixed-point value f to floating point is
163     *    done as follows:
164     *
165     *       unpackSnorm2x16: clamp(f / 32767.0, -1, +1)
166     */
167   return CLAMP((int16_t) u / 32767.0f, -1.0f, +1.0f);
168}
169
170/**
171 * Evaluate one component packUnorm4x8.
172 */
173static uint8_t
174pack_unorm_1x8(float x)
175{
176    /* From section 8.4 of the GLSL 4.30 spec:
177     *
178     *    packUnorm4x8
179     *    ------------
180     *    The conversion for component c of v to fixed point is done as
181     *    follows:
182     *
183     *       packUnorm4x8: round(clamp(c, 0, +1) * 255.0)
184     */
185   return (uint8_t) (int)
186          _mesa_roundevenf(CLAMP(x, 0.0f, 1.0f) * 255.0f);
187}
188
189/**
190 * Evaluate one component packUnorm2x16.
191 */
192static uint16_t
193pack_unorm_1x16(float x)
194{
195    /* From section 8.4 of the GLSL ES 3.00 spec:
196     *
197     *    packUnorm2x16
198     *    -------------
199     *    The conversion for component c of v to fixed point is done as
200     *    follows:
201     *
202     *       packUnorm2x16: round(clamp(c, 0, +1) * 65535.0)
203     */
204   return (uint16_t) (int)
205          _mesa_roundevenf(CLAMP(x, 0.0f, 1.0f) * 65535.0f);
206}
207
208/**
209 * Evaluate one component of unpackUnorm4x8.
210 */
211static float
212unpack_unorm_1x8(uint8_t u)
213{
214    /* From section 8.4 of the GLSL 4.30 spec:
215     *
216     *    unpackUnorm4x8
217     *    --------------
218     *    The conversion for unpacked fixed-point value f to floating point is
219     *    done as follows:
220     *
221     *       unpackUnorm4x8: f / 255.0
222     */
223   return (float) u / 255.0f;
224}
225
226/**
227 * Evaluate one component of unpackUnorm2x16.
228 */
229static float
230unpack_unorm_1x16(uint16_t u)
231{
232    /* From section 8.4 of the GLSL ES 3.00 spec:
233     *
234     *    unpackUnorm2x16
235     *    ---------------
236     *    The conversion for unpacked fixed-point value f to floating point is
237     *    done as follows:
238     *
239     *       unpackUnorm2x16: f / 65535.0
240     */
241   return (float) u / 65535.0f;
242}
243
244/**
245 * Evaluate one component of packHalf2x16.
246 */
247static uint16_t
248pack_half_1x16(float x)
249{
250   return _mesa_float_to_half(x);
251}
252
253/**
254 * Evaluate one component of packHalf2x16, RTZ mode.
255 */
256static uint16_t
257pack_half_1x16_rtz(float x)
258{
259   return _mesa_float_to_float16_rtz(x);
260}
261
262/**
263 * Evaluate one component of unpackHalf2x16.
264 */
265static float
266unpack_half_1x16(uint16_t u, bool ftz)
267{
268   if (0 == (u & 0x7c00) && ftz)
269      u &= 0x8000;
270   return _mesa_half_to_float(u);
271}
272
273/* Broadcom v3d specific instructions */
274/**
275 * Packs 2 2x16 floating split into a r11g11b10f:
276 *
277 * dst[10:0]  = float16_to_float11 (src0[15:0])
278 * dst[21:11] = float16_to_float11 (src0[31:16])
279 * dst[31:22] = float16_to_float10 (src1[15:0])
280 */
281static uint32_t pack_32_to_r11g11b10_v3d(const uint32_t src0,
282                                         const uint32_t src1)
283{
284   float rgb[3] = {
285      unpack_half_1x16((src0 & 0xffff), false),
286      unpack_half_1x16((src0 >> 16), false),
287      unpack_half_1x16((src1 & 0xffff), false),
288   };
289
290   return float3_to_r11g11b10f(rgb);
291}
292
293/**
294  * The three methods below are basically wrappers over pack_s/unorm_1x8/1x16,
295  * as they receives a uint16_t val instead of a float
296  */
297static inline uint8_t _mesa_half_to_snorm8(uint16_t val)
298{
299   return pack_snorm_1x8(_mesa_half_to_float(val));
300}
301
302static uint16_t _mesa_float_to_snorm16(uint32_t val)
303{
304   union fi aux;
305   aux.ui = val;
306   return pack_snorm_1x16(aux.f);
307}
308
309static uint16_t _mesa_float_to_unorm16(uint32_t val)
310{
311   union fi aux;
312   aux.ui = val;
313   return pack_unorm_1x16(aux.f);
314}
315
316static inline uint32_t float_pack16_v3d(uint32_t f32)
317{
318   return _mesa_float_to_half(uif(f32));
319}
320
321static inline uint32_t float_unpack16_v3d(uint32_t f16)
322{
323   return fui(_mesa_half_to_float(f16));
324}
325
326static inline uint32_t vfpack_v3d(uint32_t a, uint32_t b)
327{
328   return float_pack16_v3d(b) << 16 | float_pack16_v3d(a);
329}
330
331static inline uint32_t vfsat_v3d(uint32_t a)
332{
333   const uint32_t low = fui(SATURATE(_mesa_half_to_float(a & 0xffff)));
334   const uint32_t high = fui(SATURATE(_mesa_half_to_float(a >> 16)));
335
336   return vfpack_v3d(low, high);
337}
338
339static inline uint32_t fmul_v3d(uint32_t a, uint32_t b)
340{
341   return fui(uif(a) * uif(b));
342}
343
344static uint32_t vfmul_v3d(uint32_t a, uint32_t b)
345{
346   const uint32_t low = fmul_v3d(float_unpack16_v3d(a & 0xffff),
347                                 float_unpack16_v3d(b & 0xffff));
348   const uint32_t high = fmul_v3d(float_unpack16_v3d(a >> 16),
349                                  float_unpack16_v3d(b >> 16));
350
351   return vfpack_v3d(low, high);
352}
353
354/* Convert 2x16-bit floating point to 2x10-bit unorm */
355static uint32_t pack_2x16_to_unorm_2x10(uint32_t src0)
356{
357   return vfmul_v3d(vfsat_v3d(src0), 0x03ff03ff);
358}
359
360/*
361 * Convert 2x16-bit floating point to one 2-bit and one
362 * 10-bit unorm
363 */
364static uint32_t pack_2x16_to_unorm_10_2(uint32_t src0)
365{
366   return vfmul_v3d(vfsat_v3d(src0), 0x000303ff);
367}
368
369static uint32_t
370msad(uint32_t src0, uint32_t src1, uint32_t src2) {
371   uint32_t res = src2;
372   for (unsigned i = 0; i < 4; i++) {
373      const uint8_t ref = src0 >> (i * 8);
374      const uint8_t src = src1 >> (i * 8);
375      if (ref != 0)
376         res += MAX2(ref, src) - MIN2(ref, src);
377   }
378   return res;
379}
380
381/* Some typed vector structures to make things like src0.y work */
382typedef int8_t int1_t;
383typedef uint8_t uint1_t;
384typedef float float16_t;
385typedef float float32_t;
386typedef double float64_t;
387typedef bool bool1_t;
388typedef bool bool8_t;
389typedef bool bool16_t;
390typedef bool bool32_t;
391typedef bool bool64_t;
392% for type in ["float", "int", "uint", "bool"]:
393% for width in type_sizes(type):
394struct ${type}${width}_vec {
395   ${type}${width}_t x;
396   ${type}${width}_t y;
397   ${type}${width}_t z;
398   ${type}${width}_t w;
399   ${type}${width}_t e;
400   ${type}${width}_t f;
401   ${type}${width}_t g;
402   ${type}${width}_t h;
403   ${type}${width}_t i;
404   ${type}${width}_t j;
405   ${type}${width}_t k;
406   ${type}${width}_t l;
407   ${type}${width}_t m;
408   ${type}${width}_t n;
409   ${type}${width}_t o;
410   ${type}${width}_t p;
411};
412% endfor
413% endfor
414
415<%def name="evaluate_op(op, bit_size, execution_mode)">
416   <%
417   output_type = type_add_size(op.output_type, bit_size)
418   input_types = [type_add_size(type_, bit_size) for type_ in op.input_types]
419   %>
420
421   ## For each non-per-component input, create a variable srcN that
422   ## contains x, y, z, and w elements which are filled in with the
423   ## appropriately-typed values.
424   % for j in range(op.num_inputs):
425      % if op.input_sizes[j] == 0:
426         <% continue %>
427      % elif "src" + str(j) not in op.const_expr:
428         ## Avoid unused variable warnings
429         <% continue %>
430      %endif
431
432      const struct ${input_types[j]}_vec src${j} = {
433      % for k in range(op.input_sizes[j]):
434         % if input_types[j] == "int1":
435             /* 1-bit integers use a 0/-1 convention */
436             -(int1_t)_src[${j}][${k}].b,
437         % elif input_types[j] == "float16":
438            _mesa_half_to_float(_src[${j}][${k}].u16),
439         % else:
440            _src[${j}][${k}].${get_const_field(input_types[j])},
441         % endif
442      % endfor
443      % for k in range(op.input_sizes[j], 16):
444         0,
445      % endfor
446      };
447   % endfor
448
449   % if op.output_size == 0:
450      ## For per-component instructions, we need to iterate over the
451      ## components and apply the constant expression one component
452      ## at a time.
453      for (unsigned _i = 0; _i < num_components; _i++) {
454         ## For each per-component input, create a variable srcN that
455         ## contains the value of the current (_i'th) component.
456         % for j in range(op.num_inputs):
457            % if op.input_sizes[j] != 0:
458               <% continue %>
459            % elif "src" + str(j) not in op.const_expr:
460               ## Avoid unused variable warnings
461               <% continue %>
462            % elif input_types[j] == "int1":
463               /* 1-bit integers use a 0/-1 convention */
464               const int1_t src${j} = -(int1_t)_src[${j}][_i].b;
465            % elif input_types[j] == "float16":
466               const float src${j} =
467                  _mesa_half_to_float(_src[${j}][_i].u16);
468            % else:
469               const ${input_types[j]}_t src${j} =
470                  _src[${j}][_i].${get_const_field(input_types[j])};
471            % endif
472         % endfor
473
474         ## Create an appropriately-typed variable dst and assign the
475         ## result of the const_expr to it.  If const_expr already contains
476         ## writes to dst, just include const_expr directly.
477         % if "dst" in op.const_expr:
478            ${output_type}_t dst;
479
480            ${op.const_expr}
481         % else:
482            ${output_type}_t dst = ${op.const_expr};
483         % endif
484
485         ## Store the current component of the actual destination to the
486         ## value of dst.
487         % if output_type == "int1" or output_type == "uint1":
488            /* 1-bit integers get truncated */
489            _dst_val[_i].b = dst & 1;
490         % elif output_type.startswith("bool"):
491            ## Sanitize the C value to a proper NIR 0/-1 bool
492            _dst_val[_i].${get_const_field(output_type)} = -(int)dst;
493         % elif output_type == "float16":
494            if (nir_is_rounding_mode_rtz(execution_mode, 16)) {
495               _dst_val[_i].u16 = _mesa_float_to_float16_rtz(dst);
496            } else {
497               _dst_val[_i].u16 = _mesa_float_to_float16_rtne(dst);
498            }
499         % else:
500            _dst_val[_i].${get_const_field(output_type)} = dst;
501         % endif
502
503         % if op.name != "fquantize2f16" and type_base_type(output_type) == "float":
504            % if type_has_size(output_type):
505               if (nir_is_denorm_flush_to_zero(execution_mode, ${type_size(output_type)})) {
506                  constant_denorm_flush_to_zero(&_dst_val[_i], ${type_size(output_type)});
507               }
508            % else:
509               if (nir_is_denorm_flush_to_zero(execution_mode, ${bit_size})) {
510                  constant_denorm_flush_to_zero(&_dst_val[i], bit_size);
511               }
512            %endif
513         % endif
514      }
515   % else:
516      ## In the non-per-component case, create a struct dst with
517      ## appropriately-typed elements x, y, z, and w and assign the result
518      ## of the const_expr to all components of dst, or include the
519      ## const_expr directly if it writes to dst already.
520      struct ${output_type}_vec dst;
521
522      % if "dst" in op.const_expr:
523         ${op.const_expr}
524      % else:
525         ## Splat the value to all components.  This way expressions which
526         ## write the same value to all components don't need to explicitly
527         ## write to dest.
528         dst.x = dst.y = dst.z = dst.w = ${op.const_expr};
529      % endif
530
531      ## For each component in the destination, copy the value of dst to
532      ## the actual destination.
533      % for k in range(op.output_size):
534         % if output_type == "int1" or output_type == "uint1":
535            /* 1-bit integers get truncated */
536            _dst_val[${k}].b = dst.${"xyzwefghijklmnop"[k]} & 1;
537         % elif output_type.startswith("bool"):
538            ## Sanitize the C value to a proper NIR 0/-1 bool
539            _dst_val[${k}].${get_const_field(output_type)} = -(int)dst.${"xyzwefghijklmnop"[k]};
540         % elif output_type == "float16":
541            if (nir_is_rounding_mode_rtz(execution_mode, 16)) {
542               _dst_val[${k}].u16 = _mesa_float_to_float16_rtz(dst.${"xyzwefghijklmnop"[k]});
543            } else {
544               _dst_val[${k}].u16 = _mesa_float_to_float16_rtne(dst.${"xyzwefghijklmnop"[k]});
545            }
546         % else:
547            _dst_val[${k}].${get_const_field(output_type)} = dst.${"xyzwefghijklmnop"[k]};
548         % endif
549
550         % if op.name != "fquantize2f16" and type_base_type(output_type) == "float":
551            % if type_has_size(output_type):
552               if (nir_is_denorm_flush_to_zero(execution_mode, ${type_size(output_type)})) {
553                  constant_denorm_flush_to_zero(&_dst_val[${k}], ${type_size(output_type)});
554               }
555            % else:
556               if (nir_is_denorm_flush_to_zero(execution_mode, ${bit_size})) {
557                  constant_denorm_flush_to_zero(&_dst_val[${k}], bit_size);
558               }
559            % endif
560         % endif
561      % endfor
562   % endif
563</%def>
564
565% for name, op in sorted(opcodes.items()):
566% if op.name == "fsat":
567#if defined(_MSC_VER) && (defined(_M_ARM64) || defined(_M_ARM64EC))
568#pragma optimize("", off) /* Temporary work-around for MSVC compiler bug, present in VS2019 16.9.2 */
569#endif
570% endif
571static void
572evaluate_${name}(nir_const_value *_dst_val,
573                 UNUSED unsigned num_components,
574                 ${"UNUSED" if op_bit_sizes(op) is None else ""} unsigned bit_size,
575                 UNUSED nir_const_value **_src,
576                 UNUSED unsigned execution_mode)
577{
578   % if op_bit_sizes(op) is not None:
579      switch (bit_size) {
580      % for bit_size in op_bit_sizes(op):
581      case ${bit_size}: {
582         ${evaluate_op(op, bit_size, execution_mode)}
583         break;
584      }
585      % endfor
586
587      default:
588         unreachable("unknown bit width");
589      }
590   % else:
591      ${evaluate_op(op, 0, execution_mode)}
592   % endif
593}
594% if op.name == "fsat":
595#if defined(_MSC_VER) && (defined(_M_ARM64) || defined(_M_ARM64EC))
596#pragma optimize("", on) /* Temporary work-around for MSVC compiler bug, present in VS2019 16.9.2 */
597#endif
598% endif
599% endfor
600
601void
602nir_eval_const_opcode(nir_op op, nir_const_value *dest,
603                      unsigned num_components, unsigned bit_width,
604                      nir_const_value **src,
605                      unsigned float_controls_execution_mode)
606{
607   switch (op) {
608% for name in sorted(opcodes.keys()):
609   case nir_op_${name}:
610      evaluate_${name}(dest, num_components, bit_width, src, float_controls_execution_mode);
611      return;
612% endfor
613   default:
614      unreachable("shouldn't get here");
615   }
616}"""
617
618from mako.template import Template
619
620print(Template(template).render(opcodes=opcodes, type_sizes=type_sizes,
621                                type_base_type=type_base_type,
622                                type_size=type_size,
623                                type_has_size=type_has_size,
624                                type_add_size=type_add_size,
625                                op_bit_sizes=op_bit_sizes,
626                                get_const_field=get_const_field))
627