xref: /aosp_15_r20/external/mesa3d/src/microsoft/clc/clc_nir.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "util/u_math.h"
25 #include "nir.h"
26 #include "glsl_types.h"
27 #include "nir_builder.h"
28 #include "nir_deref.h"
29 
30 #include "clc_nir.h"
31 #include "clc_compiler.h"
32 #include "../compiler/dxil_nir.h"
33 
34 static nir_def *
load_ubo(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var,unsigned offset)35 load_ubo(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var, unsigned offset)
36 {
37    return nir_load_ubo(b,
38                        intr->def.num_components,
39                        intr->def.bit_size,
40                        nir_imm_int(b, var->data.binding),
41                        nir_imm_int(b, offset),
42                        .align_mul = 256,
43                        .align_offset = offset,
44                        .range_base = offset,
45                        .range = intr->def.bit_size * intr->def.num_components / 8);
46 }
47 
48 static bool
lower_load_base_global_invocation_id(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)49 lower_load_base_global_invocation_id(nir_builder *b, nir_intrinsic_instr *intr,
50                                     nir_variable *var)
51 {
52    b->cursor = nir_after_instr(&intr->instr);
53 
54    nir_def *offset = load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
55                                                          global_offset_x));
56    nir_def_replace(&intr->def, offset);
57    return true;
58 }
59 
60 static bool
lower_load_work_dim(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)61 lower_load_work_dim(nir_builder *b, nir_intrinsic_instr *intr,
62                     nir_variable *var)
63 {
64    b->cursor = nir_after_instr(&intr->instr);
65 
66    nir_def *dim = load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
67                                                       work_dim));
68    nir_def_replace(&intr->def, dim);
69    return true;
70 }
71 
72 static bool
lower_load_num_workgroups(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)73 lower_load_num_workgroups(nir_builder *b, nir_intrinsic_instr *intr,
74                           nir_variable *var)
75 {
76    b->cursor = nir_after_instr(&intr->instr);
77 
78    nir_def *count =
79       load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
80                                       group_count_total_x));
81    nir_def_replace(&intr->def, count);
82    return true;
83 }
84 
85 static bool
lower_load_base_workgroup_id(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)86 lower_load_base_workgroup_id(nir_builder *b, nir_intrinsic_instr *intr,
87                              nir_variable *var)
88 {
89    b->cursor = nir_after_instr(&intr->instr);
90 
91    nir_def *offset =
92       load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
93                                       group_id_offset_x));
94    nir_def_replace(&intr->def, offset);
95    return true;
96 }
97 
98 bool
clc_nir_lower_system_values(nir_shader * nir,nir_variable * var)99 clc_nir_lower_system_values(nir_shader *nir, nir_variable *var)
100 {
101    bool progress = false;
102 
103    foreach_list_typed(nir_function, func, node, &nir->functions) {
104       if (!func->is_entrypoint)
105          continue;
106       assert(func->impl);
107 
108       nir_builder b = nir_builder_create(func->impl);
109 
110       nir_foreach_block(block, func->impl) {
111          nir_foreach_instr_safe(instr, block) {
112             if (instr->type != nir_instr_type_intrinsic)
113                continue;
114 
115             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
116 
117             switch (intr->intrinsic) {
118             case nir_intrinsic_load_base_global_invocation_id:
119                progress |= lower_load_base_global_invocation_id(&b, intr, var);
120                break;
121             case nir_intrinsic_load_work_dim:
122                progress |= lower_load_work_dim(&b, intr, var);
123                break;
124             case nir_intrinsic_load_num_workgroups:
125                progress |= lower_load_num_workgroups(&b, intr, var);
126                break;
127             case nir_intrinsic_load_base_workgroup_id:
128                progress |= lower_load_base_workgroup_id(&b, intr, var);
129                break;
130             default: break;
131             }
132          }
133       }
134    }
135 
136    return progress;
137 }
138 
139 static bool
lower_load_kernel_input(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var)140 lower_load_kernel_input(nir_builder *b, nir_intrinsic_instr *intr,
141                         nir_variable *var)
142 {
143    b->cursor = nir_before_instr(&intr->instr);
144 
145    unsigned bit_size = intr->def.bit_size;
146    enum glsl_base_type base_type;
147 
148    switch (bit_size) {
149    case 64:
150       base_type = GLSL_TYPE_UINT64;
151       break;
152    case 32:
153       base_type = GLSL_TYPE_UINT;
154       break;
155    case 16:
156       base_type = GLSL_TYPE_UINT16;
157       break;
158    case 8:
159       base_type = GLSL_TYPE_UINT8;
160       break;
161    default:
162       unreachable("invalid bit size");
163    }
164 
165    const struct glsl_type *type =
166       glsl_vector_type(base_type, intr->def.num_components);
167    nir_def *ptr = nir_vec2(b, nir_imm_int(b, var->data.binding),
168                                   nir_u2uN(b, intr->src[0].ssa, 32));
169    nir_deref_instr *deref = nir_build_deref_cast(b, ptr, nir_var_mem_ubo, type,
170                                                     bit_size / 8);
171    deref->cast.align_mul = nir_intrinsic_align_mul(intr);
172    deref->cast.align_offset = nir_intrinsic_align_offset(intr);
173 
174    nir_def *result =
175       nir_load_deref(b, deref);
176    nir_def_replace(&intr->def, result);
177    return true;
178 }
179 
180 bool
clc_nir_lower_kernel_input_loads(nir_shader * nir,nir_variable * var)181 clc_nir_lower_kernel_input_loads(nir_shader *nir, nir_variable *var)
182 {
183    bool progress = false;
184 
185    foreach_list_typed(nir_function, func, node, &nir->functions) {
186       if (!func->is_entrypoint)
187          continue;
188       assert(func->impl);
189 
190       nir_builder b = nir_builder_create(func->impl);
191 
192       nir_foreach_block(block, func->impl) {
193          nir_foreach_instr_safe(instr, block) {
194             if (instr->type != nir_instr_type_intrinsic)
195                continue;
196 
197             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
198 
199             if (intr->intrinsic == nir_intrinsic_load_kernel_input)
200                progress |= lower_load_kernel_input(&b, intr, var);
201          }
202       }
203    }
204 
205    return progress;
206 }
207 
208 
209 static nir_variable *
add_printf_var(struct nir_shader * nir,unsigned uav_id)210 add_printf_var(struct nir_shader *nir, unsigned uav_id)
211 {
212    /* This size is arbitrary. Minimum required per spec is 1MB */
213    const unsigned max_printf_size = 1 * 1024 * 1024;
214    const unsigned printf_array_size = max_printf_size / sizeof(unsigned);
215    nir_variable *var =
216       nir_variable_create(nir, nir_var_mem_ssbo,
217                           glsl_array_type(glsl_uint_type(), printf_array_size, sizeof(unsigned)),
218                           "printf");
219    var->data.binding = uav_id;
220    return var;
221 }
222 
223 bool
clc_lower_printf_base(nir_shader * nir,unsigned uav_id)224 clc_lower_printf_base(nir_shader *nir, unsigned uav_id)
225 {
226    nir_variable *printf_var = NULL;
227    nir_def *printf_deref = NULL;
228    nir_foreach_function_impl(impl, nir) {
229       nir_builder b = nir_builder_at(nir_before_impl(impl));
230       bool progress = false;
231 
232       nir_foreach_block(block, impl) {
233          nir_foreach_instr_safe(instr, block) {
234             if (instr->type != nir_instr_type_intrinsic)
235                continue;
236             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
237             if (intrin->intrinsic != nir_intrinsic_load_printf_buffer_address)
238                continue;
239 
240             if (!printf_var) {
241                printf_var = add_printf_var(nir, uav_id);
242                nir_deref_instr *deref = nir_build_deref_var(&b, printf_var);
243                printf_deref = &deref->def;
244             }
245             nir_def_rewrite_uses(&intrin->def, printf_deref);
246             progress = true;
247          }
248       }
249 
250       if (progress)
251          nir_metadata_preserve(impl, nir_metadata_loop_analysis |
252                                      nir_metadata_block_index |
253                                      nir_metadata_dominance);
254       else
255          nir_metadata_preserve(impl, nir_metadata_all);
256    }
257 
258    return printf_var != NULL;
259 }
260 
261 /* Find patterns of:
262  * - deref_var for one of the kernel inputs
263  * - load_deref to get a pointer to global/constant memory
264  * - cast the pointer into a deref
265  * - use a basic deref chain that terminates in a load/store/atomic
266  *
267  * When this pattern is found, replace the load_deref with a constant value,
268  * based on which kernel argument is being loaded from. This can only be done
269  * for chains that terminate in a pointer access, since the presence of null
270  * pointers should be detected by actually performing the load and inspecting
271  * the resulting pointer value.
272  */
273 static bool
lower_deref_base_to_constant(nir_builder * b,nir_intrinsic_instr * intr,void * context)274 lower_deref_base_to_constant(nir_builder *b, nir_intrinsic_instr *intr, void *context)
275 {
276    switch (intr->intrinsic) {
277    case nir_intrinsic_load_deref:
278    case nir_intrinsic_store_deref:
279    case nir_intrinsic_deref_atomic:
280    case nir_intrinsic_deref_atomic_swap:
281       break;
282    default:
283       return false;
284    }
285 
286    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
287    if (!nir_deref_mode_must_be(deref, nir_var_mem_global | nir_var_mem_constant))
288       return false;
289 
290    nir_deref_path path;
291    nir_deref_path_init(&path, deref, NULL);
292    bool ret = false;
293 
294    if (path.path[0]->deref_type != nir_deref_type_cast)
295       goto done;
296    if (!nir_deref_mode_must_be(path.path[0], nir_var_mem_global | nir_var_mem_constant))
297       goto done;
298 
299    nir_instr *cast_src = path.path[0]->parent.ssa->parent_instr;
300    if (cast_src->type != nir_instr_type_intrinsic)
301       goto done;
302 
303    nir_intrinsic_instr *cast_src_intr = nir_instr_as_intrinsic(cast_src);
304    if (cast_src_intr->intrinsic != nir_intrinsic_load_deref)
305       goto done;
306 
307    nir_deref_instr *load_deref_src = nir_src_as_deref(cast_src_intr->src[0]);
308    if (load_deref_src->deref_type != nir_deref_type_var ||
309        load_deref_src->modes != nir_var_uniform)
310       goto done;
311 
312    nir_variable *var = load_deref_src->var;
313 
314    ret = true;
315    b->cursor = nir_before_instr(&path.path[0]->instr);
316    nir_def *original_offset = nir_unpack_64_2x32_split_x(b, &cast_src_intr->def);
317    nir_def *constant_ptr = nir_pack_64_2x32_split(b, original_offset, nir_imm_int(b, var->data.binding));
318    nir_deref_instr *new_path = nir_build_deref_cast_with_alignment(b, constant_ptr, path.path[0]->modes, path.path[0]->type, path.path[0]->cast.ptr_stride,
319                                                                    path.path[0]->cast.align_mul, path.path[0]->cast.align_offset);
320 
321    for (unsigned i = 1; path.path[i]; ++i) {
322       b->cursor = nir_after_instr(&path.path[i]->instr);
323       new_path = nir_build_deref_follower(b, new_path, path.path[i]);
324    }
325    nir_src_rewrite(&intr->src[0], &new_path->def);
326 
327 done:
328    nir_deref_path_finish(&path);
329    return ret;
330 }
331 
332 bool
clc_nir_lower_global_pointers_to_constants(nir_shader * nir)333 clc_nir_lower_global_pointers_to_constants(nir_shader *nir)
334 {
335    return nir_shader_intrinsics_pass(nir, lower_deref_base_to_constant,
336                                      nir_metadata_control_flow |
337                                      nir_metadata_loop_analysis, NULL);
338 }
339