xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_functions.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1*61046927SAndroid Build Coastguard Worker /*
2*61046927SAndroid Build Coastguard Worker  * Copyright © 2015 Intel Corporation
3*61046927SAndroid Build Coastguard Worker  *
4*61046927SAndroid Build Coastguard Worker  * Permission is hereby granted, free of charge, to any person obtaining a
5*61046927SAndroid Build Coastguard Worker  * copy of this software and associated documentation files (the "Software"),
6*61046927SAndroid Build Coastguard Worker  * to deal in the Software without restriction, including without limitation
7*61046927SAndroid Build Coastguard Worker  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8*61046927SAndroid Build Coastguard Worker  * and/or sell copies of the Software, and to permit persons to whom the
9*61046927SAndroid Build Coastguard Worker  * Software is furnished to do so, subject to the following conditions:
10*61046927SAndroid Build Coastguard Worker  *
11*61046927SAndroid Build Coastguard Worker  * The above copyright notice and this permission notice (including the next
12*61046927SAndroid Build Coastguard Worker  * paragraph) shall be included in all copies or substantial portions of the
13*61046927SAndroid Build Coastguard Worker  * Software.
14*61046927SAndroid Build Coastguard Worker  *
15*61046927SAndroid Build Coastguard Worker  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16*61046927SAndroid Build Coastguard Worker  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17*61046927SAndroid Build Coastguard Worker  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18*61046927SAndroid Build Coastguard Worker  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19*61046927SAndroid Build Coastguard Worker  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20*61046927SAndroid Build Coastguard Worker  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21*61046927SAndroid Build Coastguard Worker  * IN THE SOFTWARE.
22*61046927SAndroid Build Coastguard Worker  */
23*61046927SAndroid Build Coastguard Worker 
24*61046927SAndroid Build Coastguard Worker #include "nir.h"
25*61046927SAndroid Build Coastguard Worker #include "nir_builder.h"
26*61046927SAndroid Build Coastguard Worker #include "nir_control_flow.h"
27*61046927SAndroid Build Coastguard Worker #include "nir_vla.h"
28*61046927SAndroid Build Coastguard Worker 
29*61046927SAndroid Build Coastguard Worker /*
30*61046927SAndroid Build Coastguard Worker  * TODO: write a proper inliner for GPUs.
31*61046927SAndroid Build Coastguard Worker  * This heuristic just inlines small functions,
32*61046927SAndroid Build Coastguard Worker  * and tail calls get inlined as well.
33*61046927SAndroid Build Coastguard Worker  */
34*61046927SAndroid Build Coastguard Worker static bool
nir_function_can_inline(nir_function * function)35*61046927SAndroid Build Coastguard Worker nir_function_can_inline(nir_function *function)
36*61046927SAndroid Build Coastguard Worker {
37*61046927SAndroid Build Coastguard Worker    bool can_inline = true;
38*61046927SAndroid Build Coastguard Worker    if (!function->should_inline) {
39*61046927SAndroid Build Coastguard Worker       if (function->impl) {
40*61046927SAndroid Build Coastguard Worker          if (function->impl->num_blocks > 2)
41*61046927SAndroid Build Coastguard Worker             can_inline = false;
42*61046927SAndroid Build Coastguard Worker          if (function->impl->ssa_alloc > 45)
43*61046927SAndroid Build Coastguard Worker             can_inline = false;
44*61046927SAndroid Build Coastguard Worker       }
45*61046927SAndroid Build Coastguard Worker    }
46*61046927SAndroid Build Coastguard Worker    return can_inline;
47*61046927SAndroid Build Coastguard Worker }
48*61046927SAndroid Build Coastguard Worker 
49*61046927SAndroid Build Coastguard Worker static bool
function_ends_in_jump(nir_function_impl * impl)50*61046927SAndroid Build Coastguard Worker function_ends_in_jump(nir_function_impl *impl)
51*61046927SAndroid Build Coastguard Worker {
52*61046927SAndroid Build Coastguard Worker    nir_block *last_block = nir_impl_last_block(impl);
53*61046927SAndroid Build Coastguard Worker    return nir_block_ends_in_jump(last_block);
54*61046927SAndroid Build Coastguard Worker }
55*61046927SAndroid Build Coastguard Worker 
56*61046927SAndroid Build Coastguard Worker /* A cast is used to deref function in/out params. However the bindless
57*61046927SAndroid Build Coastguard Worker  * textures spec allows both uniforms and functions temps to be passed to a
58*61046927SAndroid Build Coastguard Worker  * function param defined the same way. To deal with this we need to update
59*61046927SAndroid Build Coastguard Worker  * this when we inline and know what variable mode we are dealing with.
60*61046927SAndroid Build Coastguard Worker  */
61*61046927SAndroid Build Coastguard Worker static void
fixup_cast_deref_mode(nir_deref_instr * deref)62*61046927SAndroid Build Coastguard Worker fixup_cast_deref_mode(nir_deref_instr *deref)
63*61046927SAndroid Build Coastguard Worker {
64*61046927SAndroid Build Coastguard Worker    nir_deref_instr *parent = nir_src_as_deref(deref->parent);
65*61046927SAndroid Build Coastguard Worker    if (parent && parent->modes & nir_var_uniform &&
66*61046927SAndroid Build Coastguard Worker        deref->modes & nir_var_function_temp) {
67*61046927SAndroid Build Coastguard Worker       deref->modes |= nir_var_uniform;
68*61046927SAndroid Build Coastguard Worker       deref->modes ^= nir_var_function_temp;
69*61046927SAndroid Build Coastguard Worker 
70*61046927SAndroid Build Coastguard Worker       nir_foreach_use(use, &deref->def) {
71*61046927SAndroid Build Coastguard Worker          if (nir_src_parent_instr(use)->type != nir_instr_type_deref)
72*61046927SAndroid Build Coastguard Worker             continue;
73*61046927SAndroid Build Coastguard Worker 
74*61046927SAndroid Build Coastguard Worker          /* Recurse into children */
75*61046927SAndroid Build Coastguard Worker          fixup_cast_deref_mode(nir_instr_as_deref(nir_src_parent_instr(use)));
76*61046927SAndroid Build Coastguard Worker       }
77*61046927SAndroid Build Coastguard Worker    }
78*61046927SAndroid Build Coastguard Worker }
79*61046927SAndroid Build Coastguard Worker 
80*61046927SAndroid Build Coastguard Worker void
nir_inline_function_impl(struct nir_builder * b,const nir_function_impl * impl,nir_def ** params,struct hash_table * shader_var_remap)81*61046927SAndroid Build Coastguard Worker nir_inline_function_impl(struct nir_builder *b,
82*61046927SAndroid Build Coastguard Worker                          const nir_function_impl *impl,
83*61046927SAndroid Build Coastguard Worker                          nir_def **params,
84*61046927SAndroid Build Coastguard Worker                          struct hash_table *shader_var_remap)
85*61046927SAndroid Build Coastguard Worker {
86*61046927SAndroid Build Coastguard Worker    nir_function_impl *copy = nir_function_impl_clone(b->shader, impl);
87*61046927SAndroid Build Coastguard Worker 
88*61046927SAndroid Build Coastguard Worker    exec_list_append(&b->impl->locals, &copy->locals);
89*61046927SAndroid Build Coastguard Worker 
90*61046927SAndroid Build Coastguard Worker    nir_foreach_block(block, copy) {
91*61046927SAndroid Build Coastguard Worker       nir_foreach_instr_safe(instr, block) {
92*61046927SAndroid Build Coastguard Worker          switch (instr->type) {
93*61046927SAndroid Build Coastguard Worker          case nir_instr_type_deref: {
94*61046927SAndroid Build Coastguard Worker             nir_deref_instr *deref = nir_instr_as_deref(instr);
95*61046927SAndroid Build Coastguard Worker 
96*61046927SAndroid Build Coastguard Worker             /* Note: This shouldn't change the mode of anything but the
97*61046927SAndroid Build Coastguard Worker              * replaced nir_intrinsic_load_param intrinsics handled later in
98*61046927SAndroid Build Coastguard Worker              * this switch table. Any incorrect modes should have already been
99*61046927SAndroid Build Coastguard Worker              * detected by previous nir_vaidate calls.
100*61046927SAndroid Build Coastguard Worker              */
101*61046927SAndroid Build Coastguard Worker             if (deref->deref_type == nir_deref_type_cast) {
102*61046927SAndroid Build Coastguard Worker                fixup_cast_deref_mode(deref);
103*61046927SAndroid Build Coastguard Worker                break;
104*61046927SAndroid Build Coastguard Worker             }
105*61046927SAndroid Build Coastguard Worker 
106*61046927SAndroid Build Coastguard Worker             if (deref->deref_type != nir_deref_type_var)
107*61046927SAndroid Build Coastguard Worker                break;
108*61046927SAndroid Build Coastguard Worker 
109*61046927SAndroid Build Coastguard Worker             /* We don't need to remap function variables.  We already cloned
110*61046927SAndroid Build Coastguard Worker              * them as part of nir_function_impl_clone and appended them to
111*61046927SAndroid Build Coastguard Worker              * b->impl->locals.
112*61046927SAndroid Build Coastguard Worker              */
113*61046927SAndroid Build Coastguard Worker             if (deref->var->data.mode == nir_var_function_temp)
114*61046927SAndroid Build Coastguard Worker                break;
115*61046927SAndroid Build Coastguard Worker 
116*61046927SAndroid Build Coastguard Worker             /* If no map is provided, we assume that there are either no
117*61046927SAndroid Build Coastguard Worker              * shader variables or they already live b->shader (this is the
118*61046927SAndroid Build Coastguard Worker              * case for function inlining within a single shader.
119*61046927SAndroid Build Coastguard Worker              */
120*61046927SAndroid Build Coastguard Worker             if (shader_var_remap == NULL)
121*61046927SAndroid Build Coastguard Worker                break;
122*61046927SAndroid Build Coastguard Worker 
123*61046927SAndroid Build Coastguard Worker             struct hash_entry *entry =
124*61046927SAndroid Build Coastguard Worker                _mesa_hash_table_search(shader_var_remap, deref->var);
125*61046927SAndroid Build Coastguard Worker             if (entry == NULL) {
126*61046927SAndroid Build Coastguard Worker                nir_variable *nvar = nir_variable_clone(deref->var, b->shader);
127*61046927SAndroid Build Coastguard Worker                nir_shader_add_variable(b->shader, nvar);
128*61046927SAndroid Build Coastguard Worker                entry = _mesa_hash_table_insert(shader_var_remap,
129*61046927SAndroid Build Coastguard Worker                                                deref->var, nvar);
130*61046927SAndroid Build Coastguard Worker             }
131*61046927SAndroid Build Coastguard Worker             deref->var = entry->data;
132*61046927SAndroid Build Coastguard Worker             break;
133*61046927SAndroid Build Coastguard Worker          }
134*61046927SAndroid Build Coastguard Worker 
135*61046927SAndroid Build Coastguard Worker          case nir_instr_type_intrinsic: {
136*61046927SAndroid Build Coastguard Worker             nir_intrinsic_instr *load = nir_instr_as_intrinsic(instr);
137*61046927SAndroid Build Coastguard Worker             if (load->intrinsic != nir_intrinsic_load_param)
138*61046927SAndroid Build Coastguard Worker                break;
139*61046927SAndroid Build Coastguard Worker 
140*61046927SAndroid Build Coastguard Worker             unsigned param_idx = nir_intrinsic_param_idx(load);
141*61046927SAndroid Build Coastguard Worker             assert(param_idx < impl->function->num_params);
142*61046927SAndroid Build Coastguard Worker             nir_def_replace(&load->def, params[param_idx]);
143*61046927SAndroid Build Coastguard Worker             break;
144*61046927SAndroid Build Coastguard Worker          }
145*61046927SAndroid Build Coastguard Worker 
146*61046927SAndroid Build Coastguard Worker          case nir_instr_type_jump:
147*61046927SAndroid Build Coastguard Worker             /* Returns have to be lowered for this to work */
148*61046927SAndroid Build Coastguard Worker             assert(nir_instr_as_jump(instr)->type != nir_jump_return);
149*61046927SAndroid Build Coastguard Worker             break;
150*61046927SAndroid Build Coastguard Worker 
151*61046927SAndroid Build Coastguard Worker          default:
152*61046927SAndroid Build Coastguard Worker             break;
153*61046927SAndroid Build Coastguard Worker          }
154*61046927SAndroid Build Coastguard Worker       }
155*61046927SAndroid Build Coastguard Worker    }
156*61046927SAndroid Build Coastguard Worker 
157*61046927SAndroid Build Coastguard Worker    bool nest_if = function_ends_in_jump(copy);
158*61046927SAndroid Build Coastguard Worker 
159*61046927SAndroid Build Coastguard Worker    /* Pluck the body out of the function and place it here */
160*61046927SAndroid Build Coastguard Worker    nir_cf_list body;
161*61046927SAndroid Build Coastguard Worker    nir_cf_list_extract(&body, &copy->body);
162*61046927SAndroid Build Coastguard Worker 
163*61046927SAndroid Build Coastguard Worker    if (nest_if) {
164*61046927SAndroid Build Coastguard Worker       nir_if *cf = nir_push_if(b, nir_imm_true(b));
165*61046927SAndroid Build Coastguard Worker       nir_cf_reinsert(&body, nir_after_cf_list(&cf->then_list));
166*61046927SAndroid Build Coastguard Worker       nir_pop_if(b, cf);
167*61046927SAndroid Build Coastguard Worker    } else {
168*61046927SAndroid Build Coastguard Worker       /* Insert a nop at the cursor so we can keep track of where things are as
169*61046927SAndroid Build Coastguard Worker        * we add/remove stuff from the CFG.
170*61046927SAndroid Build Coastguard Worker        */
171*61046927SAndroid Build Coastguard Worker       nir_intrinsic_instr *nop = nir_nop(b);
172*61046927SAndroid Build Coastguard Worker       nir_cf_reinsert(&body, nir_before_instr(&nop->instr));
173*61046927SAndroid Build Coastguard Worker       b->cursor = nir_instr_remove(&nop->instr);
174*61046927SAndroid Build Coastguard Worker    }
175*61046927SAndroid Build Coastguard Worker }
176*61046927SAndroid Build Coastguard Worker 
177*61046927SAndroid Build Coastguard Worker static bool inline_function_impl(nir_function_impl *impl, struct set *inlined);
178*61046927SAndroid Build Coastguard Worker 
inline_functions_pass(nir_builder * b,nir_instr * instr,void * cb_data)179*61046927SAndroid Build Coastguard Worker static bool inline_functions_pass(nir_builder *b,
180*61046927SAndroid Build Coastguard Worker                                   nir_instr *instr,
181*61046927SAndroid Build Coastguard Worker                                   void *cb_data)
182*61046927SAndroid Build Coastguard Worker {
183*61046927SAndroid Build Coastguard Worker    struct set *inlined = cb_data;
184*61046927SAndroid Build Coastguard Worker    if (instr->type != nir_instr_type_call)
185*61046927SAndroid Build Coastguard Worker       return false;
186*61046927SAndroid Build Coastguard Worker 
187*61046927SAndroid Build Coastguard Worker    nir_call_instr *call = nir_instr_as_call(instr);
188*61046927SAndroid Build Coastguard Worker    assert(call->callee->impl);
189*61046927SAndroid Build Coastguard Worker 
190*61046927SAndroid Build Coastguard Worker    if (b->shader->options->driver_functions &&
191*61046927SAndroid Build Coastguard Worker        b->shader->info.stage == MESA_SHADER_KERNEL) {
192*61046927SAndroid Build Coastguard Worker       bool last_instr = (instr == nir_block_last_instr(instr->block));
193*61046927SAndroid Build Coastguard Worker       if (!nir_function_can_inline(call->callee) && !last_instr) {
194*61046927SAndroid Build Coastguard Worker          return false;
195*61046927SAndroid Build Coastguard Worker       }
196*61046927SAndroid Build Coastguard Worker    }
197*61046927SAndroid Build Coastguard Worker 
198*61046927SAndroid Build Coastguard Worker    /* Make sure that the function we're calling is already inlined */
199*61046927SAndroid Build Coastguard Worker    inline_function_impl(call->callee->impl, inlined);
200*61046927SAndroid Build Coastguard Worker 
201*61046927SAndroid Build Coastguard Worker    b->cursor = nir_instr_remove(&call->instr);
202*61046927SAndroid Build Coastguard Worker 
203*61046927SAndroid Build Coastguard Worker    /* Rewrite all of the uses of the callee's parameters to use the call
204*61046927SAndroid Build Coastguard Worker     * instructions sources.  In order to ensure that the "load" happens
205*61046927SAndroid Build Coastguard Worker     * here and not later (for register sources), we make sure to convert it
206*61046927SAndroid Build Coastguard Worker     * to an SSA value first.
207*61046927SAndroid Build Coastguard Worker     */
208*61046927SAndroid Build Coastguard Worker    const unsigned num_params = call->num_params;
209*61046927SAndroid Build Coastguard Worker    NIR_VLA(nir_def *, params, num_params);
210*61046927SAndroid Build Coastguard Worker    for (unsigned i = 0; i < num_params; i++) {
211*61046927SAndroid Build Coastguard Worker       params[i] = call->params[i].ssa;
212*61046927SAndroid Build Coastguard Worker    }
213*61046927SAndroid Build Coastguard Worker 
214*61046927SAndroid Build Coastguard Worker    nir_inline_function_impl(b, call->callee->impl, params, NULL);
215*61046927SAndroid Build Coastguard Worker    return true;
216*61046927SAndroid Build Coastguard Worker }
217*61046927SAndroid Build Coastguard Worker 
218*61046927SAndroid Build Coastguard Worker static bool
inline_function_impl(nir_function_impl * impl,struct set * inlined)219*61046927SAndroid Build Coastguard Worker inline_function_impl(nir_function_impl *impl, struct set *inlined)
220*61046927SAndroid Build Coastguard Worker {
221*61046927SAndroid Build Coastguard Worker    if (_mesa_set_search(inlined, impl))
222*61046927SAndroid Build Coastguard Worker       return false; /* Already inlined */
223*61046927SAndroid Build Coastguard Worker 
224*61046927SAndroid Build Coastguard Worker    bool progress;
225*61046927SAndroid Build Coastguard Worker    progress = nir_function_instructions_pass(impl, inline_functions_pass,
226*61046927SAndroid Build Coastguard Worker                                              nir_metadata_none, inlined);
227*61046927SAndroid Build Coastguard Worker    if (progress) {
228*61046927SAndroid Build Coastguard Worker       /* Indices are completely messed up now */
229*61046927SAndroid Build Coastguard Worker       nir_index_ssa_defs(impl);
230*61046927SAndroid Build Coastguard Worker    }
231*61046927SAndroid Build Coastguard Worker 
232*61046927SAndroid Build Coastguard Worker    _mesa_set_add(inlined, impl);
233*61046927SAndroid Build Coastguard Worker 
234*61046927SAndroid Build Coastguard Worker    return progress;
235*61046927SAndroid Build Coastguard Worker }
236*61046927SAndroid Build Coastguard Worker 
237*61046927SAndroid Build Coastguard Worker /** A pass to inline all functions in a shader into their callers
238*61046927SAndroid Build Coastguard Worker  *
239*61046927SAndroid Build Coastguard Worker  * For most use-cases, function inlining is a multi-step process.  The general
240*61046927SAndroid Build Coastguard Worker  * pattern employed by SPIR-V consumers and others is as follows:
241*61046927SAndroid Build Coastguard Worker  *
242*61046927SAndroid Build Coastguard Worker  *  1. nir_lower_variable_initializers(shader, nir_var_function_temp)
243*61046927SAndroid Build Coastguard Worker  *
244*61046927SAndroid Build Coastguard Worker  *     This is needed because local variables from the callee are simply added
245*61046927SAndroid Build Coastguard Worker  *     to the locals list for the caller and the information about where the
246*61046927SAndroid Build Coastguard Worker  *     constant initializer logically happens is lost.  If the callee is
247*61046927SAndroid Build Coastguard Worker  *     called in a loop, this can cause the variable to go from being
248*61046927SAndroid Build Coastguard Worker  *     initialized once per loop iteration to being initialized once at the
249*61046927SAndroid Build Coastguard Worker  *     top of the caller and values to persist from one invocation of the
250*61046927SAndroid Build Coastguard Worker  *     callee to the next.  The simple solution to this problem is to get rid
251*61046927SAndroid Build Coastguard Worker  *     of constant initializers before function inlining.
252*61046927SAndroid Build Coastguard Worker  *
253*61046927SAndroid Build Coastguard Worker  *  2. nir_lower_returns(shader)
254*61046927SAndroid Build Coastguard Worker  *
255*61046927SAndroid Build Coastguard Worker  *     nir_inline_functions assumes that all functions end "naturally" by
256*61046927SAndroid Build Coastguard Worker  *     execution reaching the end of the function without any return
257*61046927SAndroid Build Coastguard Worker  *     instructions causing instant jumps to the end.  Thanks to NIR being
258*61046927SAndroid Build Coastguard Worker  *     structured, we can't represent arbitrary jumps to various points in the
259*61046927SAndroid Build Coastguard Worker  *     program which is what an early return in the callee would have to turn
260*61046927SAndroid Build Coastguard Worker  *     into when we inline it into the caller.  Instead, we require returns to
261*61046927SAndroid Build Coastguard Worker  *     be lowered which lets us just copy+paste the callee directly into the
262*61046927SAndroid Build Coastguard Worker  *     caller.
263*61046927SAndroid Build Coastguard Worker  *
264*61046927SAndroid Build Coastguard Worker  *  3. nir_inline_functions(shader)
265*61046927SAndroid Build Coastguard Worker  *
266*61046927SAndroid Build Coastguard Worker  *     This does the actual function inlining and the resulting shader will
267*61046927SAndroid Build Coastguard Worker  *     contain no call instructions.
268*61046927SAndroid Build Coastguard Worker  *
269*61046927SAndroid Build Coastguard Worker  *  4. nir_opt_deref(shader)
270*61046927SAndroid Build Coastguard Worker  *
271*61046927SAndroid Build Coastguard Worker  *     Most functions contain pointer parameters where the result of a deref
272*61046927SAndroid Build Coastguard Worker  *     instruction is passed in as a parameter, loaded via a load_param
273*61046927SAndroid Build Coastguard Worker  *     intrinsic, and then turned back into a deref via a cast.  Function
274*61046927SAndroid Build Coastguard Worker  *     inlining will get rid of the load_param but we are still left with a
275*61046927SAndroid Build Coastguard Worker  *     cast.  Running nir_opt_deref gets rid of the intermediate cast and
276*61046927SAndroid Build Coastguard Worker  *     results in a whole deref chain again.  This is currently required by a
277*61046927SAndroid Build Coastguard Worker  *     number of optimizations and lowering passes at least for certain
278*61046927SAndroid Build Coastguard Worker  *     variable modes.
279*61046927SAndroid Build Coastguard Worker  *
280*61046927SAndroid Build Coastguard Worker  *  5. Loop over the functions and delete all but the main entrypoint.
281*61046927SAndroid Build Coastguard Worker  *
282*61046927SAndroid Build Coastguard Worker  *     In the Intel Vulkan driver this looks like this:
283*61046927SAndroid Build Coastguard Worker  *
284*61046927SAndroid Build Coastguard Worker  *        nir_remove_non_entrypoints(nir);
285*61046927SAndroid Build Coastguard Worker  *
286*61046927SAndroid Build Coastguard Worker  *    While nir_inline_functions does get rid of all call instructions, it
287*61046927SAndroid Build Coastguard Worker  *    doesn't get rid of any functions because it doesn't know what the "root
288*61046927SAndroid Build Coastguard Worker  *    function" is.  Instead, it's up to the individual driver to know how to
289*61046927SAndroid Build Coastguard Worker  *    decide on a root function and delete the rest.  With SPIR-V,
290*61046927SAndroid Build Coastguard Worker  *    spirv_to_nir returns the root function and so we can just use == whereas
291*61046927SAndroid Build Coastguard Worker  *    with GL, you may have to look for a function named "main".
292*61046927SAndroid Build Coastguard Worker  *
293*61046927SAndroid Build Coastguard Worker  *  6. nir_lower_variable_initializers(shader, ~nir_var_function_temp)
294*61046927SAndroid Build Coastguard Worker  *
295*61046927SAndroid Build Coastguard Worker  *     Lowering constant initializers on inputs, outputs, global variables,
296*61046927SAndroid Build Coastguard Worker  *     etc. requires that we know the main entrypoint so that we know where to
297*61046927SAndroid Build Coastguard Worker  *     initialize them.  Otherwise, we would have to assume that anything
298*61046927SAndroid Build Coastguard Worker  *     could be a main entrypoint and initialize them at the start of every
299*61046927SAndroid Build Coastguard Worker  *     function but that would clearly be wrong if any of those functions were
300*61046927SAndroid Build Coastguard Worker  *     ever called within another function.  Simply requiring a single-
301*61046927SAndroid Build Coastguard Worker  *     entrypoint function shader is the best way to make it well-defined.
302*61046927SAndroid Build Coastguard Worker  */
303*61046927SAndroid Build Coastguard Worker bool
nir_inline_functions(nir_shader * shader)304*61046927SAndroid Build Coastguard Worker nir_inline_functions(nir_shader *shader)
305*61046927SAndroid Build Coastguard Worker {
306*61046927SAndroid Build Coastguard Worker    struct set *inlined = _mesa_pointer_set_create(NULL);
307*61046927SAndroid Build Coastguard Worker    bool progress = false;
308*61046927SAndroid Build Coastguard Worker 
309*61046927SAndroid Build Coastguard Worker    nir_foreach_function_impl(impl, shader) {
310*61046927SAndroid Build Coastguard Worker       progress = inline_function_impl(impl, inlined) || progress;
311*61046927SAndroid Build Coastguard Worker    }
312*61046927SAndroid Build Coastguard Worker 
313*61046927SAndroid Build Coastguard Worker    _mesa_set_destroy(inlined, NULL);
314*61046927SAndroid Build Coastguard Worker 
315*61046927SAndroid Build Coastguard Worker    return progress;
316*61046927SAndroid Build Coastguard Worker }
317*61046927SAndroid Build Coastguard Worker 
318*61046927SAndroid Build Coastguard Worker struct lower_link_state {
319*61046927SAndroid Build Coastguard Worker    struct hash_table *shader_var_remap;
320*61046927SAndroid Build Coastguard Worker    const nir_shader *link_shader;
321*61046927SAndroid Build Coastguard Worker    unsigned printf_index_offset;
322*61046927SAndroid Build Coastguard Worker };
323*61046927SAndroid Build Coastguard Worker 
324*61046927SAndroid Build Coastguard Worker static bool
lower_calls_vars_instr(struct nir_builder * b,nir_instr * instr,void * cb_data)325*61046927SAndroid Build Coastguard Worker lower_calls_vars_instr(struct nir_builder *b,
326*61046927SAndroid Build Coastguard Worker                        nir_instr *instr,
327*61046927SAndroid Build Coastguard Worker                        void *cb_data)
328*61046927SAndroid Build Coastguard Worker {
329*61046927SAndroid Build Coastguard Worker    struct lower_link_state *state = cb_data;
330*61046927SAndroid Build Coastguard Worker 
331*61046927SAndroid Build Coastguard Worker    switch (instr->type) {
332*61046927SAndroid Build Coastguard Worker    case nir_instr_type_deref: {
333*61046927SAndroid Build Coastguard Worker       nir_deref_instr *deref = nir_instr_as_deref(instr);
334*61046927SAndroid Build Coastguard Worker       if (deref->deref_type != nir_deref_type_var)
335*61046927SAndroid Build Coastguard Worker          return false;
336*61046927SAndroid Build Coastguard Worker       if (deref->var->data.mode == nir_var_function_temp)
337*61046927SAndroid Build Coastguard Worker          return false;
338*61046927SAndroid Build Coastguard Worker 
339*61046927SAndroid Build Coastguard Worker       assert(state->shader_var_remap);
340*61046927SAndroid Build Coastguard Worker       struct hash_entry *entry =
341*61046927SAndroid Build Coastguard Worker          _mesa_hash_table_search(state->shader_var_remap, deref->var);
342*61046927SAndroid Build Coastguard Worker       if (entry == NULL) {
343*61046927SAndroid Build Coastguard Worker          nir_variable *nvar = nir_variable_clone(deref->var, b->shader);
344*61046927SAndroid Build Coastguard Worker          nir_shader_add_variable(b->shader, nvar);
345*61046927SAndroid Build Coastguard Worker          entry = _mesa_hash_table_insert(state->shader_var_remap,
346*61046927SAndroid Build Coastguard Worker                                          deref->var, nvar);
347*61046927SAndroid Build Coastguard Worker       }
348*61046927SAndroid Build Coastguard Worker       deref->var = entry->data;
349*61046927SAndroid Build Coastguard Worker       break;
350*61046927SAndroid Build Coastguard Worker    }
351*61046927SAndroid Build Coastguard Worker    case nir_instr_type_call: {
352*61046927SAndroid Build Coastguard Worker       nir_call_instr *ncall = nir_instr_as_call(instr);
353*61046927SAndroid Build Coastguard Worker       if (!ncall->callee->name)
354*61046927SAndroid Build Coastguard Worker          return false;
355*61046927SAndroid Build Coastguard Worker 
356*61046927SAndroid Build Coastguard Worker       nir_function *func = nir_shader_get_function_for_name(b->shader, ncall->callee->name);
357*61046927SAndroid Build Coastguard Worker       if (func) {
358*61046927SAndroid Build Coastguard Worker          ncall->callee = func;
359*61046927SAndroid Build Coastguard Worker          break;
360*61046927SAndroid Build Coastguard Worker       }
361*61046927SAndroid Build Coastguard Worker 
362*61046927SAndroid Build Coastguard Worker       nir_function *new_func;
363*61046927SAndroid Build Coastguard Worker       new_func = nir_shader_get_function_for_name(state->link_shader, ncall->callee->name);
364*61046927SAndroid Build Coastguard Worker       if (new_func)
365*61046927SAndroid Build Coastguard Worker          ncall->callee = nir_function_clone(b->shader, new_func);
366*61046927SAndroid Build Coastguard Worker       break;
367*61046927SAndroid Build Coastguard Worker    }
368*61046927SAndroid Build Coastguard Worker    case nir_instr_type_intrinsic: {
369*61046927SAndroid Build Coastguard Worker       /* Reindex the offset of the printf intrinsic by the number of already
370*61046927SAndroid Build Coastguard Worker        * present printfs in the shader where functions are linked into.
371*61046927SAndroid Build Coastguard Worker        */
372*61046927SAndroid Build Coastguard Worker       if (state->printf_index_offset == 0)
373*61046927SAndroid Build Coastguard Worker          return false;
374*61046927SAndroid Build Coastguard Worker 
375*61046927SAndroid Build Coastguard Worker       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
376*61046927SAndroid Build Coastguard Worker       if (intrin->intrinsic != nir_intrinsic_printf)
377*61046927SAndroid Build Coastguard Worker          return false;
378*61046927SAndroid Build Coastguard Worker 
379*61046927SAndroid Build Coastguard Worker       b->cursor = nir_before_instr(instr);
380*61046927SAndroid Build Coastguard Worker       nir_src_rewrite(&intrin->src[0],
381*61046927SAndroid Build Coastguard Worker                       nir_iadd_imm(b, intrin->src[0].ssa,
382*61046927SAndroid Build Coastguard Worker                                       state->printf_index_offset));
383*61046927SAndroid Build Coastguard Worker       break;
384*61046927SAndroid Build Coastguard Worker    }
385*61046927SAndroid Build Coastguard Worker    default:
386*61046927SAndroid Build Coastguard Worker       break;
387*61046927SAndroid Build Coastguard Worker    }
388*61046927SAndroid Build Coastguard Worker    return true;
389*61046927SAndroid Build Coastguard Worker }
390*61046927SAndroid Build Coastguard Worker 
391*61046927SAndroid Build Coastguard Worker static bool
lower_call_function_impl(struct nir_builder * b,nir_function * callee,const nir_function_impl * impl,struct lower_link_state * state)392*61046927SAndroid Build Coastguard Worker lower_call_function_impl(struct nir_builder *b,
393*61046927SAndroid Build Coastguard Worker                          nir_function *callee,
394*61046927SAndroid Build Coastguard Worker                          const nir_function_impl *impl,
395*61046927SAndroid Build Coastguard Worker                          struct lower_link_state *state)
396*61046927SAndroid Build Coastguard Worker {
397*61046927SAndroid Build Coastguard Worker    nir_function_impl *copy = nir_function_impl_clone(b->shader, impl);
398*61046927SAndroid Build Coastguard Worker    copy->function = callee;
399*61046927SAndroid Build Coastguard Worker    callee->impl = copy;
400*61046927SAndroid Build Coastguard Worker 
401*61046927SAndroid Build Coastguard Worker    return nir_function_instructions_pass(copy,
402*61046927SAndroid Build Coastguard Worker                                          lower_calls_vars_instr,
403*61046927SAndroid Build Coastguard Worker                                          nir_metadata_none,
404*61046927SAndroid Build Coastguard Worker                                          state);
405*61046927SAndroid Build Coastguard Worker }
406*61046927SAndroid Build Coastguard Worker 
407*61046927SAndroid Build Coastguard Worker static bool
function_link_pass(struct nir_builder * b,nir_instr * instr,void * cb_data)408*61046927SAndroid Build Coastguard Worker function_link_pass(struct nir_builder *b,
409*61046927SAndroid Build Coastguard Worker                    nir_instr *instr,
410*61046927SAndroid Build Coastguard Worker                    void *cb_data)
411*61046927SAndroid Build Coastguard Worker {
412*61046927SAndroid Build Coastguard Worker    struct lower_link_state *state = cb_data;
413*61046927SAndroid Build Coastguard Worker 
414*61046927SAndroid Build Coastguard Worker    if (instr->type != nir_instr_type_call)
415*61046927SAndroid Build Coastguard Worker       return false;
416*61046927SAndroid Build Coastguard Worker 
417*61046927SAndroid Build Coastguard Worker    nir_call_instr *call = nir_instr_as_call(instr);
418*61046927SAndroid Build Coastguard Worker    nir_function *func = NULL;
419*61046927SAndroid Build Coastguard Worker 
420*61046927SAndroid Build Coastguard Worker    if (!call->callee->name)
421*61046927SAndroid Build Coastguard Worker       return false;
422*61046927SAndroid Build Coastguard Worker 
423*61046927SAndroid Build Coastguard Worker    if (call->callee->impl)
424*61046927SAndroid Build Coastguard Worker       return false;
425*61046927SAndroid Build Coastguard Worker 
426*61046927SAndroid Build Coastguard Worker    func = nir_shader_get_function_for_name(state->link_shader, call->callee->name);
427*61046927SAndroid Build Coastguard Worker    if (!func || !func->impl) {
428*61046927SAndroid Build Coastguard Worker       return false;
429*61046927SAndroid Build Coastguard Worker    }
430*61046927SAndroid Build Coastguard Worker    return lower_call_function_impl(b, call->callee,
431*61046927SAndroid Build Coastguard Worker                                    func->impl,
432*61046927SAndroid Build Coastguard Worker                                    state);
433*61046927SAndroid Build Coastguard Worker }
434*61046927SAndroid Build Coastguard Worker 
435*61046927SAndroid Build Coastguard Worker bool
nir_link_shader_functions(nir_shader * shader,const nir_shader * link_shader)436*61046927SAndroid Build Coastguard Worker nir_link_shader_functions(nir_shader *shader,
437*61046927SAndroid Build Coastguard Worker                           const nir_shader *link_shader)
438*61046927SAndroid Build Coastguard Worker {
439*61046927SAndroid Build Coastguard Worker    void *ra_ctx = ralloc_context(NULL);
440*61046927SAndroid Build Coastguard Worker    struct hash_table *copy_vars = _mesa_pointer_hash_table_create(ra_ctx);
441*61046927SAndroid Build Coastguard Worker    bool progress = false, overall_progress = false;
442*61046927SAndroid Build Coastguard Worker 
443*61046927SAndroid Build Coastguard Worker    struct lower_link_state state = {
444*61046927SAndroid Build Coastguard Worker       .shader_var_remap = copy_vars,
445*61046927SAndroid Build Coastguard Worker       .link_shader = link_shader,
446*61046927SAndroid Build Coastguard Worker       .printf_index_offset = shader->printf_info_count,
447*61046927SAndroid Build Coastguard Worker    };
448*61046927SAndroid Build Coastguard Worker    /* do progress passes inside the pass */
449*61046927SAndroid Build Coastguard Worker    do {
450*61046927SAndroid Build Coastguard Worker       progress = false;
451*61046927SAndroid Build Coastguard Worker       nir_foreach_function_impl(impl, shader) {
452*61046927SAndroid Build Coastguard Worker          bool this_progress = nir_function_instructions_pass(impl,
453*61046927SAndroid Build Coastguard Worker                                                              function_link_pass,
454*61046927SAndroid Build Coastguard Worker                                                              nir_metadata_none,
455*61046927SAndroid Build Coastguard Worker                                                              &state);
456*61046927SAndroid Build Coastguard Worker          if (this_progress)
457*61046927SAndroid Build Coastguard Worker             nir_index_ssa_defs(impl);
458*61046927SAndroid Build Coastguard Worker          progress |= this_progress;
459*61046927SAndroid Build Coastguard Worker       }
460*61046927SAndroid Build Coastguard Worker       overall_progress |= progress;
461*61046927SAndroid Build Coastguard Worker    } while (progress);
462*61046927SAndroid Build Coastguard Worker 
463*61046927SAndroid Build Coastguard Worker    if (overall_progress && link_shader->printf_info_count > 0) {
464*61046927SAndroid Build Coastguard Worker       shader->printf_info = reralloc(shader, shader->printf_info,
465*61046927SAndroid Build Coastguard Worker                                      u_printf_info,
466*61046927SAndroid Build Coastguard Worker                                      shader->printf_info_count +
467*61046927SAndroid Build Coastguard Worker                                      link_shader->printf_info_count);
468*61046927SAndroid Build Coastguard Worker 
469*61046927SAndroid Build Coastguard Worker       for (unsigned i = 0; i < link_shader->printf_info_count; i++){
470*61046927SAndroid Build Coastguard Worker          const u_printf_info *src_info = &link_shader->printf_info[i];
471*61046927SAndroid Build Coastguard Worker          u_printf_info *dst_info = &shader->printf_info[shader->printf_info_count++];
472*61046927SAndroid Build Coastguard Worker 
473*61046927SAndroid Build Coastguard Worker          dst_info->num_args = src_info->num_args;
474*61046927SAndroid Build Coastguard Worker          dst_info->arg_sizes = ralloc_array(shader, unsigned, dst_info->num_args);
475*61046927SAndroid Build Coastguard Worker          memcpy(dst_info->arg_sizes, src_info->arg_sizes,
476*61046927SAndroid Build Coastguard Worker                 sizeof(dst_info->arg_sizes[0]) * dst_info->num_args);
477*61046927SAndroid Build Coastguard Worker 
478*61046927SAndroid Build Coastguard Worker          dst_info->string_size = src_info->string_size;
479*61046927SAndroid Build Coastguard Worker          dst_info->strings = ralloc_memdup(shader, src_info->strings,
480*61046927SAndroid Build Coastguard Worker                                            dst_info->string_size);
481*61046927SAndroid Build Coastguard Worker       }
482*61046927SAndroid Build Coastguard Worker    }
483*61046927SAndroid Build Coastguard Worker 
484*61046927SAndroid Build Coastguard Worker    ralloc_free(ra_ctx);
485*61046927SAndroid Build Coastguard Worker 
486*61046927SAndroid Build Coastguard Worker    return overall_progress;
487*61046927SAndroid Build Coastguard Worker }
488*61046927SAndroid Build Coastguard Worker 
489*61046927SAndroid Build Coastguard Worker static void
490*61046927SAndroid Build Coastguard Worker nir_mark_used_functions(struct nir_function *func, struct set *used_funcs);
491*61046927SAndroid Build Coastguard Worker 
mark_used_pass_cb(struct nir_builder * b,nir_instr * instr,void * data)492*61046927SAndroid Build Coastguard Worker static bool mark_used_pass_cb(struct nir_builder *b,
493*61046927SAndroid Build Coastguard Worker                               nir_instr *instr, void *data)
494*61046927SAndroid Build Coastguard Worker {
495*61046927SAndroid Build Coastguard Worker    struct set *used_funcs = data;
496*61046927SAndroid Build Coastguard Worker    if (instr->type != nir_instr_type_call)
497*61046927SAndroid Build Coastguard Worker       return false;
498*61046927SAndroid Build Coastguard Worker    nir_call_instr *call = nir_instr_as_call(instr);
499*61046927SAndroid Build Coastguard Worker 
500*61046927SAndroid Build Coastguard Worker    _mesa_set_add(used_funcs, call->callee);
501*61046927SAndroid Build Coastguard Worker 
502*61046927SAndroid Build Coastguard Worker    nir_mark_used_functions(call->callee, used_funcs);
503*61046927SAndroid Build Coastguard Worker    return true;
504*61046927SAndroid Build Coastguard Worker }
505*61046927SAndroid Build Coastguard Worker 
506*61046927SAndroid Build Coastguard Worker static void
nir_mark_used_functions(struct nir_function * func,struct set * used_funcs)507*61046927SAndroid Build Coastguard Worker nir_mark_used_functions(struct nir_function *func, struct set *used_funcs)
508*61046927SAndroid Build Coastguard Worker {
509*61046927SAndroid Build Coastguard Worker    if (func->impl) {
510*61046927SAndroid Build Coastguard Worker       nir_function_instructions_pass(func->impl,
511*61046927SAndroid Build Coastguard Worker                                      mark_used_pass_cb,
512*61046927SAndroid Build Coastguard Worker                                      nir_metadata_none,
513*61046927SAndroid Build Coastguard Worker                                      used_funcs);
514*61046927SAndroid Build Coastguard Worker    }
515*61046927SAndroid Build Coastguard Worker }
516*61046927SAndroid Build Coastguard Worker 
517*61046927SAndroid Build Coastguard Worker void
nir_cleanup_functions(nir_shader * nir)518*61046927SAndroid Build Coastguard Worker nir_cleanup_functions(nir_shader *nir)
519*61046927SAndroid Build Coastguard Worker {
520*61046927SAndroid Build Coastguard Worker    if (!nir->options->driver_functions) {
521*61046927SAndroid Build Coastguard Worker       nir_remove_non_entrypoints(nir);
522*61046927SAndroid Build Coastguard Worker       return;
523*61046927SAndroid Build Coastguard Worker    }
524*61046927SAndroid Build Coastguard Worker 
525*61046927SAndroid Build Coastguard Worker    struct set *used_funcs = _mesa_set_create(NULL, _mesa_hash_pointer,
526*61046927SAndroid Build Coastguard Worker                                              _mesa_key_pointer_equal);
527*61046927SAndroid Build Coastguard Worker    foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
528*61046927SAndroid Build Coastguard Worker       if (func->is_entrypoint) {
529*61046927SAndroid Build Coastguard Worker          _mesa_set_add(used_funcs, func);
530*61046927SAndroid Build Coastguard Worker          nir_mark_used_functions(func, used_funcs);
531*61046927SAndroid Build Coastguard Worker       }
532*61046927SAndroid Build Coastguard Worker    }
533*61046927SAndroid Build Coastguard Worker    foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
534*61046927SAndroid Build Coastguard Worker       if (!_mesa_set_search(used_funcs, func))
535*61046927SAndroid Build Coastguard Worker          exec_node_remove(&func->node);
536*61046927SAndroid Build Coastguard Worker    }
537*61046927SAndroid Build Coastguard Worker    _mesa_set_destroy(used_funcs, NULL);
538*61046927SAndroid Build Coastguard Worker }
539