xref: /aosp_15_r20/external/mesa3d/src/gallium/frontends/lavapipe/lvp_nir_lower_ray_queries.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2023 Valve Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "nir/nir.h"
7 #include "nir/nir_builder.h"
8 
9 #include "lvp_nir_ray_tracing.h"
10 #include "lvp_acceleration_structure.h"
11 #include "lvp_private.h"
12 
13 #include "spirv/spirv.h"
14 
15 #include "util/hash_table.h"
16 
17 typedef struct {
18    nir_variable *variable;
19    unsigned array_length;
20 } rq_variable;
21 
22 static rq_variable *
rq_variable_create(void * ctx,nir_shader * shader,unsigned array_length,const struct glsl_type * type,const char * name)23 rq_variable_create(void *ctx, nir_shader *shader, unsigned array_length,
24                    const struct glsl_type *type, const char *name)
25 {
26    rq_variable *result = ralloc(ctx, rq_variable);
27    result->array_length = array_length;
28 
29    const struct glsl_type *variable_type = type;
30    if (array_length != 1)
31       variable_type = glsl_array_type(type, array_length, glsl_get_explicit_stride(type));
32 
33    result->variable = nir_variable_create(shader, nir_var_shader_temp, variable_type, name);
34 
35    return result;
36 }
37 
38 static nir_def *
nir_load_array(nir_builder * b,nir_variable * array,nir_def * index)39 nir_load_array(nir_builder *b, nir_variable *array, nir_def *index)
40 {
41    return nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, array), index));
42 }
43 
44 static void
nir_store_array(nir_builder * b,nir_variable * array,nir_def * index,nir_def * value,unsigned writemask)45 nir_store_array(nir_builder *b, nir_variable *array, nir_def *index, nir_def *value,
46                 unsigned writemask)
47 {
48    nir_store_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, array), index), value,
49                    writemask);
50 }
51 
52 static nir_deref_instr *
rq_deref_var(nir_builder * b,nir_def * index,rq_variable * var)53 rq_deref_var(nir_builder *b, nir_def *index, rq_variable *var)
54 {
55    if (var->array_length == 1)
56       return nir_build_deref_var(b, var->variable);
57 
58    return nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index);
59 }
60 
61 static nir_def *
rq_load_var(nir_builder * b,nir_def * index,rq_variable * var)62 rq_load_var(nir_builder *b, nir_def *index, rq_variable *var)
63 {
64    if (var->array_length == 1)
65       return nir_load_var(b, var->variable);
66 
67    return nir_load_array(b, var->variable, index);
68 }
69 
70 static void
rq_store_var(nir_builder * b,nir_def * index,rq_variable * var,nir_def * value,unsigned writemask)71 rq_store_var(nir_builder *b, nir_def *index, rq_variable *var, nir_def *value,
72              unsigned writemask)
73 {
74    if (var->array_length == 1) {
75       nir_store_var(b, var->variable, value, writemask);
76    } else {
77       nir_store_array(b, var->variable, index, value, writemask);
78    }
79 }
80 
81 static void
rq_copy_var(nir_builder * b,nir_def * index,rq_variable * dst,rq_variable * src,unsigned mask)82 rq_copy_var(nir_builder *b, nir_def *index, rq_variable *dst, rq_variable *src, unsigned mask)
83 {
84    rq_store_var(b, index, dst, rq_load_var(b, index, src), mask);
85 }
86 
87 static nir_def *
rq_load_array(nir_builder * b,nir_def * index,rq_variable * var,nir_def * array_index)88 rq_load_array(nir_builder *b, nir_def *index, rq_variable *var, nir_def *array_index)
89 {
90    if (var->array_length == 1)
91       return nir_load_array(b, var->variable, array_index);
92 
93    return nir_load_deref(
94       b,
95       nir_build_deref_array(
96          b, nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index), array_index));
97 }
98 
99 static void
rq_store_array(nir_builder * b,nir_def * index,rq_variable * var,nir_def * array_index,nir_def * value,unsigned writemask)100 rq_store_array(nir_builder *b, nir_def *index, rq_variable *var, nir_def *array_index,
101                nir_def *value, unsigned writemask)
102 {
103    if (var->array_length == 1) {
104       nir_store_array(b, var->variable, array_index, value, writemask);
105    } else {
106       nir_store_deref(
107          b,
108          nir_build_deref_array(
109             b, nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index), array_index),
110          value, writemask);
111    }
112 }
113 
114 struct ray_query_traversal_vars {
115    rq_variable *origin;
116    rq_variable *direction;
117 
118    rq_variable *bvh_base;
119    rq_variable *current_node;
120 
121    rq_variable *stack_base;
122    rq_variable *stack_ptr;
123    rq_variable *stack;
124 };
125 
126 struct ray_query_intersection_vars {
127    rq_variable *primitive_id;
128    rq_variable *geometry_id_and_flags;
129    rq_variable *instance_addr;
130    rq_variable *intersection_type;
131    rq_variable *opaque;
132    rq_variable *frontface;
133    rq_variable *sbt_offset_and_flags;
134    rq_variable *barycentrics;
135    rq_variable *t;
136 };
137 
138 struct ray_query_vars {
139    rq_variable *root_bvh_base;
140    rq_variable *flags;
141    rq_variable *cull_mask;
142    rq_variable *origin;
143    rq_variable *tmin;
144    rq_variable *direction;
145 
146    rq_variable *incomplete;
147 
148    struct ray_query_intersection_vars closest;
149    struct ray_query_intersection_vars candidate;
150 
151    struct ray_query_traversal_vars trav;
152 };
153 
154 #define VAR_NAME(name)                                                                             \
155    strcat(strcpy(ralloc_size(ctx, strlen(base_name) + strlen(name) + 1), base_name), name)
156 
157 static struct ray_query_traversal_vars
init_ray_query_traversal_vars(void * ctx,nir_shader * shader,unsigned array_length,const char * base_name)158 init_ray_query_traversal_vars(void *ctx, nir_shader *shader, unsigned array_length,
159                               const char *base_name)
160 {
161    struct ray_query_traversal_vars result;
162 
163    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
164 
165    result.origin = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_origin"));
166    result.direction =
167       rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_direction"));
168 
169    result.bvh_base =
170       rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_bvh_base"));
171    result.current_node =
172       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_current_node"));
173    result.stack_base =
174       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack_base"));
175    result.stack_ptr = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack_ptr"));
176    result.stack = rq_variable_create(ctx, shader, array_length, glsl_array_type(glsl_uint_type(), 24 * 2, 0), VAR_NAME("_stack"));
177    return result;
178 }
179 
180 static struct ray_query_intersection_vars
init_ray_query_intersection_vars(void * ctx,nir_shader * shader,unsigned array_length,const char * base_name)181 init_ray_query_intersection_vars(void *ctx, nir_shader *shader, unsigned array_length,
182                                  const char *base_name)
183 {
184    struct ray_query_intersection_vars result;
185 
186    const struct glsl_type *vec2_type = glsl_vector_type(GLSL_TYPE_FLOAT, 2);
187 
188    result.primitive_id =
189       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_primitive_id"));
190    result.geometry_id_and_flags = rq_variable_create(ctx, shader, array_length, glsl_uint_type(),
191                                                      VAR_NAME("_geometry_id_and_flags"));
192    result.instance_addr = rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(),
193                                              VAR_NAME("_instance_addr"));
194    result.intersection_type = rq_variable_create(ctx, shader, array_length, glsl_uint_type(),
195                                                  VAR_NAME("_intersection_type"));
196    result.opaque =
197       rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_opaque"));
198    result.frontface =
199       rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_frontface"));
200    result.sbt_offset_and_flags = rq_variable_create(ctx, shader, array_length, glsl_uint_type(),
201                                                     VAR_NAME("_sbt_offset_and_flags"));
202    result.barycentrics =
203       rq_variable_create(ctx, shader, array_length, vec2_type, VAR_NAME("_barycentrics"));
204    result.t = rq_variable_create(ctx, shader, array_length, glsl_float_type(), VAR_NAME("_t"));
205 
206    return result;
207 }
208 
209 static void
init_ray_query_vars(nir_shader * shader,unsigned array_length,struct ray_query_vars * dst,const char * base_name)210 init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_vars *dst,
211                     const char *base_name)
212 {
213    void *ctx = dst;
214    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
215 
216    dst->root_bvh_base = rq_variable_create(dst, shader, array_length, glsl_uint64_t_type(),
217                                            VAR_NAME("_root_bvh_base"));
218    dst->flags = rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_flags"));
219    dst->cull_mask =
220       rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_cull_mask"));
221    dst->origin = rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_origin"));
222    dst->tmin = rq_variable_create(dst, shader, array_length, glsl_float_type(), VAR_NAME("_tmin"));
223    dst->direction =
224       rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_direction"));
225 
226    dst->incomplete =
227       rq_variable_create(dst, shader, array_length, glsl_bool_type(), VAR_NAME("_incomplete"));
228 
229    dst->closest = init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_closest"));
230    dst->candidate =
231       init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_candidate"));
232 
233    dst->trav = init_ray_query_traversal_vars(dst, shader, array_length, VAR_NAME("_top"));
234 }
235 
236 #undef VAR_NAME
237 
238 static void
lower_ray_query(nir_shader * shader,nir_variable * ray_query,struct hash_table * ht)239 lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table *ht)
240 {
241    struct ray_query_vars *vars = ralloc(ht, struct ray_query_vars);
242 
243    unsigned array_length = 1;
244    if (glsl_type_is_array(ray_query->type))
245       array_length = glsl_get_length(ray_query->type);
246 
247    init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? "" : ray_query->name);
248 
249    _mesa_hash_table_insert(ht, ray_query, vars);
250 }
251 
252 static void
copy_candidate_to_closest(nir_builder * b,nir_def * index,struct ray_query_vars * vars)253 copy_candidate_to_closest(nir_builder *b, nir_def *index, struct ray_query_vars *vars)
254 {
255    rq_copy_var(b, index, vars->closest.barycentrics, vars->candidate.barycentrics, 0x3);
256    rq_copy_var(b, index, vars->closest.geometry_id_and_flags, vars->candidate.geometry_id_and_flags,
257                0x1);
258    rq_copy_var(b, index, vars->closest.instance_addr, vars->candidate.instance_addr, 0x1);
259    rq_copy_var(b, index, vars->closest.intersection_type, vars->candidate.intersection_type, 0x1);
260    rq_copy_var(b, index, vars->closest.opaque, vars->candidate.opaque, 0x1);
261    rq_copy_var(b, index, vars->closest.frontface, vars->candidate.frontface, 0x1);
262    rq_copy_var(b, index, vars->closest.sbt_offset_and_flags, vars->candidate.sbt_offset_and_flags,
263                0x1);
264    rq_copy_var(b, index, vars->closest.primitive_id, vars->candidate.primitive_id, 0x1);
265    rq_copy_var(b, index, vars->closest.t, vars->candidate.t, 0x1);
266 }
267 
268 static void
insert_terminate_on_first_hit(nir_builder * b,nir_def * index,struct ray_query_vars * vars,bool break_on_terminate)269 insert_terminate_on_first_hit(nir_builder *b, nir_def *index, struct ray_query_vars *vars,
270                               bool break_on_terminate)
271 {
272    nir_def *terminate_on_first_hit =
273       nir_test_mask(b, rq_load_var(b, index, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask);
274    nir_push_if(b, terminate_on_first_hit);
275    {
276       rq_store_var(b, index, vars->incomplete, nir_imm_false(b), 0x1);
277       if (break_on_terminate)
278          nir_jump(b, nir_jump_break);
279    }
280    nir_pop_if(b, NULL);
281 }
282 
283 static void
lower_rq_confirm_intersection(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)284 lower_rq_confirm_intersection(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
285                               struct ray_query_vars *vars)
286 {
287    copy_candidate_to_closest(b, index, vars);
288    insert_terminate_on_first_hit(b, index, vars, false);
289 }
290 
291 static void
lower_rq_generate_intersection(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)292 lower_rq_generate_intersection(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
293                                struct ray_query_vars *vars)
294 {
295    nir_push_if(b, nir_iand(b, nir_fge(b, rq_load_var(b, index, vars->closest.t), instr->src[1].ssa),
296                            nir_fge(b, instr->src[1].ssa, rq_load_var(b, index, vars->tmin))));
297    {
298       copy_candidate_to_closest(b, index, vars);
299       insert_terminate_on_first_hit(b, index, vars, false);
300       rq_store_var(b, index, vars->closest.t, instr->src[1].ssa, 0x1);
301    }
302    nir_pop_if(b, NULL);
303 }
304 
305 enum rq_intersection_type {
306    intersection_type_none,
307    intersection_type_triangle,
308    intersection_type_aabb
309 };
310 
311 static void
lower_rq_initialize(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)312 lower_rq_initialize(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
313                     struct ray_query_vars *vars)
314 {
315    rq_store_var(b, index, vars->flags, instr->src[2].ssa, 0x1);
316    rq_store_var(b, index, vars->cull_mask, nir_ishl_imm(b, instr->src[3].ssa, 24), 0x1);
317 
318    rq_store_var(b, index, vars->origin, instr->src[4].ssa, 0x7);
319    rq_store_var(b, index, vars->trav.origin, instr->src[4].ssa, 0x7);
320 
321    rq_store_var(b, index, vars->tmin, instr->src[5].ssa, 0x1);
322 
323    rq_store_var(b, index, vars->direction, instr->src[6].ssa, 0x7);
324    rq_store_var(b, index, vars->trav.direction, instr->src[6].ssa, 0x7);
325 
326    rq_store_var(b, index, vars->closest.t, instr->src[7].ssa, 0x1);
327    rq_store_var(b, index, vars->closest.intersection_type, nir_imm_int(b, intersection_type_none),
328                 0x1);
329 
330    nir_def *accel_struct = instr->src[1].ssa;
331    nir_def *bvh_base = accel_struct;
332    if (bvh_base->bit_size != 64) {
333       assert(bvh_base->num_components >= 2);
334       bvh_base = nir_load_ubo(
335          b, 1, 64, nir_channel(b, accel_struct, 0),
336          nir_imul_imm(b, nir_channel(b, accel_struct, 1), sizeof(struct lp_descriptor)), .range = ~0);
337    }
338 
339    rq_store_var(b, index, vars->root_bvh_base, bvh_base, 0x1);
340    rq_store_var(b, index, vars->trav.bvh_base, bvh_base, 1);
341 
342    rq_store_var(b, index, vars->trav.current_node, nir_imm_int(b, LVP_BVH_ROOT_NODE), 0x1);
343    rq_store_var(b, index, vars->trav.stack_ptr, nir_imm_int(b, 0), 0x1);
344    rq_store_var(b, index, vars->trav.stack_base, nir_imm_int(b, -1), 0x1);
345 
346    rq_store_var(b, index, vars->incomplete, nir_ine_imm(b, bvh_base, 0), 0x1);
347 }
348 
349 static nir_def *
lower_rq_load(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)350 lower_rq_load(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
351               struct ray_query_vars *vars)
352 {
353    bool committed = nir_intrinsic_committed(instr);
354    struct ray_query_intersection_vars *intersection = committed ? &vars->closest : &vars->candidate;
355 
356    uint32_t column = nir_intrinsic_column(instr);
357 
358    nir_ray_query_value value = nir_intrinsic_ray_query_value(instr);
359    switch (value) {
360    case nir_ray_query_value_flags:
361       return rq_load_var(b, index, vars->flags);
362    case nir_ray_query_value_intersection_barycentrics:
363       return rq_load_var(b, index, intersection->barycentrics);
364    case nir_ray_query_value_intersection_candidate_aabb_opaque:
365       return nir_iand(b, rq_load_var(b, index, vars->candidate.opaque),
366                       nir_ieq_imm(b, rq_load_var(b, index, vars->candidate.intersection_type),
367                                   intersection_type_aabb));
368    case nir_ray_query_value_intersection_front_face:
369       return rq_load_var(b, index, intersection->frontface);
370    case nir_ray_query_value_intersection_geometry_index:
371       return nir_iand_imm(b, rq_load_var(b, index, intersection->geometry_id_and_flags), 0xFFFFFF);
372    case nir_ray_query_value_intersection_instance_custom_index: {
373       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
374       return nir_iand_imm(b,
375                           nir_build_load_global(b, 1, 32,
376                                                 nir_iadd_imm(b, instance_node_addr,
377                                                              offsetof(struct lvp_bvh_instance_node,
378                                                                       custom_instance_and_mask))),
379                           0xFFFFFF);
380    }
381    case nir_ray_query_value_intersection_instance_id: {
382       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
383       return nir_build_load_global(
384          b, 1, 32,
385          nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, instance_id)));
386    }
387    case nir_ray_query_value_intersection_instance_sbt_index:
388       return nir_iand_imm(b, rq_load_var(b, index, intersection->sbt_offset_and_flags), 0xFFFFFF);
389    case nir_ray_query_value_intersection_object_ray_direction: {
390       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
391       nir_def *wto_matrix[3];
392       lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
393       return lvp_mul_vec3_mat(b, rq_load_var(b, index, vars->direction), wto_matrix, false);
394    }
395    case nir_ray_query_value_intersection_object_ray_origin: {
396       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
397       nir_def *wto_matrix[3];
398       lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
399       return lvp_mul_vec3_mat(b, rq_load_var(b, index, vars->origin), wto_matrix, true);
400    }
401    case nir_ray_query_value_intersection_object_to_world: {
402       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
403       nir_def *rows[3];
404       for (unsigned r = 0; r < 3; ++r)
405          rows[r] = nir_build_load_global(
406             b, 4, 32,
407             nir_iadd_imm(b, instance_node_addr,
408                          offsetof(struct lvp_bvh_instance_node, otw_matrix) + r * 16));
409 
410       return nir_vec3(b, nir_channel(b, rows[0], column), nir_channel(b, rows[1], column),
411                       nir_channel(b, rows[2], column));
412    }
413    case nir_ray_query_value_intersection_primitive_index:
414       return rq_load_var(b, index, intersection->primitive_id);
415    case nir_ray_query_value_intersection_t:
416       return rq_load_var(b, index, intersection->t);
417    case nir_ray_query_value_intersection_type: {
418       nir_def *intersection_type = rq_load_var(b, index, intersection->intersection_type);
419       if (!committed)
420          intersection_type = nir_iadd_imm(b, intersection_type, -1);
421 
422       return intersection_type;
423    }
424    case nir_ray_query_value_intersection_world_to_object: {
425       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
426 
427       nir_def *wto_matrix[3];
428       lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
429 
430       nir_def *vals[3];
431       for (unsigned i = 0; i < 3; ++i)
432          vals[i] = nir_channel(b, wto_matrix[i], column);
433 
434       return nir_vec(b, vals, 3);
435    }
436    case nir_ray_query_value_tmin:
437       return rq_load_var(b, index, vars->tmin);
438    case nir_ray_query_value_world_ray_direction:
439       return rq_load_var(b, index, vars->direction);
440    case nir_ray_query_value_world_ray_origin:
441       return rq_load_var(b, index, vars->origin);
442    case nir_ray_query_value_intersection_triangle_vertex_positions:
443       return lvp_load_vertex_position(
444          b, rq_load_var(b, index, intersection->instance_addr),
445          rq_load_var(b, index, intersection->primitive_id), column);
446    default:
447       unreachable("Invalid nir_ray_query_value!");
448    }
449 
450    return NULL;
451 }
452 
453 struct traversal_data {
454    struct ray_query_vars *vars;
455    nir_def *index;
456 };
457 
458 static void
handle_candidate_aabb(nir_builder * b,struct lvp_leaf_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)459 handle_candidate_aabb(nir_builder *b, struct lvp_leaf_intersection *intersection,
460                       const struct lvp_ray_traversal_args *args,
461                       const struct lvp_ray_flags *ray_flags)
462 {
463    struct traversal_data *data = args->data;
464    struct ray_query_vars *vars = data->vars;
465    nir_def *index = data->index;
466 
467    rq_store_var(b, index, vars->candidate.primitive_id, intersection->primitive_id, 1);
468    rq_store_var(b, index, vars->candidate.geometry_id_and_flags,
469                 intersection->geometry_id_and_flags, 1);
470    rq_store_var(b, index, vars->candidate.opaque, intersection->opaque, 0x1);
471    rq_store_var(b, index, vars->candidate.intersection_type, nir_imm_int(b, intersection_type_aabb),
472                 0x1);
473 
474    nir_jump(b, nir_jump_break);
475 }
476 
477 static void
handle_candidate_triangle(nir_builder * b,struct lvp_triangle_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)478 handle_candidate_triangle(nir_builder *b, struct lvp_triangle_intersection *intersection,
479                           const struct lvp_ray_traversal_args *args,
480                           const struct lvp_ray_flags *ray_flags)
481 {
482    struct traversal_data *data = args->data;
483    struct ray_query_vars *vars = data->vars;
484    nir_def *index = data->index;
485 
486    rq_store_var(b, index, vars->candidate.barycentrics, intersection->barycentrics, 3);
487    rq_store_var(b, index, vars->candidate.primitive_id, intersection->base.primitive_id, 1);
488    rq_store_var(b, index, vars->candidate.geometry_id_and_flags,
489                 intersection->base.geometry_id_and_flags, 1);
490    rq_store_var(b, index, vars->candidate.t, intersection->t, 0x1);
491    rq_store_var(b, index, vars->candidate.opaque, intersection->base.opaque, 0x1);
492    rq_store_var(b, index, vars->candidate.frontface, intersection->frontface, 0x1);
493    rq_store_var(b, index, vars->candidate.intersection_type,
494                 nir_imm_int(b, intersection_type_triangle), 0x1);
495 
496    nir_push_if(b, intersection->base.opaque);
497    {
498       copy_candidate_to_closest(b, index, vars);
499       insert_terminate_on_first_hit(b, index, vars, true);
500    }
501    nir_push_else(b, NULL);
502    {
503       nir_jump(b, nir_jump_break);
504    }
505    nir_pop_if(b, NULL);
506 }
507 
508 static nir_def *
lower_rq_proceed(nir_builder * b,nir_def * index,struct ray_query_vars * vars)509 lower_rq_proceed(nir_builder *b, nir_def *index, struct ray_query_vars *vars)
510 {
511    nir_variable *inv_dir =
512       nir_local_variable_create(b->impl, glsl_vector_type(GLSL_TYPE_FLOAT, 3), "inv_dir");
513    nir_store_var(b, inv_dir, nir_frcp(b, rq_load_var(b, index, vars->trav.direction)), 0x7);
514 
515    struct lvp_ray_traversal_vars trav_vars = {
516       .tmax = rq_deref_var(b, index, vars->closest.t),
517       .origin = rq_deref_var(b, index, vars->trav.origin),
518       .dir = rq_deref_var(b, index, vars->trav.direction),
519       .inv_dir = nir_build_deref_var(b, inv_dir),
520       .bvh_base = rq_deref_var(b, index, vars->trav.bvh_base),
521       .current_node = rq_deref_var(b, index, vars->trav.current_node),
522       .stack_ptr = rq_deref_var(b, index, vars->trav.stack_ptr),
523       .stack_base = rq_deref_var(b, index, vars->trav.stack_base),
524       .stack = rq_deref_var(b, index, vars->trav.stack),
525       .instance_addr = rq_deref_var(b, index, vars->candidate.instance_addr),
526       .sbt_offset_and_flags = rq_deref_var(b, index, vars->candidate.sbt_offset_and_flags),
527    };
528 
529    struct traversal_data data = {
530       .vars = vars,
531       .index = index,
532    };
533 
534    struct lvp_ray_traversal_args args = {
535       .root_bvh_base = rq_load_var(b, index, vars->root_bvh_base),
536       .flags = rq_load_var(b, index, vars->flags),
537       .cull_mask = rq_load_var(b, index, vars->cull_mask),
538       .origin = rq_load_var(b, index, vars->origin),
539       .tmin = rq_load_var(b, index, vars->tmin),
540       .dir = rq_load_var(b, index, vars->direction),
541       .vars = trav_vars,
542       .aabb_cb = handle_candidate_aabb,
543       .triangle_cb = handle_candidate_triangle,
544       .data = &data,
545    };
546 
547    nir_push_if(b, rq_load_var(b, index, vars->incomplete));
548    {
549       nir_def *incomplete = lvp_build_ray_traversal(b, &args);
550       rq_store_var(b, index, vars->incomplete,
551                    nir_iand(b, rq_load_var(b, index, vars->incomplete), incomplete), 1);
552    }
553    nir_pop_if(b, NULL);
554 
555    return rq_load_var(b, index, vars->incomplete);
556 }
557 
558 static void
lower_rq_terminate(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)559 lower_rq_terminate(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
560                    struct ray_query_vars *vars)
561 {
562    rq_store_var(b, index, vars->incomplete, nir_imm_false(b), 0x1);
563 }
564 
565 bool
lvp_nir_lower_ray_queries(struct nir_shader * shader)566 lvp_nir_lower_ray_queries(struct nir_shader *shader)
567 {
568    bool progress = false;
569    struct hash_table *query_ht = _mesa_pointer_hash_table_create(NULL);
570 
571    nir_foreach_variable_in_list (var, &shader->variables) {
572       if (!var->data.ray_query)
573          continue;
574 
575       lower_ray_query(shader, var, query_ht);
576 
577       progress = true;
578    }
579 
580    nir_foreach_function (function, shader) {
581       if (!function->impl)
582          continue;
583 
584       nir_builder builder = nir_builder_create(function->impl);
585 
586       nir_foreach_variable_in_list (var, &function->impl->locals) {
587          if (!var->data.ray_query)
588             continue;
589 
590          lower_ray_query(shader, var, query_ht);
591 
592          progress = true;
593       }
594 
595       nir_foreach_block (block, function->impl) {
596          nir_foreach_instr_safe (instr, block) {
597             if (instr->type != nir_instr_type_intrinsic)
598                continue;
599 
600             nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
601 
602             if (!nir_intrinsic_is_ray_query(intrinsic->intrinsic))
603                continue;
604 
605             nir_deref_instr *ray_query_deref =
606                nir_instr_as_deref(intrinsic->src[0].ssa->parent_instr);
607             nir_def *index = NULL;
608 
609             if (ray_query_deref->deref_type == nir_deref_type_array) {
610                index = ray_query_deref->arr.index.ssa;
611                ray_query_deref = nir_instr_as_deref(ray_query_deref->parent.ssa->parent_instr);
612             }
613 
614             assert(ray_query_deref->deref_type == nir_deref_type_var);
615 
616             struct ray_query_vars *vars =
617                (struct ray_query_vars *)_mesa_hash_table_search(query_ht, ray_query_deref->var)
618                   ->data;
619 
620             builder.cursor = nir_before_instr(instr);
621 
622             nir_def *new_dest = NULL;
623 
624             switch (intrinsic->intrinsic) {
625             case nir_intrinsic_rq_confirm_intersection:
626                lower_rq_confirm_intersection(&builder, index, intrinsic, vars);
627                break;
628             case nir_intrinsic_rq_generate_intersection:
629                lower_rq_generate_intersection(&builder, index, intrinsic, vars);
630                break;
631             case nir_intrinsic_rq_initialize:
632                lower_rq_initialize(&builder, index, intrinsic, vars);
633                break;
634             case nir_intrinsic_rq_load:
635                new_dest = lower_rq_load(&builder, index, intrinsic, vars);
636                break;
637             case nir_intrinsic_rq_proceed:
638                new_dest = lower_rq_proceed(&builder, index, vars);
639                break;
640             case nir_intrinsic_rq_terminate:
641                lower_rq_terminate(&builder, index, intrinsic, vars);
642                break;
643             default:
644                unreachable("Unsupported ray query intrinsic!");
645             }
646 
647             if (new_dest)
648                nir_def_rewrite_uses(&intrinsic->def, new_dest);
649 
650             nir_instr_remove(instr);
651             nir_instr_free(instr);
652 
653             progress = true;
654          }
655       }
656 
657       nir_metadata_preserve(function->impl, nir_metadata_none);
658    }
659 
660    ralloc_free(query_ht);
661 
662    if (progress) {
663       NIR_PASS(_, shader, nir_lower_global_vars_to_local);
664       NIR_PASS(_, shader, nir_lower_vars_to_ssa);
665 
666       NIR_PASS(_, shader, nir_opt_constant_folding);
667       NIR_PASS(_, shader, nir_opt_cse);
668       NIR_PASS(_, shader, nir_opt_dce);
669    }
670 
671    return progress;
672 }
673