1 /*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "util/u_math.h"
25 #include "nir.h"
26 #include "glsl_types.h"
27 #include "nir_builder.h"
28 #include "nir_deref.h"
29
30 #include "clc_nir.h"
31 #include "clc_compiler.h"
32 #include "../compiler/dxil_nir.h"
33
34 static nir_def *
load_ubo(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var,unsigned offset)35 load_ubo(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var, unsigned offset)
36 {
37 return nir_load_ubo(b,
38 intr->def.num_components,
39 intr->def.bit_size,
40 nir_imm_int(b, var->data.binding),
41 nir_imm_int(b, offset),
42 .align_mul = 256,
43 .align_offset = offset,
44 .range_base = offset,
45 .range = intr->def.bit_size * intr->def.num_components / 8);
46 }
47
48 static bool
lower_load_base_global_invocation_id(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)49 lower_load_base_global_invocation_id(nir_builder *b, nir_intrinsic_instr *intr,
50 nir_variable *var)
51 {
52 b->cursor = nir_after_instr(&intr->instr);
53
54 nir_def *offset = load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
55 global_offset_x));
56 nir_def_replace(&intr->def, offset);
57 return true;
58 }
59
60 static bool
lower_load_work_dim(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)61 lower_load_work_dim(nir_builder *b, nir_intrinsic_instr *intr,
62 nir_variable *var)
63 {
64 b->cursor = nir_after_instr(&intr->instr);
65
66 nir_def *dim = load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
67 work_dim));
68 nir_def_replace(&intr->def, dim);
69 return true;
70 }
71
72 static bool
lower_load_num_workgroups(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)73 lower_load_num_workgroups(nir_builder *b, nir_intrinsic_instr *intr,
74 nir_variable *var)
75 {
76 b->cursor = nir_after_instr(&intr->instr);
77
78 nir_def *count =
79 load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
80 group_count_total_x));
81 nir_def_replace(&intr->def, count);
82 return true;
83 }
84
85 static bool
lower_load_base_workgroup_id(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)86 lower_load_base_workgroup_id(nir_builder *b, nir_intrinsic_instr *intr,
87 nir_variable *var)
88 {
89 b->cursor = nir_after_instr(&intr->instr);
90
91 nir_def *offset =
92 load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
93 group_id_offset_x));
94 nir_def_replace(&intr->def, offset);
95 return true;
96 }
97
98 bool
clc_nir_lower_system_values(nir_shader * nir,nir_variable * var)99 clc_nir_lower_system_values(nir_shader *nir, nir_variable *var)
100 {
101 bool progress = false;
102
103 foreach_list_typed(nir_function, func, node, &nir->functions) {
104 if (!func->is_entrypoint)
105 continue;
106 assert(func->impl);
107
108 nir_builder b = nir_builder_create(func->impl);
109
110 nir_foreach_block(block, func->impl) {
111 nir_foreach_instr_safe(instr, block) {
112 if (instr->type != nir_instr_type_intrinsic)
113 continue;
114
115 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
116
117 switch (intr->intrinsic) {
118 case nir_intrinsic_load_base_global_invocation_id:
119 progress |= lower_load_base_global_invocation_id(&b, intr, var);
120 break;
121 case nir_intrinsic_load_work_dim:
122 progress |= lower_load_work_dim(&b, intr, var);
123 break;
124 case nir_intrinsic_load_num_workgroups:
125 progress |= lower_load_num_workgroups(&b, intr, var);
126 break;
127 case nir_intrinsic_load_base_workgroup_id:
128 progress |= lower_load_base_workgroup_id(&b, intr, var);
129 break;
130 default: break;
131 }
132 }
133 }
134 }
135
136 return progress;
137 }
138
139 static bool
lower_load_kernel_input(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)140 lower_load_kernel_input(nir_builder *b, nir_intrinsic_instr *intr,
141 nir_variable *var)
142 {
143 b->cursor = nir_before_instr(&intr->instr);
144
145 unsigned bit_size = intr->def.bit_size;
146 enum glsl_base_type base_type;
147
148 switch (bit_size) {
149 case 64:
150 base_type = GLSL_TYPE_UINT64;
151 break;
152 case 32:
153 base_type = GLSL_TYPE_UINT;
154 break;
155 case 16:
156 base_type = GLSL_TYPE_UINT16;
157 break;
158 case 8:
159 base_type = GLSL_TYPE_UINT8;
160 break;
161 default:
162 unreachable("invalid bit size");
163 }
164
165 const struct glsl_type *type =
166 glsl_vector_type(base_type, intr->def.num_components);
167 nir_def *ptr = nir_vec2(b, nir_imm_int(b, var->data.binding),
168 nir_u2uN(b, intr->src[0].ssa, 32));
169 nir_deref_instr *deref = nir_build_deref_cast(b, ptr, nir_var_mem_ubo, type,
170 bit_size / 8);
171 deref->cast.align_mul = nir_intrinsic_align_mul(intr);
172 deref->cast.align_offset = nir_intrinsic_align_offset(intr);
173
174 nir_def *result =
175 nir_load_deref(b, deref);
176 nir_def_replace(&intr->def, result);
177 return true;
178 }
179
180 bool
clc_nir_lower_kernel_input_loads(nir_shader * nir,nir_variable * var)181 clc_nir_lower_kernel_input_loads(nir_shader *nir, nir_variable *var)
182 {
183 bool progress = false;
184
185 foreach_list_typed(nir_function, func, node, &nir->functions) {
186 if (!func->is_entrypoint)
187 continue;
188 assert(func->impl);
189
190 nir_builder b = nir_builder_create(func->impl);
191
192 nir_foreach_block(block, func->impl) {
193 nir_foreach_instr_safe(instr, block) {
194 if (instr->type != nir_instr_type_intrinsic)
195 continue;
196
197 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
198
199 if (intr->intrinsic == nir_intrinsic_load_kernel_input)
200 progress |= lower_load_kernel_input(&b, intr, var);
201 }
202 }
203 }
204
205 return progress;
206 }
207
208
209 static nir_variable *
add_printf_var(struct nir_shader * nir,unsigned uav_id)210 add_printf_var(struct nir_shader *nir, unsigned uav_id)
211 {
212 /* This size is arbitrary. Minimum required per spec is 1MB */
213 const unsigned max_printf_size = 1 * 1024 * 1024;
214 const unsigned printf_array_size = max_printf_size / sizeof(unsigned);
215 nir_variable *var =
216 nir_variable_create(nir, nir_var_mem_ssbo,
217 glsl_array_type(glsl_uint_type(), printf_array_size, sizeof(unsigned)),
218 "printf");
219 var->data.binding = uav_id;
220 return var;
221 }
222
223 bool
clc_lower_printf_base(nir_shader * nir,unsigned uav_id)224 clc_lower_printf_base(nir_shader *nir, unsigned uav_id)
225 {
226 nir_variable *printf_var = NULL;
227 nir_def *printf_deref = NULL;
228 nir_foreach_function_impl(impl, nir) {
229 nir_builder b = nir_builder_at(nir_before_impl(impl));
230 bool progress = false;
231
232 nir_foreach_block(block, impl) {
233 nir_foreach_instr_safe(instr, block) {
234 if (instr->type != nir_instr_type_intrinsic)
235 continue;
236 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
237 if (intrin->intrinsic != nir_intrinsic_load_printf_buffer_address)
238 continue;
239
240 if (!printf_var) {
241 printf_var = add_printf_var(nir, uav_id);
242 nir_deref_instr *deref = nir_build_deref_var(&b, printf_var);
243 printf_deref = &deref->def;
244 }
245 nir_def_rewrite_uses(&intrin->def, printf_deref);
246 progress = true;
247 }
248 }
249
250 if (progress)
251 nir_metadata_preserve(impl, nir_metadata_loop_analysis |
252 nir_metadata_block_index |
253 nir_metadata_dominance);
254 else
255 nir_metadata_preserve(impl, nir_metadata_all);
256 }
257
258 return printf_var != NULL;
259 }
260
261 /* Find patterns of:
262 * - deref_var for one of the kernel inputs
263 * - load_deref to get a pointer to global/constant memory
264 * - cast the pointer into a deref
265 * - use a basic deref chain that terminates in a load/store/atomic
266 *
267 * When this pattern is found, replace the load_deref with a constant value,
268 * based on which kernel argument is being loaded from. This can only be done
269 * for chains that terminate in a pointer access, since the presence of null
270 * pointers should be detected by actually performing the load and inspecting
271 * the resulting pointer value.
272 */
273 static bool
lower_deref_base_to_constant(nir_builder * b,nir_intrinsic_instr * intr,void * context)274 lower_deref_base_to_constant(nir_builder *b, nir_intrinsic_instr *intr, void *context)
275 {
276 switch (intr->intrinsic) {
277 case nir_intrinsic_load_deref:
278 case nir_intrinsic_store_deref:
279 case nir_intrinsic_deref_atomic:
280 case nir_intrinsic_deref_atomic_swap:
281 break;
282 default:
283 return false;
284 }
285
286 nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
287 if (!nir_deref_mode_must_be(deref, nir_var_mem_global | nir_var_mem_constant))
288 return false;
289
290 nir_deref_path path;
291 nir_deref_path_init(&path, deref, NULL);
292 bool ret = false;
293
294 if (path.path[0]->deref_type != nir_deref_type_cast)
295 goto done;
296 if (!nir_deref_mode_must_be(path.path[0], nir_var_mem_global | nir_var_mem_constant))
297 goto done;
298
299 nir_instr *cast_src = path.path[0]->parent.ssa->parent_instr;
300 if (cast_src->type != nir_instr_type_intrinsic)
301 goto done;
302
303 nir_intrinsic_instr *cast_src_intr = nir_instr_as_intrinsic(cast_src);
304 if (cast_src_intr->intrinsic != nir_intrinsic_load_deref)
305 goto done;
306
307 nir_deref_instr *load_deref_src = nir_src_as_deref(cast_src_intr->src[0]);
308 if (load_deref_src->deref_type != nir_deref_type_var ||
309 load_deref_src->modes != nir_var_uniform)
310 goto done;
311
312 nir_variable *var = load_deref_src->var;
313
314 ret = true;
315 b->cursor = nir_before_instr(&path.path[0]->instr);
316 nir_def *original_offset = nir_unpack_64_2x32_split_x(b, &cast_src_intr->def);
317 nir_def *constant_ptr = nir_pack_64_2x32_split(b, original_offset, nir_imm_int(b, var->data.binding));
318 nir_deref_instr *new_path = nir_build_deref_cast_with_alignment(b, constant_ptr, path.path[0]->modes, path.path[0]->type, path.path[0]->cast.ptr_stride,
319 path.path[0]->cast.align_mul, path.path[0]->cast.align_offset);
320
321 for (unsigned i = 1; path.path[i]; ++i) {
322 b->cursor = nir_after_instr(&path.path[i]->instr);
323 new_path = nir_build_deref_follower(b, new_path, path.path[i]);
324 }
325 nir_src_rewrite(&intr->src[0], &new_path->def);
326
327 done:
328 nir_deref_path_finish(&path);
329 return ret;
330 }
331
332 bool
clc_nir_lower_global_pointers_to_constants(nir_shader * nir)333 clc_nir_lower_global_pointers_to_constants(nir_shader *nir)
334 {
335 return nir_shader_intrinsics_pass(nir, lower_deref_base_to_constant,
336 nir_metadata_control_flow |
337 nir_metadata_loop_analysis, NULL);
338 }
339