xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_mediump.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright (C) 2020 Google, Inc.
3  * Copyright (C) 2021 Advanced Micro Devices, Inc.
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include "nir.h"
26 #include "nir_builder.h"
27 
28 /**
29  * Return the intrinsic if it matches the mask in "modes", else return NULL.
30  */
31 static nir_intrinsic_instr *
get_io_intrinsic(nir_instr * instr,nir_variable_mode modes,nir_variable_mode * out_mode)32 get_io_intrinsic(nir_instr *instr, nir_variable_mode modes,
33                  nir_variable_mode *out_mode)
34 {
35    if (instr->type != nir_instr_type_intrinsic)
36       return NULL;
37 
38    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
39 
40    switch (intr->intrinsic) {
41    case nir_intrinsic_load_input:
42    case nir_intrinsic_load_per_primitive_input:
43    case nir_intrinsic_load_input_vertex:
44    case nir_intrinsic_load_interpolated_input:
45    case nir_intrinsic_load_per_vertex_input:
46       *out_mode = nir_var_shader_in;
47       return modes & nir_var_shader_in ? intr : NULL;
48    case nir_intrinsic_load_output:
49    case nir_intrinsic_load_per_vertex_output:
50    case nir_intrinsic_store_output:
51    case nir_intrinsic_store_per_vertex_output:
52       *out_mode = nir_var_shader_out;
53       return modes & nir_var_shader_out ? intr : NULL;
54    default:
55       return NULL;
56    }
57 }
58 
59 /**
60  * Recompute the IO "base" indices from scratch to remove holes or to fix
61  * incorrect base values due to changes in IO locations by using IO locations
62  * to assign new bases. The mapping from locations to bases becomes
63  * monotonically increasing.
64  */
65 bool
nir_recompute_io_bases(nir_shader * nir,nir_variable_mode modes)66 nir_recompute_io_bases(nir_shader *nir, nir_variable_mode modes)
67 {
68    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
69 
70    BITSET_DECLARE(inputs, NUM_TOTAL_VARYING_SLOTS);
71    BITSET_DECLARE(per_prim_inputs, NUM_TOTAL_VARYING_SLOTS);  /* FS only */
72    BITSET_DECLARE(dual_slot_inputs, NUM_TOTAL_VARYING_SLOTS); /* VS only */
73    BITSET_DECLARE(outputs, NUM_TOTAL_VARYING_SLOTS);
74    BITSET_ZERO(inputs);
75    BITSET_ZERO(per_prim_inputs);
76    BITSET_ZERO(dual_slot_inputs);
77    BITSET_ZERO(outputs);
78 
79    /* Gather the bitmasks of used locations. */
80    nir_foreach_block_safe(block, impl) {
81       nir_foreach_instr_safe(instr, block) {
82          nir_variable_mode mode;
83          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
84          if (!intr)
85             continue;
86 
87          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
88          unsigned num_slots = sem.num_slots;
89          if (sem.medium_precision)
90             num_slots = (num_slots + sem.high_16bits + 1) / 2;
91 
92          if (mode == nir_var_shader_in) {
93             for (unsigned i = 0; i < num_slots; i++) {
94                if (intr->intrinsic == nir_intrinsic_load_per_primitive_input)
95                   BITSET_SET(per_prim_inputs, sem.location + i);
96                else
97                   BITSET_SET(inputs, sem.location + i);
98 
99                if (sem.high_dvec2)
100                   BITSET_SET(dual_slot_inputs, sem.location + i);
101             }
102          } else if (!sem.dual_source_blend_index) {
103             for (unsigned i = 0; i < num_slots; i++)
104                BITSET_SET(outputs, sem.location + i);
105          }
106       }
107    }
108 
109    const unsigned num_normal_inputs = BITSET_COUNT(inputs) + BITSET_COUNT(dual_slot_inputs);
110 
111    /* Renumber bases. */
112    bool changed = false;
113 
114    nir_foreach_block_safe(block, impl) {
115       nir_foreach_instr_safe(instr, block) {
116          nir_variable_mode mode;
117          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
118          if (!intr)
119             continue;
120 
121          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
122          unsigned num_slots = sem.num_slots;
123          if (sem.medium_precision)
124             num_slots = (num_slots + sem.high_16bits + 1) / 2;
125 
126          if (mode == nir_var_shader_in) {
127             if (intr->intrinsic == nir_intrinsic_load_per_primitive_input) {
128                nir_intrinsic_set_base(intr,
129                                       num_normal_inputs +
130                                       BITSET_PREFIX_SUM(per_prim_inputs, sem.location));
131             } else {
132                nir_intrinsic_set_base(intr,
133                                       BITSET_PREFIX_SUM(inputs, sem.location) +
134                                       BITSET_PREFIX_SUM(dual_slot_inputs, sem.location) +
135                                       (sem.high_dvec2 ? 1 : 0));
136             }
137          } else if (sem.dual_source_blend_index) {
138             nir_intrinsic_set_base(intr,
139                                    BITSET_PREFIX_SUM(outputs, NUM_TOTAL_VARYING_SLOTS));
140          } else {
141             nir_intrinsic_set_base(intr,
142                                    BITSET_PREFIX_SUM(outputs, sem.location));
143          }
144          changed = true;
145       }
146    }
147 
148    if (changed) {
149       nir_metadata_preserve(impl, nir_metadata_control_flow);
150    } else {
151       nir_metadata_preserve(impl, nir_metadata_all);
152    }
153 
154    if (modes & nir_var_shader_in)
155       nir->num_inputs = BITSET_COUNT(inputs);
156    if (modes & nir_var_shader_out)
157       nir->num_outputs = BITSET_COUNT(outputs);
158 
159    return changed;
160 }
161 
162 /**
163  * Lower mediump inputs and/or outputs to 16 bits.
164  *
165  * \param modes            Whether to lower inputs, outputs, or both.
166  * \param varying_mask     Determines which varyings to skip (VS inputs,
167  *    FS outputs, and patch varyings ignore this mask).
168  * \param use_16bit_slots  Remap lowered slots to* VARYING_SLOT_VARn_16BIT.
169  */
170 bool
nir_lower_mediump_io(nir_shader * nir,nir_variable_mode modes,uint64_t varying_mask,bool use_16bit_slots)171 nir_lower_mediump_io(nir_shader *nir, nir_variable_mode modes,
172                      uint64_t varying_mask, bool use_16bit_slots)
173 {
174    bool changed = false;
175    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
176    assert(impl);
177 
178    nir_builder b = nir_builder_create(impl);
179 
180    nir_foreach_block_safe(block, impl) {
181       nir_foreach_instr_safe(instr, block) {
182          nir_variable_mode mode;
183          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
184          if (!intr)
185             continue;
186 
187          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
188          nir_def *(*convert)(nir_builder *, nir_def *);
189          bool is_varying = !(nir->info.stage == MESA_SHADER_VERTEX &&
190                              mode == nir_var_shader_in) &&
191                            !(nir->info.stage == MESA_SHADER_FRAGMENT &&
192                              mode == nir_var_shader_out);
193 
194          if (is_varying && sem.location <= VARYING_SLOT_VAR31 &&
195              !(varying_mask & BITFIELD64_BIT(sem.location))) {
196             continue; /* can't lower */
197          }
198 
199          if (nir_intrinsic_has_src_type(intr)) {
200             /* Stores. */
201             nir_alu_type type = nir_intrinsic_src_type(intr);
202 
203             nir_op upconvert_op;
204             switch (type) {
205             case nir_type_float32:
206                convert = nir_f2fmp;
207                upconvert_op = nir_op_f2f32;
208                break;
209             case nir_type_int32:
210                convert = nir_i2imp;
211                upconvert_op = nir_op_i2i32;
212                break;
213             case nir_type_uint32:
214                convert = nir_i2imp;
215                upconvert_op = nir_op_u2u32;
216                break;
217             default:
218                continue; /* already lowered? */
219             }
220 
221             /* Check that the output is mediump, or (for fragment shader
222              * outputs) is a conversion from a mediump value, and lower it to
223              * mediump.  Note that we don't automatically apply it to
224              * gl_FragDepth, as GLSL ES declares it highp and so hardware such
225              * as Adreno a6xx doesn't expect a half-float output for it.
226              */
227             nir_def *val = intr->src[0].ssa;
228             bool is_fragdepth = (nir->info.stage == MESA_SHADER_FRAGMENT &&
229                                  sem.location == FRAG_RESULT_DEPTH);
230             if (!sem.medium_precision &&
231                 (is_varying || is_fragdepth || val->parent_instr->type != nir_instr_type_alu ||
232                  nir_instr_as_alu(val->parent_instr)->op != upconvert_op)) {
233                continue;
234             }
235 
236             /* Convert the 32-bit store into a 16-bit store. */
237             b.cursor = nir_before_instr(&intr->instr);
238             nir_src_rewrite(&intr->src[0], convert(&b, intr->src[0].ssa));
239             nir_intrinsic_set_src_type(intr, (type & ~32) | 16);
240          } else {
241             if (!sem.medium_precision)
242                continue;
243 
244             /* Loads. */
245             nir_alu_type type = nir_intrinsic_dest_type(intr);
246 
247             switch (type) {
248             case nir_type_float32:
249                convert = nir_f2f32;
250                break;
251             case nir_type_int32:
252                convert = nir_i2i32;
253                break;
254             case nir_type_uint32:
255                convert = nir_u2u32;
256                break;
257             default:
258                continue; /* already lowered? */
259             }
260 
261             /* Convert the 32-bit load into a 16-bit load. */
262             b.cursor = nir_after_instr(&intr->instr);
263             intr->def.bit_size = 16;
264             nir_intrinsic_set_dest_type(intr, (type & ~32) | 16);
265             nir_def *dst = convert(&b, &intr->def);
266             nir_def_rewrite_uses_after(&intr->def, dst,
267                                        dst->parent_instr);
268          }
269 
270          if (use_16bit_slots && is_varying &&
271              sem.location >= VARYING_SLOT_VAR0 &&
272              sem.location <= VARYING_SLOT_VAR31) {
273             unsigned index = sem.location - VARYING_SLOT_VAR0;
274 
275             sem.location = VARYING_SLOT_VAR0_16BIT + index / 2;
276             sem.high_16bits = index % 2;
277             nir_intrinsic_set_io_semantics(intr, sem);
278          }
279          changed = true;
280       }
281    }
282 
283    if (changed && use_16bit_slots)
284       nir_recompute_io_bases(nir, modes);
285 
286    if (changed) {
287       nir_metadata_preserve(impl, nir_metadata_control_flow);
288    } else {
289       nir_metadata_preserve(impl, nir_metadata_all);
290    }
291 
292    return changed;
293 }
294 
295 /**
296  * Set the mediump precision bit for those shader inputs and outputs that are
297  * set in the "modes" mask. Non-generic varyings (that GLES3 doesn't have)
298  * are ignored. The "types" mask can be (nir_type_float | nir_type_int), etc.
299  */
300 bool
nir_force_mediump_io(nir_shader * nir,nir_variable_mode modes,nir_alu_type types)301 nir_force_mediump_io(nir_shader *nir, nir_variable_mode modes,
302                      nir_alu_type types)
303 {
304    bool changed = false;
305    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
306    assert(impl);
307 
308    nir_foreach_block_safe(block, impl) {
309       nir_foreach_instr_safe(instr, block) {
310          nir_variable_mode mode;
311          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
312          if (!intr)
313             continue;
314 
315          nir_alu_type type;
316          if (nir_intrinsic_has_src_type(intr))
317             type = nir_intrinsic_src_type(intr);
318          else
319             type = nir_intrinsic_dest_type(intr);
320          if (!(type & types))
321             continue;
322 
323          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
324 
325          if (nir->info.stage == MESA_SHADER_FRAGMENT &&
326              mode == nir_var_shader_out) {
327             /* Only accept FS outputs. */
328             if (sem.location < FRAG_RESULT_DATA0 &&
329                 sem.location != FRAG_RESULT_COLOR)
330                continue;
331          } else if (nir->info.stage == MESA_SHADER_VERTEX &&
332                     mode == nir_var_shader_in) {
333             /* Accept all VS inputs. */
334          } else {
335             /* Only accept generic varyings. */
336             if (sem.location < VARYING_SLOT_VAR0 ||
337                 sem.location > VARYING_SLOT_VAR31)
338                continue;
339          }
340 
341          sem.medium_precision = 1;
342          nir_intrinsic_set_io_semantics(intr, sem);
343          changed = true;
344       }
345    }
346 
347    if (changed) {
348       nir_metadata_preserve(impl, nir_metadata_control_flow);
349    } else {
350       nir_metadata_preserve(impl, nir_metadata_all);
351    }
352 
353    return changed;
354 }
355 
356 /**
357  * Remap 16-bit varying slots to the original 32-bit varying slots.
358  * This only changes IO semantics and bases.
359  */
360 bool
nir_unpack_16bit_varying_slots(nir_shader * nir,nir_variable_mode modes)361 nir_unpack_16bit_varying_slots(nir_shader *nir, nir_variable_mode modes)
362 {
363    bool changed = false;
364    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
365    assert(impl);
366 
367    nir_foreach_block_safe(block, impl) {
368       nir_foreach_instr_safe(instr, block) {
369          nir_variable_mode mode;
370          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
371          if (!intr)
372             continue;
373 
374          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
375 
376          if (sem.location < VARYING_SLOT_VAR0_16BIT ||
377              sem.location > VARYING_SLOT_VAR15_16BIT)
378             continue;
379 
380          sem.location = VARYING_SLOT_VAR0 +
381                         (sem.location - VARYING_SLOT_VAR0_16BIT) * 2 +
382                         sem.high_16bits;
383          sem.high_16bits = 0;
384          nir_intrinsic_set_io_semantics(intr, sem);
385          changed = true;
386       }
387    }
388 
389    if (changed)
390       nir_recompute_io_bases(nir, modes);
391 
392    if (changed) {
393       nir_metadata_preserve(impl, nir_metadata_control_flow);
394    } else {
395       nir_metadata_preserve(impl, nir_metadata_all);
396    }
397 
398    return changed;
399 }
400 
401 static bool
is_mediump_or_lowp(unsigned precision)402 is_mediump_or_lowp(unsigned precision)
403 {
404    return precision == GLSL_PRECISION_LOW || precision == GLSL_PRECISION_MEDIUM;
405 }
406 
407 static bool
try_lower_mediump_var(nir_variable * var,nir_variable_mode modes,struct set * set)408 try_lower_mediump_var(nir_variable *var, nir_variable_mode modes, struct set *set)
409 {
410    if (!(var->data.mode & modes) || !is_mediump_or_lowp(var->data.precision))
411       return false;
412 
413    if (set && _mesa_set_search(set, var))
414       return false;
415 
416    const struct glsl_type *new_type = glsl_type_to_16bit(var->type);
417    if (var->type == new_type)
418       return false;
419 
420    var->type = new_type;
421    return true;
422 }
423 
424 static bool
nir_lower_mediump_vars_impl(nir_function_impl * impl,nir_variable_mode modes,bool any_lowered)425 nir_lower_mediump_vars_impl(nir_function_impl *impl, nir_variable_mode modes,
426                             bool any_lowered)
427 {
428    bool progress = false;
429 
430    if (modes & nir_var_function_temp) {
431       nir_foreach_function_temp_variable(var, impl) {
432          any_lowered = try_lower_mediump_var(var, modes, NULL) || any_lowered;
433       }
434    }
435    if (!any_lowered)
436       return false;
437 
438    nir_builder b = nir_builder_create(impl);
439 
440    nir_foreach_block(block, impl) {
441       nir_foreach_instr_safe(instr, block) {
442          switch (instr->type) {
443          case nir_instr_type_deref: {
444             nir_deref_instr *deref = nir_instr_as_deref(instr);
445 
446             if (deref->modes & modes) {
447                switch (deref->deref_type) {
448                case nir_deref_type_var:
449                   deref->type = deref->var->type;
450                   break;
451                case nir_deref_type_array:
452                case nir_deref_type_array_wildcard:
453                   deref->type = glsl_get_array_element(nir_deref_instr_parent(deref)->type);
454                   break;
455                case nir_deref_type_struct:
456                   deref->type = glsl_get_struct_field(nir_deref_instr_parent(deref)->type, deref->strct.index);
457                   break;
458                default:
459                   nir_print_instr(instr, stderr);
460                   unreachable("unsupported deref type");
461                }
462             }
463 
464             break;
465          }
466 
467          case nir_instr_type_intrinsic: {
468             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
469             switch (intrin->intrinsic) {
470             case nir_intrinsic_load_deref: {
471 
472                if (intrin->def.bit_size != 32)
473                   break;
474 
475                nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
476                if (glsl_get_bit_size(deref->type) != 16)
477                   break;
478 
479                intrin->def.bit_size = 16;
480 
481                b.cursor = nir_after_instr(&intrin->instr);
482                nir_def *replace = NULL;
483                switch (glsl_get_base_type(deref->type)) {
484                case GLSL_TYPE_FLOAT16:
485                   replace = nir_f2f32(&b, &intrin->def);
486                   break;
487                case GLSL_TYPE_INT16:
488                   replace = nir_i2i32(&b, &intrin->def);
489                   break;
490                case GLSL_TYPE_UINT16:
491                   replace = nir_u2u32(&b, &intrin->def);
492                   break;
493                default:
494                   unreachable("Invalid 16-bit type");
495                }
496 
497                nir_def_rewrite_uses_after(&intrin->def,
498                                           replace,
499                                           replace->parent_instr);
500                progress = true;
501                break;
502             }
503 
504             case nir_intrinsic_store_deref: {
505                nir_def *data = intrin->src[1].ssa;
506                if (data->bit_size != 32)
507                   break;
508 
509                nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
510                if (glsl_get_bit_size(deref->type) != 16)
511                   break;
512 
513                b.cursor = nir_before_instr(&intrin->instr);
514                nir_def *replace = NULL;
515                switch (glsl_get_base_type(deref->type)) {
516                case GLSL_TYPE_FLOAT16:
517                   replace = nir_f2fmp(&b, data);
518                   break;
519                case GLSL_TYPE_INT16:
520                case GLSL_TYPE_UINT16:
521                   replace = nir_i2imp(&b, data);
522                   break;
523                default:
524                   unreachable("Invalid 16-bit type");
525                }
526 
527                nir_src_rewrite(&intrin->src[1], replace);
528                progress = true;
529                break;
530             }
531 
532             case nir_intrinsic_copy_deref: {
533                nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
534                nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
535                /* If we convert once side of a copy and not the other, that
536                 * would be very bad.
537                 */
538                if (nir_deref_mode_may_be(dst, modes) ||
539                    nir_deref_mode_may_be(src, modes)) {
540                   assert(nir_deref_mode_must_be(dst, modes));
541                   assert(nir_deref_mode_must_be(src, modes));
542                }
543                break;
544             }
545 
546             default:
547                break;
548             }
549             break;
550          }
551 
552          default:
553             break;
554          }
555       }
556    }
557 
558    if (progress) {
559       nir_metadata_preserve(impl, nir_metadata_control_flow);
560    } else {
561       nir_metadata_preserve(impl, nir_metadata_all);
562    }
563 
564    return progress;
565 }
566 
567 bool
nir_lower_mediump_vars(nir_shader * shader,nir_variable_mode modes)568 nir_lower_mediump_vars(nir_shader *shader, nir_variable_mode modes)
569 {
570    bool progress = false;
571 
572    if (modes & ~nir_var_function_temp) {
573       /* Don't lower GLES mediump atomic ops to 16-bit -- no hardware is expecting that. */
574       struct set *no_lower_set = _mesa_pointer_set_create(NULL);
575       nir_foreach_block(block, nir_shader_get_entrypoint(shader)) {
576          nir_foreach_instr(instr, block) {
577             if (instr->type != nir_instr_type_intrinsic)
578                continue;
579             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
580             switch (intr->intrinsic) {
581             case nir_intrinsic_deref_atomic:
582             case nir_intrinsic_deref_atomic_swap: {
583                nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
584                nir_variable *var = nir_deref_instr_get_variable(deref);
585 
586                /* If we have atomic derefs that we can't track, then don't lower any mediump.  */
587                if (!var)
588                   return false;
589 
590                _mesa_set_add(no_lower_set, var);
591                break;
592             }
593 
594             default:
595                break;
596             }
597          }
598       }
599 
600       nir_foreach_variable_in_shader(var, shader) {
601          progress = try_lower_mediump_var(var, modes, no_lower_set) || progress;
602       }
603 
604       ralloc_free(no_lower_set);
605    }
606 
607    nir_foreach_function_impl(impl, shader) {
608       if (nir_lower_mediump_vars_impl(impl, modes, progress))
609          progress = true;
610    }
611 
612    return progress;
613 }
614 
615 /**
616  * Fix types of source operands of texture opcodes according to
617  * the constraints by inserting the appropriate conversion opcodes.
618  *
619  * For example, if the type of derivatives must be equal to texture
620  * coordinates and the type of the texture bias must be 32-bit, there
621  * will be 2 constraints describing that.
622  */
623 static bool
legalize_16bit_sampler_srcs(nir_builder * b,nir_instr * instr,void * data)624 legalize_16bit_sampler_srcs(nir_builder *b, nir_instr *instr, void *data)
625 {
626    bool progress = false;
627    nir_tex_src_type_constraint *constraints = data;
628 
629    if (instr->type != nir_instr_type_tex)
630       return false;
631 
632    nir_tex_instr *tex = nir_instr_as_tex(instr);
633    int8_t map[nir_num_tex_src_types];
634    memset(map, -1, sizeof(map));
635 
636    /* Create a mapping from src_type to src[i]. */
637    for (unsigned i = 0; i < tex->num_srcs; i++)
638       map[tex->src[i].src_type] = i;
639 
640    /* Legalize src types. */
641    for (unsigned i = 0; i < tex->num_srcs; i++) {
642       nir_tex_src_type_constraint c = constraints[tex->src[i].src_type];
643 
644       if (!c.legalize_type)
645          continue;
646 
647       /* Determine the required bit size for the src. */
648       unsigned bit_size;
649       if (c.bit_size) {
650          bit_size = c.bit_size;
651       } else {
652          if (map[c.match_src] == -1)
653             continue; /* e.g. txs */
654 
655          bit_size = tex->src[map[c.match_src]].src.ssa->bit_size;
656       }
657 
658       /* Check if the type is legal. */
659       if (bit_size == tex->src[i].src.ssa->bit_size)
660          continue;
661 
662       /* Fix the bit size. */
663       bool is_sint = nir_tex_instr_src_type(tex, i) == nir_type_int;
664       bool is_uint = nir_tex_instr_src_type(tex, i) == nir_type_uint;
665       nir_def *(*convert)(nir_builder *, nir_def *);
666 
667       switch (bit_size) {
668       case 16:
669          convert = is_sint ? nir_i2i16 : is_uint ? nir_u2u16
670                                                  : nir_f2f16;
671          break;
672       case 32:
673          convert = is_sint ? nir_i2i32 : is_uint ? nir_u2u32
674                                                  : nir_f2f32;
675          break;
676       default:
677          assert(!"unexpected bit size");
678          continue;
679       }
680 
681       b->cursor = nir_before_instr(&tex->instr);
682       nir_src_rewrite(&tex->src[i].src, convert(b, tex->src[i].src.ssa));
683       progress = true;
684    }
685 
686    return progress;
687 }
688 
689 bool
nir_legalize_16bit_sampler_srcs(nir_shader * nir,nir_tex_src_type_constraints constraints)690 nir_legalize_16bit_sampler_srcs(nir_shader *nir,
691                                 nir_tex_src_type_constraints constraints)
692 {
693    return nir_shader_instructions_pass(nir, legalize_16bit_sampler_srcs,
694                                        nir_metadata_control_flow,
695                                        constraints);
696 }
697 
698 static bool
const_is_f16(nir_scalar scalar)699 const_is_f16(nir_scalar scalar)
700 {
701    double value = nir_scalar_as_float(scalar);
702    uint16_t fp16_val = _mesa_float_to_half(value);
703    bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
704    return value == _mesa_half_to_float(fp16_val) && !is_denorm;
705 }
706 
707 static bool
const_is_u16(nir_scalar scalar)708 const_is_u16(nir_scalar scalar)
709 {
710    uint64_t value = nir_scalar_as_uint(scalar);
711    return value == (uint16_t)value;
712 }
713 
714 static bool
const_is_i16(nir_scalar scalar)715 const_is_i16(nir_scalar scalar)
716 {
717    int64_t value = nir_scalar_as_int(scalar);
718    return value == (int16_t)value;
719 }
720 
721 static bool
can_opt_16bit_src(nir_def * ssa,nir_alu_type src_type,bool sext_matters)722 can_opt_16bit_src(nir_def *ssa, nir_alu_type src_type, bool sext_matters)
723 {
724    bool opt_f16 = src_type == nir_type_float32;
725    bool opt_u16 = src_type == nir_type_uint32 && sext_matters;
726    bool opt_i16 = src_type == nir_type_int32 && sext_matters;
727    bool opt_i16_u16 = (src_type == nir_type_uint32 || src_type == nir_type_int32) && !sext_matters;
728 
729    bool can_opt = opt_f16 || opt_u16 || opt_i16 || opt_i16_u16;
730    for (unsigned i = 0; can_opt && i < ssa->num_components; i++) {
731       nir_scalar comp = nir_scalar_resolved(ssa, i);
732       if (nir_scalar_is_undef(comp))
733          continue;
734       else if (nir_scalar_is_const(comp)) {
735          if (opt_f16)
736             can_opt &= const_is_f16(comp);
737          else if (opt_u16)
738             can_opt &= const_is_u16(comp);
739          else if (opt_i16)
740             can_opt &= const_is_i16(comp);
741          else if (opt_i16_u16)
742             can_opt &= (const_is_u16(comp) || const_is_i16(comp));
743       } else if (nir_scalar_is_alu(comp)) {
744          nir_alu_instr *alu = nir_instr_as_alu(comp.def->parent_instr);
745          if (alu->src[0].src.ssa->bit_size != 16)
746             return false;
747 
748          if (alu->op == nir_op_f2f32)
749             can_opt &= opt_f16;
750          else if (alu->op == nir_op_i2i32)
751             can_opt &= opt_i16 || opt_i16_u16;
752          else if (alu->op == nir_op_u2u32)
753             can_opt &= opt_u16 || opt_i16_u16;
754          else
755             return false;
756       } else {
757          return false;
758       }
759    }
760 
761    return can_opt;
762 }
763 
764 static void
opt_16bit_src(nir_builder * b,nir_instr * instr,nir_src * src,nir_alu_type src_type)765 opt_16bit_src(nir_builder *b, nir_instr *instr, nir_src *src, nir_alu_type src_type)
766 {
767    b->cursor = nir_before_instr(instr);
768 
769    nir_scalar new_comps[NIR_MAX_VEC_COMPONENTS];
770    for (unsigned i = 0; i < src->ssa->num_components; i++) {
771       nir_scalar comp = nir_scalar_resolved(src->ssa, i);
772 
773       if (nir_scalar_is_undef(comp))
774          new_comps[i] = nir_get_scalar(nir_undef(b, 1, 16), 0);
775       else if (nir_scalar_is_const(comp)) {
776          nir_def *constant;
777          if (src_type == nir_type_float32)
778             constant = nir_imm_float16(b, nir_scalar_as_float(comp));
779          else
780             constant = nir_imm_intN_t(b, nir_scalar_as_uint(comp), 16);
781          new_comps[i] = nir_get_scalar(constant, 0);
782       } else {
783          /* conversion instruction */
784          new_comps[i] = nir_scalar_chase_alu_src(comp, 0);
785       }
786    }
787 
788    nir_def *new_vec = nir_vec_scalars(b, new_comps, src->ssa->num_components);
789 
790    nir_src_rewrite(src, new_vec);
791 }
792 
793 static bool
opt_16bit_store_data(nir_builder * b,nir_intrinsic_instr * instr)794 opt_16bit_store_data(nir_builder *b, nir_intrinsic_instr *instr)
795 {
796    nir_alu_type src_type = nir_intrinsic_src_type(instr);
797    nir_src *data_src = &instr->src[3];
798 
799    b->cursor = nir_before_instr(&instr->instr);
800 
801    if (!can_opt_16bit_src(data_src->ssa, src_type, true))
802       return false;
803 
804    opt_16bit_src(b, &instr->instr, data_src, src_type);
805 
806    nir_intrinsic_set_src_type(instr, (src_type & ~32) | 16);
807 
808    return true;
809 }
810 
811 static bool
opt_16bit_destination(nir_def * ssa,nir_alu_type dest_type,unsigned exec_mode,struct nir_opt_16bit_tex_image_options * options)812 opt_16bit_destination(nir_def *ssa, nir_alu_type dest_type, unsigned exec_mode,
813                       struct nir_opt_16bit_tex_image_options *options)
814 {
815    bool opt_f2f16 = dest_type == nir_type_float32;
816    bool opt_i2i16 = (dest_type == nir_type_int32 || dest_type == nir_type_uint32) &&
817                     !options->integer_dest_saturates;
818    bool opt_i2i16_sat = dest_type == nir_type_int32 && options->integer_dest_saturates;
819    bool opt_u2u16_sat = dest_type == nir_type_uint32 && options->integer_dest_saturates;
820 
821    nir_rounding_mode rdm = options->rounding_mode;
822    nir_rounding_mode src_rdm =
823       nir_get_rounding_mode_from_float_controls(exec_mode, nir_type_float16);
824 
825    nir_foreach_use(use, ssa) {
826       nir_instr *instr = nir_src_parent_instr(use);
827       if (instr->type != nir_instr_type_alu)
828          return false;
829 
830       nir_alu_instr *alu = nir_instr_as_alu(instr);
831 
832       switch (alu->op) {
833       case nir_op_pack_half_2x16_split:
834          if (alu->src[0].src.ssa != alu->src[1].src.ssa)
835             return false;
836          FALLTHROUGH;
837       case nir_op_pack_half_2x16:
838          /* pack_half rounding is undefined */
839          if (!opt_f2f16)
840             return false;
841          break;
842       case nir_op_pack_half_2x16_rtz_split:
843          if (alu->src[0].src.ssa != alu->src[1].src.ssa)
844             return false;
845          FALLTHROUGH;
846       case nir_op_f2f16_rtz:
847          if (rdm != nir_rounding_mode_rtz || !opt_f2f16)
848             return false;
849          break;
850       case nir_op_f2f16_rtne:
851          if (rdm != nir_rounding_mode_rtne || !opt_f2f16)
852             return false;
853          break;
854       case nir_op_f2f16:
855       case nir_op_f2fmp:
856          if (src_rdm != rdm && src_rdm != nir_rounding_mode_undef)
857             return false;
858          if (!opt_f2f16)
859             return false;
860          break;
861       case nir_op_i2i16:
862       case nir_op_i2imp:
863       case nir_op_u2u16:
864          if (!opt_i2i16)
865             return false;
866          break;
867       case nir_op_pack_sint_2x16:
868          if (!opt_i2i16_sat)
869             return false;
870          break;
871       case nir_op_pack_uint_2x16:
872          if (!opt_u2u16_sat)
873             return false;
874          break;
875       default:
876          return false;
877       }
878    }
879 
880    /* All uses are the same conversions. Replace them with mov. */
881    nir_foreach_use(use, ssa) {
882       nir_alu_instr *alu = nir_instr_as_alu(nir_src_parent_instr(use));
883       switch (alu->op) {
884       case nir_op_f2f16_rtne:
885       case nir_op_f2f16_rtz:
886       case nir_op_f2f16:
887       case nir_op_f2fmp:
888       case nir_op_i2i16:
889       case nir_op_i2imp:
890       case nir_op_u2u16:
891          alu->op = nir_op_mov;
892          break;
893       case nir_op_pack_half_2x16_rtz_split:
894       case nir_op_pack_half_2x16_split:
895          alu->op = nir_op_pack_32_2x16_split;
896          break;
897       case nir_op_pack_32_2x16_split:
898          /* Split opcodes have two operands, so the iteration
899           * for the second use will already observe the
900           * updated opcode.
901           */
902          break;
903       case nir_op_pack_half_2x16:
904       case nir_op_pack_sint_2x16:
905       case nir_op_pack_uint_2x16:
906          alu->op = nir_op_pack_32_2x16;
907          break;
908       default:
909          unreachable("unsupported conversion op");
910       };
911    }
912 
913    ssa->bit_size = 16;
914    return true;
915 }
916 
917 static bool
opt_16bit_image_dest(nir_intrinsic_instr * instr,unsigned exec_mode,struct nir_opt_16bit_tex_image_options * options)918 opt_16bit_image_dest(nir_intrinsic_instr *instr, unsigned exec_mode,
919                      struct nir_opt_16bit_tex_image_options *options)
920 {
921    nir_alu_type dest_type = nir_intrinsic_dest_type(instr);
922 
923    if (!(nir_alu_type_get_base_type(dest_type) & options->opt_image_dest_types))
924       return false;
925 
926    if (!opt_16bit_destination(&instr->def, dest_type, exec_mode, options))
927       return false;
928 
929    nir_intrinsic_set_dest_type(instr, (dest_type & ~32) | 16);
930 
931    return true;
932 }
933 
934 static bool
opt_16bit_tex_dest(nir_tex_instr * tex,unsigned exec_mode,struct nir_opt_16bit_tex_image_options * options)935 opt_16bit_tex_dest(nir_tex_instr *tex, unsigned exec_mode,
936                    struct nir_opt_16bit_tex_image_options *options)
937 {
938    /* Skip sparse residency */
939    if (tex->is_sparse)
940       return false;
941 
942    if (tex->op != nir_texop_tex &&
943        tex->op != nir_texop_txb &&
944        tex->op != nir_texop_txd &&
945        tex->op != nir_texop_txl &&
946        tex->op != nir_texop_txf &&
947        tex->op != nir_texop_txf_ms &&
948        tex->op != nir_texop_tg4 &&
949        tex->op != nir_texop_tex_prefetch &&
950        tex->op != nir_texop_fragment_fetch_amd)
951       return false;
952 
953    if (!(nir_alu_type_get_base_type(tex->dest_type) & options->opt_tex_dest_types))
954       return false;
955 
956    if (!opt_16bit_destination(&tex->def, tex->dest_type, exec_mode, options))
957       return false;
958 
959    tex->dest_type = (tex->dest_type & ~32) | 16;
960    return true;
961 }
962 
963 static bool
opt_16bit_tex_srcs(nir_builder * b,nir_tex_instr * tex,struct nir_opt_tex_srcs_options * options)964 opt_16bit_tex_srcs(nir_builder *b, nir_tex_instr *tex,
965                    struct nir_opt_tex_srcs_options *options)
966 {
967    if (tex->op != nir_texop_tex &&
968        tex->op != nir_texop_txb &&
969        tex->op != nir_texop_txd &&
970        tex->op != nir_texop_txl &&
971        tex->op != nir_texop_txf &&
972        tex->op != nir_texop_txf_ms &&
973        tex->op != nir_texop_tg4 &&
974        tex->op != nir_texop_tex_prefetch &&
975        tex->op != nir_texop_fragment_fetch_amd &&
976        tex->op != nir_texop_fragment_mask_fetch_amd)
977       return false;
978 
979    if (!(options->sampler_dims & BITFIELD_BIT(tex->sampler_dim)))
980       return false;
981 
982    if (nir_tex_instr_src_index(tex, nir_tex_src_backend1) >= 0)
983       return false;
984 
985    unsigned opt_srcs = 0;
986    for (unsigned i = 0; i < tex->num_srcs; i++) {
987       /* Filter out sources that should be ignored. */
988       if (!(BITFIELD_BIT(tex->src[i].src_type) & options->src_types))
989          continue;
990 
991       nir_src *src = &tex->src[i].src;
992 
993       nir_alu_type src_type = nir_tex_instr_src_type(tex, i) | src->ssa->bit_size;
994 
995       /* Zero-extension (u16) and sign-extension (i16) have
996        * the same behavior here - txf returns 0 if bit 15 is set
997        * because it's out of bounds and the higher bits don't
998        * matter.
999        */
1000       if (!can_opt_16bit_src(src->ssa, src_type, false))
1001          return false;
1002 
1003       opt_srcs |= (1 << i);
1004    }
1005 
1006    u_foreach_bit(i, opt_srcs) {
1007       nir_src *src = &tex->src[i].src;
1008       nir_alu_type src_type = nir_tex_instr_src_type(tex, i) | src->ssa->bit_size;
1009       opt_16bit_src(b, &tex->instr, src, src_type);
1010    }
1011 
1012    return !!opt_srcs;
1013 }
1014 
1015 static bool
opt_16bit_image_srcs(nir_builder * b,nir_intrinsic_instr * instr,int lod_idx)1016 opt_16bit_image_srcs(nir_builder *b, nir_intrinsic_instr *instr, int lod_idx)
1017 {
1018    enum glsl_sampler_dim dim = nir_intrinsic_image_dim(instr);
1019    bool is_ms = (dim == GLSL_SAMPLER_DIM_MS || dim == GLSL_SAMPLER_DIM_SUBPASS_MS);
1020    nir_src *coords = &instr->src[1];
1021    nir_src *sample = is_ms ? &instr->src[2] : NULL;
1022    nir_src *lod = lod_idx >= 0 ? &instr->src[lod_idx] : NULL;
1023 
1024    if (dim == GLSL_SAMPLER_DIM_BUF ||
1025        !can_opt_16bit_src(coords->ssa, nir_type_int32, false) ||
1026        (sample && !can_opt_16bit_src(sample->ssa, nir_type_int32, false)) ||
1027        (lod && !can_opt_16bit_src(lod->ssa, nir_type_int32, false)))
1028       return false;
1029 
1030    opt_16bit_src(b, &instr->instr, coords, nir_type_int32);
1031    if (sample)
1032       opt_16bit_src(b, &instr->instr, sample, nir_type_int32);
1033    if (lod)
1034       opt_16bit_src(b, &instr->instr, lod, nir_type_int32);
1035 
1036    return true;
1037 }
1038 
1039 static bool
opt_16bit_tex_image(nir_builder * b,nir_instr * instr,void * params)1040 opt_16bit_tex_image(nir_builder *b, nir_instr *instr, void *params)
1041 {
1042    struct nir_opt_16bit_tex_image_options *options = params;
1043    unsigned exec_mode = b->shader->info.float_controls_execution_mode;
1044    bool progress = false;
1045 
1046    if (instr->type == nir_instr_type_intrinsic) {
1047       nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
1048 
1049       switch (intrinsic->intrinsic) {
1050       case nir_intrinsic_bindless_image_store:
1051       case nir_intrinsic_image_deref_store:
1052       case nir_intrinsic_image_store:
1053          if (options->opt_image_store_data)
1054             progress |= opt_16bit_store_data(b, intrinsic);
1055          if (options->opt_image_srcs)
1056             progress |= opt_16bit_image_srcs(b, intrinsic, 4);
1057          break;
1058       case nir_intrinsic_bindless_image_load:
1059       case nir_intrinsic_image_deref_load:
1060       case nir_intrinsic_image_load:
1061          if (options->opt_image_dest_types)
1062             progress |= opt_16bit_image_dest(intrinsic, exec_mode, options);
1063          if (options->opt_image_srcs)
1064             progress |= opt_16bit_image_srcs(b, intrinsic, 3);
1065          break;
1066       case nir_intrinsic_bindless_image_sparse_load:
1067       case nir_intrinsic_image_deref_sparse_load:
1068       case nir_intrinsic_image_sparse_load:
1069          if (options->opt_image_srcs)
1070             progress |= opt_16bit_image_srcs(b, intrinsic, 3);
1071          break;
1072       case nir_intrinsic_bindless_image_atomic:
1073       case nir_intrinsic_bindless_image_atomic_swap:
1074       case nir_intrinsic_image_deref_atomic:
1075       case nir_intrinsic_image_deref_atomic_swap:
1076       case nir_intrinsic_image_atomic:
1077       case nir_intrinsic_image_atomic_swap:
1078          if (options->opt_image_srcs)
1079             progress |= opt_16bit_image_srcs(b, intrinsic, -1);
1080          break;
1081       default:
1082          break;
1083       }
1084    } else if (instr->type == nir_instr_type_tex) {
1085       nir_tex_instr *tex = nir_instr_as_tex(instr);
1086 
1087       if (options->opt_tex_dest_types)
1088          progress |= opt_16bit_tex_dest(tex, exec_mode, options);
1089 
1090       for (unsigned i = 0; i < options->opt_srcs_options_count; i++) {
1091          progress |= opt_16bit_tex_srcs(b, tex, &options->opt_srcs_options[i]);
1092       }
1093    }
1094 
1095    return progress;
1096 }
1097 
1098 bool
nir_opt_16bit_tex_image(nir_shader * nir,struct nir_opt_16bit_tex_image_options * options)1099 nir_opt_16bit_tex_image(nir_shader *nir,
1100                         struct nir_opt_16bit_tex_image_options *options)
1101 {
1102    return nir_shader_instructions_pass(nir,
1103                                        opt_16bit_tex_image,
1104                                        nir_metadata_control_flow,
1105                                        options);
1106 }
1107