1 #include "CL/cl.h"
2
3 #include "nir.h"
4 #include "nir_builder.h"
5
6 #include "rusticl_nir.h"
7
8 static bool
rusticl_lower_intrinsics_filter(const nir_instr * instr,const void * state)9 rusticl_lower_intrinsics_filter(const nir_instr* instr, const void* state)
10 {
11 return instr->type == nir_instr_type_intrinsic;
12 }
13
14 static nir_def*
rusticl_lower_intrinsics_instr(nir_builder * b,nir_instr * instr,void * _state)15 rusticl_lower_intrinsics_instr(
16 nir_builder *b,
17 nir_instr *instr,
18 void* _state
19 ) {
20 nir_intrinsic_instr *intrins = nir_instr_as_intrinsic(instr);
21 struct rusticl_lower_state *state = _state;
22
23 switch (intrins->intrinsic) {
24 case nir_intrinsic_image_deref_format:
25 case nir_intrinsic_image_deref_order: {
26 int32_t offset;
27 nir_deref_instr *deref;
28 nir_def *val;
29 nir_variable *var;
30
31 if (intrins->intrinsic == nir_intrinsic_image_deref_format) {
32 offset = CL_SNORM_INT8;
33 var = nir_find_variable_with_location(b->shader, nir_var_uniform, state->format_arr_loc);
34 } else {
35 offset = CL_R;
36 var = nir_find_variable_with_location(b->shader, nir_var_uniform, state->order_arr_loc);
37 }
38
39 val = intrins->src[0].ssa;
40
41 if (val->parent_instr->type == nir_instr_type_deref) {
42 nir_deref_instr *deref = nir_instr_as_deref(val->parent_instr);
43 nir_variable *var = nir_deref_instr_get_variable(deref);
44 assert(var);
45 val = nir_imm_intN_t(b, var->data.binding, val->bit_size);
46 }
47
48 // we put write images after read images
49 if (glsl_type_is_image(var->type)) {
50 val = nir_iadd_imm(b, val, b->shader->info.num_textures);
51 }
52
53 deref = nir_build_deref_var(b, var);
54 deref = nir_build_deref_array(b, deref, val);
55 val = nir_u2uN(b, nir_load_deref(b, deref), 32);
56
57 // we have to fix up the value base
58 val = nir_iadd_imm(b, val, -offset);
59
60 return val;
61 }
62 case nir_intrinsic_load_global_invocation_id:
63 if (intrins->def.bit_size == 64)
64 return nir_u2u64(b, nir_load_global_invocation_id(b, 32));
65 return NULL;
66 case nir_intrinsic_load_base_global_invocation_id:
67 return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->base_global_invoc_id_loc));
68 case nir_intrinsic_load_base_workgroup_id:
69 return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->base_workgroup_id_loc));
70 case nir_intrinsic_load_global_size:
71 return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->global_size_loc));
72 case nir_intrinsic_load_num_workgroups:
73 return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->num_workgroups_loc));
74 case nir_intrinsic_load_constant_base_ptr:
75 return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->const_buf_loc));
76 case nir_intrinsic_load_printf_buffer_address:
77 return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->printf_buf_loc));
78 case nir_intrinsic_load_work_dim:
79 assert(nir_find_variable_with_location(b->shader, nir_var_uniform, state->work_dim_loc));
80 return nir_u2uN(b, nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->work_dim_loc)),
81 intrins->def.bit_size);
82 default:
83 return NULL;
84 }
85 }
86
87 bool
rusticl_lower_intrinsics(nir_shader * nir,struct rusticl_lower_state * state)88 rusticl_lower_intrinsics(nir_shader *nir, struct rusticl_lower_state* state)
89 {
90 return nir_shader_lower_instructions(
91 nir,
92 rusticl_lower_intrinsics_filter,
93 rusticl_lower_intrinsics_instr,
94 state
95 );
96 }
97
98 static nir_def*
rusticl_lower_input_instr(struct nir_builder * b,nir_instr * instr,void * _)99 rusticl_lower_input_instr(struct nir_builder *b, nir_instr *instr, void *_)
100 {
101 nir_intrinsic_instr *intrins = nir_instr_as_intrinsic(instr);
102 if (intrins->intrinsic != nir_intrinsic_load_kernel_input)
103 return NULL;
104
105 nir_def *ubo_idx = nir_imm_int(b, 0);
106 nir_def *uniform_offset = intrins->src[0].ssa;
107
108 assert(intrins->def.bit_size >= 8);
109 nir_def *load_result =
110 nir_load_ubo(b, intrins->num_components, intrins->def.bit_size,
111 ubo_idx, nir_iadd_imm(b, uniform_offset, nir_intrinsic_base(intrins)));
112
113 nir_intrinsic_instr *load = nir_instr_as_intrinsic(load_result->parent_instr);
114
115 nir_intrinsic_set_align_mul(load, nir_intrinsic_align_mul(intrins));
116 nir_intrinsic_set_align_offset(load, nir_intrinsic_align_offset(intrins));
117 nir_intrinsic_set_range_base(load, nir_intrinsic_base(intrins));
118 nir_intrinsic_set_range(load, nir_intrinsic_range(intrins));
119
120 return load_result;
121 }
122
123 bool
rusticl_lower_inputs(nir_shader * shader)124 rusticl_lower_inputs(nir_shader *shader)
125 {
126 bool progress = false;
127
128 assert(!shader->info.first_ubo_is_default_ubo);
129
130 progress = nir_shader_lower_instructions(
131 shader,
132 rusticl_lower_intrinsics_filter,
133 rusticl_lower_input_instr,
134 NULL
135 );
136
137 nir_foreach_variable_with_modes(var, shader, nir_var_mem_ubo) {
138 var->data.binding++;
139 var->data.driver_location++;
140 }
141 shader->info.num_ubos++;
142
143 if (shader->num_uniforms > 0) {
144 const struct glsl_type *type = glsl_array_type(glsl_uint8_t_type(), shader->num_uniforms, 1);
145 nir_variable *ubo = nir_variable_create(shader, nir_var_mem_ubo, type, "kernel_input");
146 ubo->data.binding = 0;
147 ubo->data.explicit_binding = 1;
148 }
149
150 shader->info.first_ubo_is_default_ubo = true;
151 return progress;
152 }
153