xref: /aosp_15_r20/external/mesa3d/src/amd/compiler/aco_instruction_selection_setup.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2018 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "aco_instruction_selection.h"
8 
9 #include "common/ac_nir.h"
10 #include "common/sid.h"
11 
12 #include "nir_control_flow.h"
13 #include "nir_builder.h"
14 
15 #include <vector>
16 
17 namespace aco {
18 
19 namespace {
20 
21 /* Check whether the given SSA def is only used by cross-lane instructions. */
22 bool
only_used_by_cross_lane_instrs(nir_def * ssa,bool follow_phis=true)23 only_used_by_cross_lane_instrs(nir_def* ssa, bool follow_phis = true)
24 {
25    nir_foreach_use (src, ssa) {
26       switch (nir_src_parent_instr(src)->type) {
27       case nir_instr_type_alu: {
28          nir_alu_instr* alu = nir_instr_as_alu(nir_src_parent_instr(src));
29          if (alu->op != nir_op_unpack_64_2x32_split_x && alu->op != nir_op_unpack_64_2x32_split_y)
30             return false;
31          if (!only_used_by_cross_lane_instrs(&alu->def, follow_phis))
32             return false;
33 
34          continue;
35       }
36       case nir_instr_type_intrinsic: {
37          nir_intrinsic_instr* intrin = nir_instr_as_intrinsic(nir_src_parent_instr(src));
38          if (intrin->intrinsic != nir_intrinsic_read_invocation &&
39              intrin->intrinsic != nir_intrinsic_read_first_invocation &&
40              intrin->intrinsic != nir_intrinsic_lane_permute_16_amd)
41             return false;
42 
43          continue;
44       }
45       case nir_instr_type_phi: {
46          /* Don't follow more than 1 phis, this avoids infinite loops. */
47          if (!follow_phis)
48             return false;
49 
50          nir_phi_instr* phi = nir_instr_as_phi(nir_src_parent_instr(src));
51          if (!only_used_by_cross_lane_instrs(&phi->def, false))
52             return false;
53 
54          continue;
55       }
56       default: return false;
57       }
58    }
59 
60    return true;
61 }
62 
63 /* If one side of a divergent IF ends in a branch and the other doesn't, we
64  * might have to emit the contents of the side without the branch at the merge
65  * block instead. This is so that we can use any SGPR live-out of the side
66  * without the branch without creating a linear phi in the invert or merge block.
67  *
68  * This also removes any unreachable merge blocks.
69  */
70 bool
sanitize_if(nir_function_impl * impl,nir_if * nif)71 sanitize_if(nir_function_impl* impl, nir_if* nif)
72 {
73    nir_block* then_block = nir_if_last_then_block(nif);
74    nir_block* else_block = nir_if_last_else_block(nif);
75    bool then_jump = nir_block_ends_in_jump(then_block);
76    bool else_jump = nir_block_ends_in_jump(else_block);
77    if (!then_jump && !else_jump)
78       return false;
79 
80    /* If the continue from block is empty then return as there is nothing to
81     * move.
82     */
83    if (nir_cf_list_is_empty_block(then_jump ? &nif->else_list : &nif->then_list))
84       return false;
85 
86    /* Even though this if statement has a jump on one side, we may still have
87     * phis afterwards.  Single-source phis can be produced by loop unrolling
88     * or dead control-flow passes and are perfectly legal.  Run a quick phi
89     * removal on the block after the if to clean up any such phis.
90     */
91    nir_opt_remove_phis_block(nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node)));
92 
93    /* Finally, move the continue from branch after the if-statement. */
94    nir_block* last_continue_from_blk = then_jump ? else_block : then_block;
95    nir_block* first_continue_from_blk =
96       then_jump ? nir_if_first_else_block(nif) : nir_if_first_then_block(nif);
97 
98    /* We don't need to repair SSA. nir_remove_after_cf_node() replaces any uses with undef. */
99    if (then_jump && else_jump)
100       nir_remove_after_cf_node(&nif->cf_node);
101 
102    nir_cf_list tmp;
103    nir_cf_extract(&tmp, nir_before_block(first_continue_from_blk),
104                   nir_after_block(last_continue_from_blk));
105    nir_cf_reinsert(&tmp, nir_after_cf_node(&nif->cf_node));
106 
107    return true;
108 }
109 
110 bool
sanitize_cf_list(nir_function_impl * impl,struct exec_list * cf_list)111 sanitize_cf_list(nir_function_impl* impl, struct exec_list* cf_list)
112 {
113    bool progress = false;
114    foreach_list_typed (nir_cf_node, cf_node, node, cf_list) {
115       switch (cf_node->type) {
116       case nir_cf_node_block: break;
117       case nir_cf_node_if: {
118          nir_if* nif = nir_cf_node_as_if(cf_node);
119          progress |= sanitize_cf_list(impl, &nif->then_list);
120          progress |= sanitize_cf_list(impl, &nif->else_list);
121          progress |= sanitize_if(impl, nif);
122          break;
123       }
124       case nir_cf_node_loop: {
125          nir_loop* loop = nir_cf_node_as_loop(cf_node);
126          assert(!nir_loop_has_continue_construct(loop));
127          progress |= sanitize_cf_list(impl, &loop->body);
128 
129          /* NIR seems to allow this, and even though the loop exit has no predecessors, SSA defs from the
130           * loop header are live. Handle this without complicating the ACO IR by creating a dummy break.
131           */
132          if (nir_cf_node_cf_tree_next(&loop->cf_node)->predecessors->entries == 0) {
133             nir_builder b = nir_builder_create(impl);
134             b.cursor = nir_after_block_before_jump(nir_loop_last_block(loop));
135 
136             nir_def *cond = nir_imm_false(&b);
137             /* We don't use block divergence information, so just this is enough. */
138             cond->divergent = false;
139 
140             nir_push_if(&b, cond);
141             nir_jump(&b, nir_jump_break);
142             nir_pop_if(&b, NULL);
143 
144             progress = true;
145          }
146          break;
147       }
148       case nir_cf_node_function: unreachable("Invalid cf type");
149       }
150    }
151 
152    return progress;
153 }
154 
155 void
apply_nuw_to_ssa(isel_context * ctx,nir_def * ssa)156 apply_nuw_to_ssa(isel_context* ctx, nir_def* ssa)
157 {
158    nir_scalar scalar;
159    scalar.def = ssa;
160    scalar.comp = 0;
161 
162    if (!nir_scalar_is_alu(scalar) || nir_scalar_alu_op(scalar) != nir_op_iadd)
163       return;
164 
165    nir_alu_instr* add = nir_instr_as_alu(ssa->parent_instr);
166 
167    if (add->no_unsigned_wrap)
168       return;
169 
170    nir_scalar src0 = nir_scalar_chase_alu_src(scalar, 0);
171    nir_scalar src1 = nir_scalar_chase_alu_src(scalar, 1);
172 
173    if (nir_scalar_is_const(src0)) {
174       nir_scalar tmp = src0;
175       src0 = src1;
176       src1 = tmp;
177    }
178 
179    uint32_t src1_ub = nir_unsigned_upper_bound(ctx->shader, ctx->range_ht, src1, &ctx->ub_config);
180    add->no_unsigned_wrap =
181       !nir_addition_might_overflow(ctx->shader, ctx->range_ht, src0, src1_ub, &ctx->ub_config);
182 }
183 
184 void
apply_nuw_to_offsets(isel_context * ctx,nir_function_impl * impl)185 apply_nuw_to_offsets(isel_context* ctx, nir_function_impl* impl)
186 {
187    nir_foreach_block (block, impl) {
188       nir_foreach_instr (instr, block) {
189          if (instr->type != nir_instr_type_intrinsic)
190             continue;
191          nir_intrinsic_instr* intrin = nir_instr_as_intrinsic(instr);
192 
193          switch (intrin->intrinsic) {
194          case nir_intrinsic_load_constant:
195          case nir_intrinsic_load_uniform:
196          case nir_intrinsic_load_push_constant:
197             if (!nir_src_is_divergent(intrin->src[0]))
198                apply_nuw_to_ssa(ctx, intrin->src[0].ssa);
199             break;
200          case nir_intrinsic_load_ubo:
201          case nir_intrinsic_load_ssbo:
202             if (!nir_src_is_divergent(intrin->src[1]))
203                apply_nuw_to_ssa(ctx, intrin->src[1].ssa);
204             break;
205          case nir_intrinsic_store_ssbo:
206             if (!nir_src_is_divergent(intrin->src[2]))
207                apply_nuw_to_ssa(ctx, intrin->src[2].ssa);
208             break;
209          case nir_intrinsic_load_scratch: apply_nuw_to_ssa(ctx, intrin->src[0].ssa); break;
210          case nir_intrinsic_store_scratch:
211          case nir_intrinsic_load_smem_amd: apply_nuw_to_ssa(ctx, intrin->src[1].ssa); break;
212          default: break;
213          }
214       }
215    }
216 }
217 
218 RegClass
get_reg_class(isel_context * ctx,RegType type,unsigned components,unsigned bitsize)219 get_reg_class(isel_context* ctx, RegType type, unsigned components, unsigned bitsize)
220 {
221    if (bitsize == 1)
222       return RegClass(RegType::sgpr, ctx->program->lane_mask.size() * components);
223    else
224       return RegClass::get(type, components * bitsize / 8u);
225 }
226 
227 void
setup_tcs_info(isel_context * ctx)228 setup_tcs_info(isel_context* ctx)
229 {
230    ctx->tcs_in_out_eq = ctx->program->info.vs.tcs_in_out_eq;
231    ctx->tcs_temp_only_inputs = ctx->program->info.vs.tcs_temp_only_input_mask;
232 }
233 
234 void
setup_lds_size(isel_context * ctx,nir_shader * nir)235 setup_lds_size(isel_context* ctx, nir_shader* nir)
236 {
237    /* TCS and GFX9 GS are special cases, already in units of the allocation granule. */
238    if (ctx->stage.has(SWStage::TCS))
239       ctx->program->config->lds_size = ctx->program->info.tcs.num_lds_blocks;
240    else if (ctx->stage.hw == AC_HW_LEGACY_GEOMETRY_SHADER && ctx->options->gfx_level >= GFX9)
241       ctx->program->config->lds_size = ctx->program->info.gfx9_gs_ring_lds_size;
242    else
243       ctx->program->config->lds_size =
244          DIV_ROUND_UP(nir->info.shared_size, ctx->program->dev.lds_encoding_granule);
245 
246    /* Make sure we fit the available LDS space. */
247    assert((ctx->program->config->lds_size * ctx->program->dev.lds_encoding_granule) <=
248           ctx->program->dev.lds_limit);
249 }
250 
251 void
setup_nir(isel_context * ctx,nir_shader * nir)252 setup_nir(isel_context* ctx, nir_shader* nir)
253 {
254    nir_convert_to_lcssa(nir, true, false);
255    if (nir_lower_phis_to_scalar(nir, true)) {
256       nir_copy_prop(nir);
257       nir_opt_dce(nir);
258    }
259 
260    nir_function_impl* func = nir_shader_get_entrypoint(nir);
261    nir_index_ssa_defs(func);
262 }
263 
264 } /* end namespace */
265 
266 void
init_context(isel_context * ctx,nir_shader * shader)267 init_context(isel_context* ctx, nir_shader* shader)
268 {
269    nir_function_impl* impl = nir_shader_get_entrypoint(shader);
270    ctx->shader = shader;
271 
272    /* Init NIR range analysis. */
273    ctx->range_ht = _mesa_pointer_hash_table_create(NULL);
274    ctx->ub_config.min_subgroup_size = ctx->program->wave_size;
275    ctx->ub_config.max_subgroup_size = ctx->program->wave_size;
276    ctx->ub_config.max_workgroup_invocations = 2048;
277    ctx->ub_config.max_workgroup_count[0] = 65535;
278    ctx->ub_config.max_workgroup_count[1] = 65535;
279    ctx->ub_config.max_workgroup_count[2] = 65535;
280    ctx->ub_config.max_workgroup_size[0] = 2048;
281    ctx->ub_config.max_workgroup_size[1] = 2048;
282    ctx->ub_config.max_workgroup_size[2] = 2048;
283 
284    ac_nir_opt_shared_append(shader);
285 
286    nir_divergence_analysis(shader);
287    if (nir_opt_uniform_atomics(shader, false) && nir_lower_int64(shader))
288       nir_divergence_analysis(shader);
289 
290    apply_nuw_to_offsets(ctx, impl);
291 
292    /* sanitize control flow */
293    sanitize_cf_list(impl, &impl->body);
294    nir_metadata_preserve(impl, nir_metadata_none);
295 
296    /* we'll need these for isel */
297    nir_metadata_require(impl, nir_metadata_block_index | nir_metadata_dominance);
298 
299    if (ctx->options->dump_preoptir) {
300       fprintf(stderr, "NIR shader before instruction selection:\n");
301       nir_print_shader(shader, stderr);
302    }
303 
304    ctx->first_temp_id = ctx->program->peekAllocationId();
305    ctx->program->allocateRange(impl->ssa_alloc);
306    RegClass* regclasses = ctx->program->temp_rc.data() + ctx->first_temp_id;
307 
308    /* TODO: make this recursive to improve compile times */
309    bool done = false;
310    while (!done) {
311       done = true;
312       nir_foreach_block (block, impl) {
313          nir_foreach_instr (instr, block) {
314             switch (instr->type) {
315             case nir_instr_type_alu: {
316                nir_alu_instr* alu_instr = nir_instr_as_alu(instr);
317                RegType type = alu_instr->def.divergent ? RegType::vgpr : RegType::sgpr;
318 
319                /* packed 16bit instructions have to be VGPR */
320                if (alu_instr->def.num_components == 2 &&
321                    nir_op_infos[alu_instr->op].output_size == 0)
322                   type = RegType::vgpr;
323 
324                switch (alu_instr->op) {
325                case nir_op_f2i16:
326                case nir_op_f2u16:
327                case nir_op_f2i32:
328                case nir_op_f2u32:
329                case nir_op_b2i8:
330                case nir_op_b2i16:
331                case nir_op_b2i32:
332                case nir_op_b2b32:
333                case nir_op_b2f16:
334                case nir_op_b2f32:
335                case nir_op_mov: break;
336                case nir_op_fmulz:
337                case nir_op_ffmaz:
338                case nir_op_f2f64:
339                case nir_op_u2f64:
340                case nir_op_i2f64:
341                case nir_op_pack_unorm_2x16:
342                case nir_op_pack_snorm_2x16:
343                case nir_op_pack_uint_2x16:
344                case nir_op_pack_sint_2x16:
345                case nir_op_ldexp:
346                case nir_op_frexp_sig:
347                case nir_op_frexp_exp:
348                case nir_op_cube_amd:
349                case nir_op_msad_4x8:
350                case nir_op_mqsad_4x8:
351                case nir_op_udot_4x8_uadd:
352                case nir_op_sdot_4x8_iadd:
353                case nir_op_sudot_4x8_iadd:
354                case nir_op_udot_4x8_uadd_sat:
355                case nir_op_sdot_4x8_iadd_sat:
356                case nir_op_sudot_4x8_iadd_sat:
357                case nir_op_udot_2x16_uadd:
358                case nir_op_sdot_2x16_iadd:
359                case nir_op_udot_2x16_uadd_sat:
360                case nir_op_sdot_2x16_iadd_sat: type = RegType::vgpr; break;
361                case nir_op_fmul:
362                case nir_op_ffma:
363                case nir_op_fadd:
364                case nir_op_fsub:
365                case nir_op_fmax:
366                case nir_op_fmin:
367                case nir_op_fsat:
368                case nir_op_fneg:
369                case nir_op_fabs:
370                case nir_op_fsign:
371                case nir_op_i2f16:
372                case nir_op_i2f32:
373                case nir_op_u2f16:
374                case nir_op_u2f32:
375                case nir_op_f2f16:
376                case nir_op_f2f16_rtz:
377                case nir_op_f2f16_rtne:
378                case nir_op_f2f32:
379                case nir_op_fquantize2f16:
380                case nir_op_ffract:
381                case nir_op_ffloor:
382                case nir_op_fceil:
383                case nir_op_ftrunc:
384                case nir_op_fround_even:
385                case nir_op_frcp:
386                case nir_op_frsq:
387                case nir_op_fsqrt:
388                case nir_op_fexp2:
389                case nir_op_flog2:
390                case nir_op_fsin_amd:
391                case nir_op_fcos_amd:
392                case nir_op_pack_half_2x16_rtz_split:
393                case nir_op_pack_half_2x16_split:
394                case nir_op_unpack_half_2x16_split_x:
395                case nir_op_unpack_half_2x16_split_y: {
396                   if (ctx->program->gfx_level < GFX11_5 ||
397                       alu_instr->src[0].src.ssa->bit_size > 32) {
398                      type = RegType::vgpr;
399                      break;
400                   }
401                   FALLTHROUGH;
402                }
403                default:
404                   for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) {
405                      if (regclasses[alu_instr->src[i].src.ssa->index].type() == RegType::vgpr)
406                         type = RegType::vgpr;
407                   }
408                   break;
409                }
410 
411                RegClass rc =
412                   get_reg_class(ctx, type, alu_instr->def.num_components, alu_instr->def.bit_size);
413                regclasses[alu_instr->def.index] = rc;
414                break;
415             }
416             case nir_instr_type_load_const: {
417                unsigned num_components = nir_instr_as_load_const(instr)->def.num_components;
418                unsigned bit_size = nir_instr_as_load_const(instr)->def.bit_size;
419                RegClass rc = get_reg_class(ctx, RegType::sgpr, num_components, bit_size);
420                regclasses[nir_instr_as_load_const(instr)->def.index] = rc;
421                break;
422             }
423             case nir_instr_type_intrinsic: {
424                nir_intrinsic_instr* intrinsic = nir_instr_as_intrinsic(instr);
425                if (!nir_intrinsic_infos[intrinsic->intrinsic].has_dest)
426                   break;
427                if (intrinsic->intrinsic == nir_intrinsic_strict_wqm_coord_amd) {
428                   regclasses[intrinsic->def.index] =
429                      RegClass::get(RegType::vgpr, intrinsic->def.num_components * 4 +
430                                                      nir_intrinsic_base(intrinsic))
431                         .as_linear();
432                   break;
433                }
434                RegType type = RegType::sgpr;
435                switch (intrinsic->intrinsic) {
436                case nir_intrinsic_load_push_constant:
437                case nir_intrinsic_load_workgroup_id:
438                case nir_intrinsic_load_num_workgroups:
439                case nir_intrinsic_load_sbt_base_amd:
440                case nir_intrinsic_load_subgroup_id:
441                case nir_intrinsic_load_num_subgroups:
442                case nir_intrinsic_load_first_vertex:
443                case nir_intrinsic_load_base_instance:
444                case nir_intrinsic_vote_all:
445                case nir_intrinsic_vote_any:
446                case nir_intrinsic_read_first_invocation:
447                case nir_intrinsic_as_uniform:
448                case nir_intrinsic_read_invocation:
449                case nir_intrinsic_first_invocation:
450                case nir_intrinsic_ballot:
451                case nir_intrinsic_ballot_relaxed:
452                case nir_intrinsic_bindless_image_samples:
453                case nir_intrinsic_load_scalar_arg_amd:
454                case nir_intrinsic_load_lds_ngg_scratch_base_amd:
455                case nir_intrinsic_load_lds_ngg_gs_out_vertex_base_amd:
456                case nir_intrinsic_load_smem_amd:
457                case nir_intrinsic_unit_test_uniform_amd: type = RegType::sgpr; break;
458                case nir_intrinsic_load_sample_id:
459                case nir_intrinsic_load_input:
460                case nir_intrinsic_load_per_primitive_input:
461                case nir_intrinsic_load_output:
462                case nir_intrinsic_load_input_vertex:
463                case nir_intrinsic_load_per_vertex_input:
464                case nir_intrinsic_load_per_vertex_output:
465                case nir_intrinsic_load_vertex_id_zero_base:
466                case nir_intrinsic_load_barycentric_sample:
467                case nir_intrinsic_load_barycentric_pixel:
468                case nir_intrinsic_load_barycentric_model:
469                case nir_intrinsic_load_barycentric_centroid:
470                case nir_intrinsic_load_barycentric_at_offset:
471                case nir_intrinsic_load_interpolated_input:
472                case nir_intrinsic_load_frag_coord:
473                case nir_intrinsic_load_frag_shading_rate:
474                case nir_intrinsic_load_sample_pos:
475                case nir_intrinsic_load_local_invocation_id:
476                case nir_intrinsic_load_local_invocation_index:
477                case nir_intrinsic_load_subgroup_invocation:
478                case nir_intrinsic_load_tess_coord:
479                case nir_intrinsic_write_invocation_amd:
480                case nir_intrinsic_mbcnt_amd:
481                case nir_intrinsic_lane_permute_16_amd:
482                case nir_intrinsic_dpp16_shift_amd:
483                case nir_intrinsic_load_instance_id:
484                case nir_intrinsic_ssbo_atomic:
485                case nir_intrinsic_ssbo_atomic_swap:
486                case nir_intrinsic_global_atomic_amd:
487                case nir_intrinsic_global_atomic_swap_amd:
488                case nir_intrinsic_bindless_image_atomic:
489                case nir_intrinsic_bindless_image_atomic_swap:
490                case nir_intrinsic_bindless_image_size:
491                case nir_intrinsic_shared_atomic:
492                case nir_intrinsic_shared_atomic_swap:
493                case nir_intrinsic_load_scratch:
494                case nir_intrinsic_load_invocation_id:
495                case nir_intrinsic_load_primitive_id:
496                case nir_intrinsic_load_typed_buffer_amd:
497                case nir_intrinsic_load_buffer_amd:
498                case nir_intrinsic_load_initial_edgeflags_amd:
499                case nir_intrinsic_gds_atomic_add_amd:
500                case nir_intrinsic_bvh64_intersect_ray_amd:
501                case nir_intrinsic_load_vector_arg_amd:
502                case nir_intrinsic_ordered_xfb_counter_add_gfx11_amd:
503                case nir_intrinsic_cmat_muladd_amd:
504                case nir_intrinsic_unit_test_divergent_amd: type = RegType::vgpr; break;
505                case nir_intrinsic_load_shared:
506                case nir_intrinsic_load_shared2_amd:
507                   /* When the result of these loads is only used by cross-lane instructions,
508                    * it is beneficial to use a VGPR destination. This is because this allows
509                    * to put the s_waitcnt further down, which decreases latency.
510                    */
511                   if (only_used_by_cross_lane_instrs(&intrinsic->def)) {
512                      type = RegType::vgpr;
513                      break;
514                   }
515                   FALLTHROUGH;
516                case nir_intrinsic_shuffle:
517                case nir_intrinsic_quad_broadcast:
518                case nir_intrinsic_quad_swap_horizontal:
519                case nir_intrinsic_quad_swap_vertical:
520                case nir_intrinsic_quad_swap_diagonal:
521                case nir_intrinsic_quad_swizzle_amd:
522                case nir_intrinsic_masked_swizzle_amd:
523                case nir_intrinsic_rotate:
524                case nir_intrinsic_inclusive_scan:
525                case nir_intrinsic_exclusive_scan:
526                case nir_intrinsic_reduce:
527                case nir_intrinsic_load_ubo:
528                case nir_intrinsic_load_ssbo:
529                case nir_intrinsic_load_global_amd:
530                   type = intrinsic->def.divergent ? RegType::vgpr : RegType::sgpr;
531                   break;
532                case nir_intrinsic_ddx:
533                case nir_intrinsic_ddy:
534                case nir_intrinsic_ddx_fine:
535                case nir_intrinsic_ddy_fine:
536                case nir_intrinsic_ddx_coarse:
537                case nir_intrinsic_ddy_coarse:
538                   type = RegType::vgpr;
539                   break;
540                case nir_intrinsic_load_view_index:
541                   type = ctx->stage == fragment_fs ? RegType::vgpr : RegType::sgpr;
542                   break;
543                default:
544                   for (unsigned i = 0; i < nir_intrinsic_infos[intrinsic->intrinsic].num_srcs;
545                        i++) {
546                      if (regclasses[intrinsic->src[i].ssa->index].type() == RegType::vgpr)
547                         type = RegType::vgpr;
548                   }
549                   break;
550                }
551                RegClass rc =
552                   get_reg_class(ctx, type, intrinsic->def.num_components, intrinsic->def.bit_size);
553                regclasses[intrinsic->def.index] = rc;
554                break;
555             }
556             case nir_instr_type_tex: {
557                nir_tex_instr* tex = nir_instr_as_tex(instr);
558                RegType type = tex->def.divergent ? RegType::vgpr : RegType::sgpr;
559 
560                if (tex->op == nir_texop_texture_samples) {
561                   assert(!tex->def.divergent);
562                }
563 
564                RegClass rc = get_reg_class(ctx, type, tex->def.num_components, tex->def.bit_size);
565                regclasses[tex->def.index] = rc;
566                break;
567             }
568             case nir_instr_type_undef: {
569                unsigned num_components = nir_instr_as_undef(instr)->def.num_components;
570                unsigned bit_size = nir_instr_as_undef(instr)->def.bit_size;
571                RegClass rc = get_reg_class(ctx, RegType::sgpr, num_components, bit_size);
572                regclasses[nir_instr_as_undef(instr)->def.index] = rc;
573                break;
574             }
575             case nir_instr_type_phi: {
576                nir_phi_instr* phi = nir_instr_as_phi(instr);
577                RegType type = RegType::sgpr;
578                unsigned num_components = phi->def.num_components;
579                assert((phi->def.bit_size != 1 || num_components == 1) &&
580                       "Multiple components not supported on boolean phis.");
581 
582                if (phi->def.divergent) {
583                   type = RegType::vgpr;
584                } else {
585                   nir_foreach_phi_src (src, phi) {
586                      if (regclasses[src->src.ssa->index].type() == RegType::vgpr)
587                         type = RegType::vgpr;
588                   }
589                }
590 
591                RegClass rc = get_reg_class(ctx, type, num_components, phi->def.bit_size);
592                if (rc != regclasses[phi->def.index])
593                   done = false;
594                regclasses[phi->def.index] = rc;
595                break;
596             }
597             default: break;
598             }
599          }
600       }
601    }
602 
603    ctx->program->config->spi_ps_input_ena = ctx->program->info.ps.spi_ps_input_ena;
604    ctx->program->config->spi_ps_input_addr = ctx->program->info.ps.spi_ps_input_addr;
605 
606    /* align and copy constant data */
607    while (ctx->program->constant_data.size() % 4u)
608       ctx->program->constant_data.push_back(0);
609    ctx->constant_data_offset = ctx->program->constant_data.size();
610    ctx->program->constant_data.insert(ctx->program->constant_data.end(),
611                                       (uint8_t*)shader->constant_data,
612                                       (uint8_t*)shader->constant_data + shader->constant_data_size);
613 
614    BITSET_CLEAR_RANGE(ctx->output_args, 0, BITSET_SIZE(ctx->output_args));
615 }
616 
617 void
cleanup_context(isel_context * ctx)618 cleanup_context(isel_context* ctx)
619 {
620    _mesa_hash_table_destroy(ctx->range_ht, NULL);
621 }
622 
623 isel_context
setup_isel_context(Program * program,unsigned shader_count,struct nir_shader * const * shaders,ac_shader_config * config,const struct aco_compiler_options * options,const struct aco_shader_info * info,const struct ac_shader_args * args,SWStage sw_stage)624 setup_isel_context(Program* program, unsigned shader_count, struct nir_shader* const* shaders,
625                    ac_shader_config* config, const struct aco_compiler_options* options,
626                    const struct aco_shader_info* info, const struct ac_shader_args* args,
627                    SWStage sw_stage)
628 {
629    for (unsigned i = 0; i < shader_count; i++) {
630       switch (shaders[i]->info.stage) {
631       case MESA_SHADER_VERTEX: sw_stage = sw_stage | SWStage::VS; break;
632       case MESA_SHADER_TESS_CTRL: sw_stage = sw_stage | SWStage::TCS; break;
633       case MESA_SHADER_TESS_EVAL: sw_stage = sw_stage | SWStage::TES; break;
634       case MESA_SHADER_GEOMETRY: sw_stage = sw_stage | SWStage::GS; break;
635       case MESA_SHADER_FRAGMENT: sw_stage = sw_stage | SWStage::FS; break;
636       case MESA_SHADER_KERNEL:
637       case MESA_SHADER_COMPUTE: sw_stage = sw_stage | SWStage::CS; break;
638       case MESA_SHADER_TASK: sw_stage = sw_stage | SWStage::TS; break;
639       case MESA_SHADER_MESH: sw_stage = sw_stage | SWStage::MS; break;
640       case MESA_SHADER_RAYGEN:
641       case MESA_SHADER_CLOSEST_HIT:
642       case MESA_SHADER_MISS:
643       case MESA_SHADER_CALLABLE:
644       case MESA_SHADER_INTERSECTION:
645       case MESA_SHADER_ANY_HIT: sw_stage = SWStage::RT; break;
646       default: unreachable("Shader stage not implemented");
647       }
648    }
649 
650    init_program(program, Stage{info->hw_stage, sw_stage}, info, options->gfx_level, options->family,
651                 options->wgp_mode, config);
652 
653    isel_context ctx = {};
654    ctx.program = program;
655    ctx.args = args;
656    ctx.options = options;
657    ctx.stage = program->stage;
658 
659    program->workgroup_size = program->info.workgroup_size;
660    assert(program->workgroup_size);
661 
662    /* Mesh shading only works on GFX10.3+. */
663    ASSERTED bool mesh_shading = ctx.stage.has(SWStage::TS) || ctx.stage.has(SWStage::MS);
664    assert(!mesh_shading || ctx.program->gfx_level >= GFX10_3);
665 
666    setup_tcs_info(&ctx);
667 
668    calc_min_waves(program);
669 
670    unsigned scratch_size = 0;
671    for (unsigned i = 0; i < shader_count; i++) {
672       nir_shader* nir = shaders[i];
673       setup_nir(&ctx, nir);
674       setup_lds_size(&ctx, nir);
675    }
676 
677    for (unsigned i = 0; i < shader_count; i++)
678       scratch_size = std::max(scratch_size, shaders[i]->scratch_size);
679 
680    ctx.program->config->scratch_bytes_per_wave = scratch_size * ctx.program->wave_size;
681 
682    unsigned nir_num_blocks = 0;
683    for (unsigned i = 0; i < shader_count; i++)
684       nir_num_blocks += nir_shader_get_entrypoint(shaders[i])->num_blocks;
685    ctx.program->blocks.reserve(nir_num_blocks * 2);
686    ctx.block = ctx.program->create_and_insert_block();
687    ctx.block->kind = block_kind_top_level;
688 
689    return ctx;
690 }
691 
692 } // namespace aco
693