xref: /aosp_15_r20/external/mesa3d/src/gallium/frontends/rusticl/rusticl_nir.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 #include "CL/cl.h"
2 
3 #include "nir.h"
4 #include "nir_builder.h"
5 
6 #include "rusticl_nir.h"
7 
8 static bool
rusticl_lower_intrinsics_filter(const nir_instr * instr,const void * state)9 rusticl_lower_intrinsics_filter(const nir_instr* instr, const void* state)
10 {
11     return instr->type == nir_instr_type_intrinsic;
12 }
13 
14 static nir_def*
rusticl_lower_intrinsics_instr(nir_builder * b,nir_instr * instr,void * _state)15 rusticl_lower_intrinsics_instr(
16     nir_builder *b,
17     nir_instr *instr,
18     void* _state
19 ) {
20     nir_intrinsic_instr *intrins = nir_instr_as_intrinsic(instr);
21     struct rusticl_lower_state *state = _state;
22 
23     switch (intrins->intrinsic) {
24     case nir_intrinsic_image_deref_format:
25     case nir_intrinsic_image_deref_order: {
26         int32_t offset;
27         nir_deref_instr *deref;
28         nir_def *val;
29         nir_variable *var;
30 
31         if (intrins->intrinsic == nir_intrinsic_image_deref_format) {
32             offset = CL_SNORM_INT8;
33             var = nir_find_variable_with_location(b->shader, nir_var_uniform, state->format_arr_loc);
34         } else {
35             offset = CL_R;
36             var = nir_find_variable_with_location(b->shader, nir_var_uniform, state->order_arr_loc);
37         }
38 
39         val = intrins->src[0].ssa;
40 
41         if (val->parent_instr->type == nir_instr_type_deref) {
42             nir_deref_instr *deref = nir_instr_as_deref(val->parent_instr);
43             nir_variable *var = nir_deref_instr_get_variable(deref);
44             assert(var);
45             val = nir_imm_intN_t(b, var->data.binding, val->bit_size);
46         }
47 
48         // we put write images after read images
49         if (glsl_type_is_image(var->type)) {
50             val = nir_iadd_imm(b, val, b->shader->info.num_textures);
51         }
52 
53         deref = nir_build_deref_var(b, var);
54         deref = nir_build_deref_array(b, deref, val);
55         val = nir_u2uN(b, nir_load_deref(b, deref), 32);
56 
57         // we have to fix up the value base
58         val = nir_iadd_imm(b, val, -offset);
59 
60         return val;
61     }
62     case nir_intrinsic_load_global_invocation_id:
63         if (intrins->def.bit_size == 64)
64             return nir_u2u64(b, nir_load_global_invocation_id(b, 32));
65         return NULL;
66     case nir_intrinsic_load_base_global_invocation_id:
67         return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->base_global_invoc_id_loc));
68     case nir_intrinsic_load_base_workgroup_id:
69         return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->base_workgroup_id_loc));
70     case nir_intrinsic_load_global_size:
71         return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->global_size_loc));
72     case nir_intrinsic_load_num_workgroups:
73         return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->num_workgroups_loc));
74     case nir_intrinsic_load_constant_base_ptr:
75         return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->const_buf_loc));
76     case nir_intrinsic_load_printf_buffer_address:
77         return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->printf_buf_loc));
78     case nir_intrinsic_load_work_dim:
79         assert(nir_find_variable_with_location(b->shader, nir_var_uniform, state->work_dim_loc));
80         return nir_u2uN(b, nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->work_dim_loc)),
81                         intrins->def.bit_size);
82     default:
83         return NULL;
84     }
85 }
86 
87 bool
rusticl_lower_intrinsics(nir_shader * nir,struct rusticl_lower_state * state)88 rusticl_lower_intrinsics(nir_shader *nir, struct rusticl_lower_state* state)
89 {
90     return nir_shader_lower_instructions(
91         nir,
92         rusticl_lower_intrinsics_filter,
93         rusticl_lower_intrinsics_instr,
94         state
95     );
96 }
97 
98 static nir_def*
rusticl_lower_input_instr(struct nir_builder * b,nir_instr * instr,void * _)99 rusticl_lower_input_instr(struct nir_builder *b, nir_instr *instr, void *_)
100 {
101    nir_intrinsic_instr *intrins = nir_instr_as_intrinsic(instr);
102    if (intrins->intrinsic != nir_intrinsic_load_kernel_input)
103       return NULL;
104 
105    nir_def *ubo_idx = nir_imm_int(b, 0);
106    nir_def *uniform_offset = intrins->src[0].ssa;
107 
108    assert(intrins->def.bit_size >= 8);
109    nir_def *load_result =
110       nir_load_ubo(b, intrins->num_components, intrins->def.bit_size,
111                    ubo_idx, nir_iadd_imm(b, uniform_offset, nir_intrinsic_base(intrins)));
112 
113    nir_intrinsic_instr *load = nir_instr_as_intrinsic(load_result->parent_instr);
114 
115    nir_intrinsic_set_align_mul(load, nir_intrinsic_align_mul(intrins));
116    nir_intrinsic_set_align_offset(load, nir_intrinsic_align_offset(intrins));
117    nir_intrinsic_set_range_base(load, nir_intrinsic_base(intrins));
118    nir_intrinsic_set_range(load, nir_intrinsic_range(intrins));
119 
120    return load_result;
121 }
122 
123 bool
rusticl_lower_inputs(nir_shader * shader)124 rusticl_lower_inputs(nir_shader *shader)
125 {
126    bool progress = false;
127 
128    assert(!shader->info.first_ubo_is_default_ubo);
129 
130    progress = nir_shader_lower_instructions(
131       shader,
132       rusticl_lower_intrinsics_filter,
133       rusticl_lower_input_instr,
134       NULL
135    );
136 
137    nir_foreach_variable_with_modes(var, shader, nir_var_mem_ubo) {
138       var->data.binding++;
139       var->data.driver_location++;
140    }
141    shader->info.num_ubos++;
142 
143    if (shader->num_uniforms > 0) {
144       const struct glsl_type *type = glsl_array_type(glsl_uint8_t_type(), shader->num_uniforms, 1);
145       nir_variable *ubo = nir_variable_create(shader, nir_var_mem_ubo, type, "kernel_input");
146       ubo->data.binding = 0;
147       ubo->data.explicit_binding = 1;
148    }
149 
150    shader->info.first_ubo_is_default_ubo = true;
151    return progress;
152 }
153