1 /*
2 * Copyright © 2023 Valve Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "nir/nir.h"
7 #include "nir/nir_builder.h"
8
9 #include "lvp_nir_ray_tracing.h"
10 #include "lvp_acceleration_structure.h"
11 #include "lvp_private.h"
12
13 #include "spirv/spirv.h"
14
15 #include "util/hash_table.h"
16
17 typedef struct {
18 nir_variable *variable;
19 unsigned array_length;
20 } rq_variable;
21
22 static rq_variable *
rq_variable_create(void * ctx,nir_shader * shader,unsigned array_length,const struct glsl_type * type,const char * name)23 rq_variable_create(void *ctx, nir_shader *shader, unsigned array_length,
24 const struct glsl_type *type, const char *name)
25 {
26 rq_variable *result = ralloc(ctx, rq_variable);
27 result->array_length = array_length;
28
29 const struct glsl_type *variable_type = type;
30 if (array_length != 1)
31 variable_type = glsl_array_type(type, array_length, glsl_get_explicit_stride(type));
32
33 result->variable = nir_variable_create(shader, nir_var_shader_temp, variable_type, name);
34
35 return result;
36 }
37
38 static nir_def *
nir_load_array(nir_builder * b,nir_variable * array,nir_def * index)39 nir_load_array(nir_builder *b, nir_variable *array, nir_def *index)
40 {
41 return nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, array), index));
42 }
43
44 static void
nir_store_array(nir_builder * b,nir_variable * array,nir_def * index,nir_def * value,unsigned writemask)45 nir_store_array(nir_builder *b, nir_variable *array, nir_def *index, nir_def *value,
46 unsigned writemask)
47 {
48 nir_store_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, array), index), value,
49 writemask);
50 }
51
52 static nir_deref_instr *
rq_deref_var(nir_builder * b,nir_def * index,rq_variable * var)53 rq_deref_var(nir_builder *b, nir_def *index, rq_variable *var)
54 {
55 if (var->array_length == 1)
56 return nir_build_deref_var(b, var->variable);
57
58 return nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index);
59 }
60
61 static nir_def *
rq_load_var(nir_builder * b,nir_def * index,rq_variable * var)62 rq_load_var(nir_builder *b, nir_def *index, rq_variable *var)
63 {
64 if (var->array_length == 1)
65 return nir_load_var(b, var->variable);
66
67 return nir_load_array(b, var->variable, index);
68 }
69
70 static void
rq_store_var(nir_builder * b,nir_def * index,rq_variable * var,nir_def * value,unsigned writemask)71 rq_store_var(nir_builder *b, nir_def *index, rq_variable *var, nir_def *value,
72 unsigned writemask)
73 {
74 if (var->array_length == 1) {
75 nir_store_var(b, var->variable, value, writemask);
76 } else {
77 nir_store_array(b, var->variable, index, value, writemask);
78 }
79 }
80
81 static void
rq_copy_var(nir_builder * b,nir_def * index,rq_variable * dst,rq_variable * src,unsigned mask)82 rq_copy_var(nir_builder *b, nir_def *index, rq_variable *dst, rq_variable *src, unsigned mask)
83 {
84 rq_store_var(b, index, dst, rq_load_var(b, index, src), mask);
85 }
86
87 static nir_def *
rq_load_array(nir_builder * b,nir_def * index,rq_variable * var,nir_def * array_index)88 rq_load_array(nir_builder *b, nir_def *index, rq_variable *var, nir_def *array_index)
89 {
90 if (var->array_length == 1)
91 return nir_load_array(b, var->variable, array_index);
92
93 return nir_load_deref(
94 b,
95 nir_build_deref_array(
96 b, nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index), array_index));
97 }
98
99 static void
rq_store_array(nir_builder * b,nir_def * index,rq_variable * var,nir_def * array_index,nir_def * value,unsigned writemask)100 rq_store_array(nir_builder *b, nir_def *index, rq_variable *var, nir_def *array_index,
101 nir_def *value, unsigned writemask)
102 {
103 if (var->array_length == 1) {
104 nir_store_array(b, var->variable, array_index, value, writemask);
105 } else {
106 nir_store_deref(
107 b,
108 nir_build_deref_array(
109 b, nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index), array_index),
110 value, writemask);
111 }
112 }
113
114 struct ray_query_traversal_vars {
115 rq_variable *origin;
116 rq_variable *direction;
117
118 rq_variable *bvh_base;
119 rq_variable *current_node;
120
121 rq_variable *stack_base;
122 rq_variable *stack_ptr;
123 rq_variable *stack;
124 };
125
126 struct ray_query_intersection_vars {
127 rq_variable *primitive_id;
128 rq_variable *geometry_id_and_flags;
129 rq_variable *instance_addr;
130 rq_variable *intersection_type;
131 rq_variable *opaque;
132 rq_variable *frontface;
133 rq_variable *sbt_offset_and_flags;
134 rq_variable *barycentrics;
135 rq_variable *t;
136 };
137
138 struct ray_query_vars {
139 rq_variable *root_bvh_base;
140 rq_variable *flags;
141 rq_variable *cull_mask;
142 rq_variable *origin;
143 rq_variable *tmin;
144 rq_variable *direction;
145
146 rq_variable *incomplete;
147
148 struct ray_query_intersection_vars closest;
149 struct ray_query_intersection_vars candidate;
150
151 struct ray_query_traversal_vars trav;
152 };
153
154 #define VAR_NAME(name) \
155 strcat(strcpy(ralloc_size(ctx, strlen(base_name) + strlen(name) + 1), base_name), name)
156
157 static struct ray_query_traversal_vars
init_ray_query_traversal_vars(void * ctx,nir_shader * shader,unsigned array_length,const char * base_name)158 init_ray_query_traversal_vars(void *ctx, nir_shader *shader, unsigned array_length,
159 const char *base_name)
160 {
161 struct ray_query_traversal_vars result;
162
163 const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
164
165 result.origin = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_origin"));
166 result.direction =
167 rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_direction"));
168
169 result.bvh_base =
170 rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_bvh_base"));
171 result.current_node =
172 rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_current_node"));
173 result.stack_base =
174 rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack_base"));
175 result.stack_ptr = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack_ptr"));
176 result.stack = rq_variable_create(ctx, shader, array_length, glsl_array_type(glsl_uint_type(), 24 * 2, 0), VAR_NAME("_stack"));
177 return result;
178 }
179
180 static struct ray_query_intersection_vars
init_ray_query_intersection_vars(void * ctx,nir_shader * shader,unsigned array_length,const char * base_name)181 init_ray_query_intersection_vars(void *ctx, nir_shader *shader, unsigned array_length,
182 const char *base_name)
183 {
184 struct ray_query_intersection_vars result;
185
186 const struct glsl_type *vec2_type = glsl_vector_type(GLSL_TYPE_FLOAT, 2);
187
188 result.primitive_id =
189 rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_primitive_id"));
190 result.geometry_id_and_flags = rq_variable_create(ctx, shader, array_length, glsl_uint_type(),
191 VAR_NAME("_geometry_id_and_flags"));
192 result.instance_addr = rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(),
193 VAR_NAME("_instance_addr"));
194 result.intersection_type = rq_variable_create(ctx, shader, array_length, glsl_uint_type(),
195 VAR_NAME("_intersection_type"));
196 result.opaque =
197 rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_opaque"));
198 result.frontface =
199 rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_frontface"));
200 result.sbt_offset_and_flags = rq_variable_create(ctx, shader, array_length, glsl_uint_type(),
201 VAR_NAME("_sbt_offset_and_flags"));
202 result.barycentrics =
203 rq_variable_create(ctx, shader, array_length, vec2_type, VAR_NAME("_barycentrics"));
204 result.t = rq_variable_create(ctx, shader, array_length, glsl_float_type(), VAR_NAME("_t"));
205
206 return result;
207 }
208
209 static void
init_ray_query_vars(nir_shader * shader,unsigned array_length,struct ray_query_vars * dst,const char * base_name)210 init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_vars *dst,
211 const char *base_name)
212 {
213 void *ctx = dst;
214 const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
215
216 dst->root_bvh_base = rq_variable_create(dst, shader, array_length, glsl_uint64_t_type(),
217 VAR_NAME("_root_bvh_base"));
218 dst->flags = rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_flags"));
219 dst->cull_mask =
220 rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_cull_mask"));
221 dst->origin = rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_origin"));
222 dst->tmin = rq_variable_create(dst, shader, array_length, glsl_float_type(), VAR_NAME("_tmin"));
223 dst->direction =
224 rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_direction"));
225
226 dst->incomplete =
227 rq_variable_create(dst, shader, array_length, glsl_bool_type(), VAR_NAME("_incomplete"));
228
229 dst->closest = init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_closest"));
230 dst->candidate =
231 init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_candidate"));
232
233 dst->trav = init_ray_query_traversal_vars(dst, shader, array_length, VAR_NAME("_top"));
234 }
235
236 #undef VAR_NAME
237
238 static void
lower_ray_query(nir_shader * shader,nir_variable * ray_query,struct hash_table * ht)239 lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table *ht)
240 {
241 struct ray_query_vars *vars = ralloc(ht, struct ray_query_vars);
242
243 unsigned array_length = 1;
244 if (glsl_type_is_array(ray_query->type))
245 array_length = glsl_get_length(ray_query->type);
246
247 init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? "" : ray_query->name);
248
249 _mesa_hash_table_insert(ht, ray_query, vars);
250 }
251
252 static void
copy_candidate_to_closest(nir_builder * b,nir_def * index,struct ray_query_vars * vars)253 copy_candidate_to_closest(nir_builder *b, nir_def *index, struct ray_query_vars *vars)
254 {
255 rq_copy_var(b, index, vars->closest.barycentrics, vars->candidate.barycentrics, 0x3);
256 rq_copy_var(b, index, vars->closest.geometry_id_and_flags, vars->candidate.geometry_id_and_flags,
257 0x1);
258 rq_copy_var(b, index, vars->closest.instance_addr, vars->candidate.instance_addr, 0x1);
259 rq_copy_var(b, index, vars->closest.intersection_type, vars->candidate.intersection_type, 0x1);
260 rq_copy_var(b, index, vars->closest.opaque, vars->candidate.opaque, 0x1);
261 rq_copy_var(b, index, vars->closest.frontface, vars->candidate.frontface, 0x1);
262 rq_copy_var(b, index, vars->closest.sbt_offset_and_flags, vars->candidate.sbt_offset_and_flags,
263 0x1);
264 rq_copy_var(b, index, vars->closest.primitive_id, vars->candidate.primitive_id, 0x1);
265 rq_copy_var(b, index, vars->closest.t, vars->candidate.t, 0x1);
266 }
267
268 static void
insert_terminate_on_first_hit(nir_builder * b,nir_def * index,struct ray_query_vars * vars,bool break_on_terminate)269 insert_terminate_on_first_hit(nir_builder *b, nir_def *index, struct ray_query_vars *vars,
270 bool break_on_terminate)
271 {
272 nir_def *terminate_on_first_hit =
273 nir_test_mask(b, rq_load_var(b, index, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask);
274 nir_push_if(b, terminate_on_first_hit);
275 {
276 rq_store_var(b, index, vars->incomplete, nir_imm_false(b), 0x1);
277 if (break_on_terminate)
278 nir_jump(b, nir_jump_break);
279 }
280 nir_pop_if(b, NULL);
281 }
282
283 static void
lower_rq_confirm_intersection(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)284 lower_rq_confirm_intersection(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
285 struct ray_query_vars *vars)
286 {
287 copy_candidate_to_closest(b, index, vars);
288 insert_terminate_on_first_hit(b, index, vars, false);
289 }
290
291 static void
lower_rq_generate_intersection(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)292 lower_rq_generate_intersection(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
293 struct ray_query_vars *vars)
294 {
295 nir_push_if(b, nir_iand(b, nir_fge(b, rq_load_var(b, index, vars->closest.t), instr->src[1].ssa),
296 nir_fge(b, instr->src[1].ssa, rq_load_var(b, index, vars->tmin))));
297 {
298 copy_candidate_to_closest(b, index, vars);
299 insert_terminate_on_first_hit(b, index, vars, false);
300 rq_store_var(b, index, vars->closest.t, instr->src[1].ssa, 0x1);
301 }
302 nir_pop_if(b, NULL);
303 }
304
305 enum rq_intersection_type {
306 intersection_type_none,
307 intersection_type_triangle,
308 intersection_type_aabb
309 };
310
311 static void
lower_rq_initialize(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)312 lower_rq_initialize(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
313 struct ray_query_vars *vars)
314 {
315 rq_store_var(b, index, vars->flags, instr->src[2].ssa, 0x1);
316 rq_store_var(b, index, vars->cull_mask, nir_ishl_imm(b, instr->src[3].ssa, 24), 0x1);
317
318 rq_store_var(b, index, vars->origin, instr->src[4].ssa, 0x7);
319 rq_store_var(b, index, vars->trav.origin, instr->src[4].ssa, 0x7);
320
321 rq_store_var(b, index, vars->tmin, instr->src[5].ssa, 0x1);
322
323 rq_store_var(b, index, vars->direction, instr->src[6].ssa, 0x7);
324 rq_store_var(b, index, vars->trav.direction, instr->src[6].ssa, 0x7);
325
326 rq_store_var(b, index, vars->closest.t, instr->src[7].ssa, 0x1);
327 rq_store_var(b, index, vars->closest.intersection_type, nir_imm_int(b, intersection_type_none),
328 0x1);
329
330 nir_def *accel_struct = instr->src[1].ssa;
331 nir_def *bvh_base = accel_struct;
332 if (bvh_base->bit_size != 64) {
333 assert(bvh_base->num_components >= 2);
334 bvh_base = nir_load_ubo(
335 b, 1, 64, nir_channel(b, accel_struct, 0),
336 nir_imul_imm(b, nir_channel(b, accel_struct, 1), sizeof(struct lp_descriptor)), .range = ~0);
337 }
338
339 rq_store_var(b, index, vars->root_bvh_base, bvh_base, 0x1);
340 rq_store_var(b, index, vars->trav.bvh_base, bvh_base, 1);
341
342 rq_store_var(b, index, vars->trav.current_node, nir_imm_int(b, LVP_BVH_ROOT_NODE), 0x1);
343 rq_store_var(b, index, vars->trav.stack_ptr, nir_imm_int(b, 0), 0x1);
344 rq_store_var(b, index, vars->trav.stack_base, nir_imm_int(b, -1), 0x1);
345
346 rq_store_var(b, index, vars->incomplete, nir_ine_imm(b, bvh_base, 0), 0x1);
347 }
348
349 static nir_def *
lower_rq_load(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)350 lower_rq_load(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
351 struct ray_query_vars *vars)
352 {
353 bool committed = nir_intrinsic_committed(instr);
354 struct ray_query_intersection_vars *intersection = committed ? &vars->closest : &vars->candidate;
355
356 uint32_t column = nir_intrinsic_column(instr);
357
358 nir_ray_query_value value = nir_intrinsic_ray_query_value(instr);
359 switch (value) {
360 case nir_ray_query_value_flags:
361 return rq_load_var(b, index, vars->flags);
362 case nir_ray_query_value_intersection_barycentrics:
363 return rq_load_var(b, index, intersection->barycentrics);
364 case nir_ray_query_value_intersection_candidate_aabb_opaque:
365 return nir_iand(b, rq_load_var(b, index, vars->candidate.opaque),
366 nir_ieq_imm(b, rq_load_var(b, index, vars->candidate.intersection_type),
367 intersection_type_aabb));
368 case nir_ray_query_value_intersection_front_face:
369 return rq_load_var(b, index, intersection->frontface);
370 case nir_ray_query_value_intersection_geometry_index:
371 return nir_iand_imm(b, rq_load_var(b, index, intersection->geometry_id_and_flags), 0xFFFFFF);
372 case nir_ray_query_value_intersection_instance_custom_index: {
373 nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
374 return nir_iand_imm(b,
375 nir_build_load_global(b, 1, 32,
376 nir_iadd_imm(b, instance_node_addr,
377 offsetof(struct lvp_bvh_instance_node,
378 custom_instance_and_mask))),
379 0xFFFFFF);
380 }
381 case nir_ray_query_value_intersection_instance_id: {
382 nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
383 return nir_build_load_global(
384 b, 1, 32,
385 nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, instance_id)));
386 }
387 case nir_ray_query_value_intersection_instance_sbt_index:
388 return nir_iand_imm(b, rq_load_var(b, index, intersection->sbt_offset_and_flags), 0xFFFFFF);
389 case nir_ray_query_value_intersection_object_ray_direction: {
390 nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
391 nir_def *wto_matrix[3];
392 lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
393 return lvp_mul_vec3_mat(b, rq_load_var(b, index, vars->direction), wto_matrix, false);
394 }
395 case nir_ray_query_value_intersection_object_ray_origin: {
396 nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
397 nir_def *wto_matrix[3];
398 lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
399 return lvp_mul_vec3_mat(b, rq_load_var(b, index, vars->origin), wto_matrix, true);
400 }
401 case nir_ray_query_value_intersection_object_to_world: {
402 nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
403 nir_def *rows[3];
404 for (unsigned r = 0; r < 3; ++r)
405 rows[r] = nir_build_load_global(
406 b, 4, 32,
407 nir_iadd_imm(b, instance_node_addr,
408 offsetof(struct lvp_bvh_instance_node, otw_matrix) + r * 16));
409
410 return nir_vec3(b, nir_channel(b, rows[0], column), nir_channel(b, rows[1], column),
411 nir_channel(b, rows[2], column));
412 }
413 case nir_ray_query_value_intersection_primitive_index:
414 return rq_load_var(b, index, intersection->primitive_id);
415 case nir_ray_query_value_intersection_t:
416 return rq_load_var(b, index, intersection->t);
417 case nir_ray_query_value_intersection_type: {
418 nir_def *intersection_type = rq_load_var(b, index, intersection->intersection_type);
419 if (!committed)
420 intersection_type = nir_iadd_imm(b, intersection_type, -1);
421
422 return intersection_type;
423 }
424 case nir_ray_query_value_intersection_world_to_object: {
425 nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
426
427 nir_def *wto_matrix[3];
428 lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
429
430 nir_def *vals[3];
431 for (unsigned i = 0; i < 3; ++i)
432 vals[i] = nir_channel(b, wto_matrix[i], column);
433
434 return nir_vec(b, vals, 3);
435 }
436 case nir_ray_query_value_tmin:
437 return rq_load_var(b, index, vars->tmin);
438 case nir_ray_query_value_world_ray_direction:
439 return rq_load_var(b, index, vars->direction);
440 case nir_ray_query_value_world_ray_origin:
441 return rq_load_var(b, index, vars->origin);
442 case nir_ray_query_value_intersection_triangle_vertex_positions:
443 return lvp_load_vertex_position(
444 b, rq_load_var(b, index, intersection->instance_addr),
445 rq_load_var(b, index, intersection->primitive_id), column);
446 default:
447 unreachable("Invalid nir_ray_query_value!");
448 }
449
450 return NULL;
451 }
452
453 struct traversal_data {
454 struct ray_query_vars *vars;
455 nir_def *index;
456 };
457
458 static void
handle_candidate_aabb(nir_builder * b,struct lvp_leaf_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)459 handle_candidate_aabb(nir_builder *b, struct lvp_leaf_intersection *intersection,
460 const struct lvp_ray_traversal_args *args,
461 const struct lvp_ray_flags *ray_flags)
462 {
463 struct traversal_data *data = args->data;
464 struct ray_query_vars *vars = data->vars;
465 nir_def *index = data->index;
466
467 rq_store_var(b, index, vars->candidate.primitive_id, intersection->primitive_id, 1);
468 rq_store_var(b, index, vars->candidate.geometry_id_and_flags,
469 intersection->geometry_id_and_flags, 1);
470 rq_store_var(b, index, vars->candidate.opaque, intersection->opaque, 0x1);
471 rq_store_var(b, index, vars->candidate.intersection_type, nir_imm_int(b, intersection_type_aabb),
472 0x1);
473
474 nir_jump(b, nir_jump_break);
475 }
476
477 static void
handle_candidate_triangle(nir_builder * b,struct lvp_triangle_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)478 handle_candidate_triangle(nir_builder *b, struct lvp_triangle_intersection *intersection,
479 const struct lvp_ray_traversal_args *args,
480 const struct lvp_ray_flags *ray_flags)
481 {
482 struct traversal_data *data = args->data;
483 struct ray_query_vars *vars = data->vars;
484 nir_def *index = data->index;
485
486 rq_store_var(b, index, vars->candidate.barycentrics, intersection->barycentrics, 3);
487 rq_store_var(b, index, vars->candidate.primitive_id, intersection->base.primitive_id, 1);
488 rq_store_var(b, index, vars->candidate.geometry_id_and_flags,
489 intersection->base.geometry_id_and_flags, 1);
490 rq_store_var(b, index, vars->candidate.t, intersection->t, 0x1);
491 rq_store_var(b, index, vars->candidate.opaque, intersection->base.opaque, 0x1);
492 rq_store_var(b, index, vars->candidate.frontface, intersection->frontface, 0x1);
493 rq_store_var(b, index, vars->candidate.intersection_type,
494 nir_imm_int(b, intersection_type_triangle), 0x1);
495
496 nir_push_if(b, intersection->base.opaque);
497 {
498 copy_candidate_to_closest(b, index, vars);
499 insert_terminate_on_first_hit(b, index, vars, true);
500 }
501 nir_push_else(b, NULL);
502 {
503 nir_jump(b, nir_jump_break);
504 }
505 nir_pop_if(b, NULL);
506 }
507
508 static nir_def *
lower_rq_proceed(nir_builder * b,nir_def * index,struct ray_query_vars * vars)509 lower_rq_proceed(nir_builder *b, nir_def *index, struct ray_query_vars *vars)
510 {
511 nir_variable *inv_dir =
512 nir_local_variable_create(b->impl, glsl_vector_type(GLSL_TYPE_FLOAT, 3), "inv_dir");
513 nir_store_var(b, inv_dir, nir_frcp(b, rq_load_var(b, index, vars->trav.direction)), 0x7);
514
515 struct lvp_ray_traversal_vars trav_vars = {
516 .tmax = rq_deref_var(b, index, vars->closest.t),
517 .origin = rq_deref_var(b, index, vars->trav.origin),
518 .dir = rq_deref_var(b, index, vars->trav.direction),
519 .inv_dir = nir_build_deref_var(b, inv_dir),
520 .bvh_base = rq_deref_var(b, index, vars->trav.bvh_base),
521 .current_node = rq_deref_var(b, index, vars->trav.current_node),
522 .stack_ptr = rq_deref_var(b, index, vars->trav.stack_ptr),
523 .stack_base = rq_deref_var(b, index, vars->trav.stack_base),
524 .stack = rq_deref_var(b, index, vars->trav.stack),
525 .instance_addr = rq_deref_var(b, index, vars->candidate.instance_addr),
526 .sbt_offset_and_flags = rq_deref_var(b, index, vars->candidate.sbt_offset_and_flags),
527 };
528
529 struct traversal_data data = {
530 .vars = vars,
531 .index = index,
532 };
533
534 struct lvp_ray_traversal_args args = {
535 .root_bvh_base = rq_load_var(b, index, vars->root_bvh_base),
536 .flags = rq_load_var(b, index, vars->flags),
537 .cull_mask = rq_load_var(b, index, vars->cull_mask),
538 .origin = rq_load_var(b, index, vars->origin),
539 .tmin = rq_load_var(b, index, vars->tmin),
540 .dir = rq_load_var(b, index, vars->direction),
541 .vars = trav_vars,
542 .aabb_cb = handle_candidate_aabb,
543 .triangle_cb = handle_candidate_triangle,
544 .data = &data,
545 };
546
547 nir_push_if(b, rq_load_var(b, index, vars->incomplete));
548 {
549 nir_def *incomplete = lvp_build_ray_traversal(b, &args);
550 rq_store_var(b, index, vars->incomplete,
551 nir_iand(b, rq_load_var(b, index, vars->incomplete), incomplete), 1);
552 }
553 nir_pop_if(b, NULL);
554
555 return rq_load_var(b, index, vars->incomplete);
556 }
557
558 static void
lower_rq_terminate(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)559 lower_rq_terminate(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
560 struct ray_query_vars *vars)
561 {
562 rq_store_var(b, index, vars->incomplete, nir_imm_false(b), 0x1);
563 }
564
565 bool
lvp_nir_lower_ray_queries(struct nir_shader * shader)566 lvp_nir_lower_ray_queries(struct nir_shader *shader)
567 {
568 bool progress = false;
569 struct hash_table *query_ht = _mesa_pointer_hash_table_create(NULL);
570
571 nir_foreach_variable_in_list (var, &shader->variables) {
572 if (!var->data.ray_query)
573 continue;
574
575 lower_ray_query(shader, var, query_ht);
576
577 progress = true;
578 }
579
580 nir_foreach_function (function, shader) {
581 if (!function->impl)
582 continue;
583
584 nir_builder builder = nir_builder_create(function->impl);
585
586 nir_foreach_variable_in_list (var, &function->impl->locals) {
587 if (!var->data.ray_query)
588 continue;
589
590 lower_ray_query(shader, var, query_ht);
591
592 progress = true;
593 }
594
595 nir_foreach_block (block, function->impl) {
596 nir_foreach_instr_safe (instr, block) {
597 if (instr->type != nir_instr_type_intrinsic)
598 continue;
599
600 nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
601
602 if (!nir_intrinsic_is_ray_query(intrinsic->intrinsic))
603 continue;
604
605 nir_deref_instr *ray_query_deref =
606 nir_instr_as_deref(intrinsic->src[0].ssa->parent_instr);
607 nir_def *index = NULL;
608
609 if (ray_query_deref->deref_type == nir_deref_type_array) {
610 index = ray_query_deref->arr.index.ssa;
611 ray_query_deref = nir_instr_as_deref(ray_query_deref->parent.ssa->parent_instr);
612 }
613
614 assert(ray_query_deref->deref_type == nir_deref_type_var);
615
616 struct ray_query_vars *vars =
617 (struct ray_query_vars *)_mesa_hash_table_search(query_ht, ray_query_deref->var)
618 ->data;
619
620 builder.cursor = nir_before_instr(instr);
621
622 nir_def *new_dest = NULL;
623
624 switch (intrinsic->intrinsic) {
625 case nir_intrinsic_rq_confirm_intersection:
626 lower_rq_confirm_intersection(&builder, index, intrinsic, vars);
627 break;
628 case nir_intrinsic_rq_generate_intersection:
629 lower_rq_generate_intersection(&builder, index, intrinsic, vars);
630 break;
631 case nir_intrinsic_rq_initialize:
632 lower_rq_initialize(&builder, index, intrinsic, vars);
633 break;
634 case nir_intrinsic_rq_load:
635 new_dest = lower_rq_load(&builder, index, intrinsic, vars);
636 break;
637 case nir_intrinsic_rq_proceed:
638 new_dest = lower_rq_proceed(&builder, index, vars);
639 break;
640 case nir_intrinsic_rq_terminate:
641 lower_rq_terminate(&builder, index, intrinsic, vars);
642 break;
643 default:
644 unreachable("Unsupported ray query intrinsic!");
645 }
646
647 if (new_dest)
648 nir_def_rewrite_uses(&intrinsic->def, new_dest);
649
650 nir_instr_remove(instr);
651 nir_instr_free(instr);
652
653 progress = true;
654 }
655 }
656
657 nir_metadata_preserve(function->impl, nir_metadata_none);
658 }
659
660 ralloc_free(query_ht);
661
662 if (progress) {
663 NIR_PASS(_, shader, nir_lower_global_vars_to_local);
664 NIR_PASS(_, shader, nir_lower_vars_to_ssa);
665
666 NIR_PASS(_, shader, nir_opt_constant_folding);
667 NIR_PASS(_, shader, nir_opt_cse);
668 NIR_PASS(_, shader, nir_opt_dce);
669 }
670
671 return progress;
672 }
673