xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/tests/radv_nir_lower_hit_attrib_derefs_tests.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2023 Valve Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "nir/radv_nir.h"
7 #include "tests/nir_test.h"
8 #include "radv_constants.h"
9 
10 class radv_nir_lower_hit_attrib_derefs_test : public nir_test {
11 protected:
radv_nir_lower_hit_attrib_derefs_test()12    radv_nir_lower_hit_attrib_derefs_test(): nir_test("radv_nir_lower_hit_attrib_derefs_test")
13    {
14       b->shader->info.stage = MESA_SHADER_INTERSECTION;
15    }
16 
17    void validate(uint32_t used_bits[RADV_MAX_HIT_ATTRIB_DWORDS], uint32_t used_dwords, bool constant_fold);
18 };
19 
20 void
validate(uint32_t used_bits[RADV_MAX_HIT_ATTRIB_DWORDS],uint32_t used_dwords,bool constant_fold)21 radv_nir_lower_hit_attrib_derefs_test::validate(uint32_t used_bits[RADV_MAX_HIT_ATTRIB_DWORDS], uint32_t used_dwords,
22                                                 bool constant_fold)
23 {
24    EXPECT_TRUE(radv_nir_lower_hit_attrib_derefs(b->shader));
25    nir_validate_shader(b->shader, "After radv_nir_lower_hit_attrib_derefs");
26 
27    srand(918763498);
28 
29    uint32_t values[RADV_MAX_HIT_ATTRIB_DWORDS];
30    for (uint32_t i = 0; i < ARRAY_SIZE(values); i++)
31       values[i] = ((uint32_t)rand() ^ ((uint32_t)rand() << 1)) & used_bits[i];
32 
33    nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);
34 
35    if (constant_fold) {
36       nir_foreach_block (block, impl) {
37          nir_foreach_instr_safe (instr, block) {
38             if (instr->type != nir_instr_type_intrinsic)
39                continue;
40 
41             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
42             if (intr->intrinsic != nir_intrinsic_load_hit_attrib_amd)
43                continue;
44 
45             b->cursor = nir_after_instr(instr);
46             nir_def *value = nir_imm_int(b, values[nir_intrinsic_base(intr)]);
47             nir_def_rewrite_uses(&intr->def, value);
48          }
49       }
50       NIR_PASS(_, b->shader, nir_opt_constant_folding);
51    }
52 
53    NIR_PASS(_, b->shader, nir_opt_dce);
54 
55    uint32_t stored_dwords = 0;
56    nir_foreach_block (block, impl) {
57       nir_foreach_instr_safe (instr, block) {
58          /* Make sure that all ray_hit_attrib variables have been lowered. */
59          if (instr->type == nir_instr_type_deref) {
60             EXPECT_FALSE(nir_deref_mode_is(nir_instr_as_deref(instr), nir_var_ray_hit_attrib));
61          }
62 
63          if (instr->type != nir_instr_type_intrinsic)
64             continue;
65 
66          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
67          if (intr->intrinsic != nir_intrinsic_store_hit_attrib_amd)
68             continue;
69 
70          uint32_t base = nir_intrinsic_base(intr);
71          EXPECT_LT(base, RADV_MAX_HIT_ATTRIB_DWORDS);
72          stored_dwords |= BITFIELD_BIT(base);
73 
74          bool is_const = nir_src_is_const(intr->src[0]);
75          EXPECT_TRUE(is_const || !constant_fold);
76          if (!is_const)
77             continue;
78 
79          uint32_t src = nir_src_as_uint(intr->src[0]);
80          EXPECT_EQ(src, values[base] & used_bits[base]);
81       }
82    }
83 
84    EXPECT_EQ(stored_dwords, used_dwords);
85 }
86 
TEST_F(radv_nir_lower_hit_attrib_derefs_test,types)87 TEST_F(radv_nir_lower_hit_attrib_derefs_test, types)
88 {
89    nir_variable *vec3_var = nir_variable_create(b->shader, nir_var_ray_hit_attrib, glsl_vec_type(3), "vec3");
90    nir_variable *uint64_var = nir_variable_create(b->shader, nir_var_ray_hit_attrib, glsl_uint64_t_type(), "uint64_t");
91    nir_variable *uint16_var = nir_variable_create(b->shader, nir_var_ray_hit_attrib, glsl_uint16_t_type(), "uint16_t");
92    nir_variable *uint8_var = nir_variable_create(b->shader, nir_var_ray_hit_attrib, glsl_uint8_t_type(), "uint8_t");
93    nir_variable *bool_var = nir_variable_create(b->shader, nir_var_ray_hit_attrib, glsl_bool_type(), "bool");
94 
95    nir_variable *vars[5] = {
96       vec3_var, uint64_var, uint16_var, uint8_var, bool_var,
97    };
98 
99    for (uint32_t i = 0; i < ARRAY_SIZE(vars); i++) {
100       nir_def *load = nir_load_var(b, vars[i]);
101       nir_store_var(b, vars[i], load, (1 << load->num_components) - 1);
102    }
103 
104    uint32_t masks[RADV_MAX_HIT_ATTRIB_DWORDS] = {
105       /* vec3 */
106       0xFFFFFFFF,
107       0xFFFFFFFF,
108       0xFFFFFFFF,
109       /* padding */
110       0,
111       /* uint64_t */
112       0xFFFFFFFF,
113       0xFFFFFFFF,
114       /* uint16_t uint8_t */
115       0xFFFFFFFF,
116       /* bool */
117       1,
118    };
119    validate(masks, 0b11110111, true);
120 }
121 
TEST_F(radv_nir_lower_hit_attrib_derefs_test,array)122 TEST_F(radv_nir_lower_hit_attrib_derefs_test, array)
123 {
124    nir_variable *array_var =
125       nir_variable_create(b->shader, nir_var_ray_hit_attrib,
126                           glsl_array_type(glsl_uint_type(), RADV_MAX_HIT_ATTRIB_DWORDS, 0), "uint32_t[]");
127 
128    for (uint32_t i = 0; i < RADV_MAX_HIT_ATTRIB_DWORDS; i++) {
129       nir_def *load = nir_load_array_var_imm(b, array_var, i);
130       nir_store_array_var_imm(b, array_var, i, load, 0x1);
131    }
132 
133    uint32_t masks[RADV_MAX_HIT_ATTRIB_DWORDS] = {
134       0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
135    };
136    validate(masks, 0xFF, true);
137 }
138 
TEST_F(radv_nir_lower_hit_attrib_derefs_test,dynamic_array)139 TEST_F(radv_nir_lower_hit_attrib_derefs_test, dynamic_array)
140 {
141    nir_variable *array_var =
142       nir_variable_create(b->shader, nir_var_ray_hit_attrib,
143                           glsl_array_type(glsl_uint_type(), RADV_MAX_HIT_ATTRIB_DWORDS, 0), "uint32_t[]");
144 
145    nir_def *index = nir_load_local_invocation_index(b);
146    nir_def *load = nir_load_array_var(b, array_var, index);
147    nir_store_array_var(b, array_var, index, load, 0x1);
148 
149    uint32_t masks[RADV_MAX_HIT_ATTRIB_DWORDS] = {
150       0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
151    };
152    validate(masks, 0xFF, false);
153 }
154 
TEST_F(radv_nir_lower_hit_attrib_derefs_test,struct)155 TEST_F(radv_nir_lower_hit_attrib_derefs_test, struct)
156 {
157    glsl_struct_field fields[5] = {
158       glsl_struct_field(glsl_vec_type(3), "vec3"),         glsl_struct_field(glsl_uint64_t_type(), "uint64_t"),
159       glsl_struct_field(glsl_uint16_t_type(), "uint16_t"), glsl_struct_field(glsl_uint8_t_type(), "uint8_t"),
160       glsl_struct_field(glsl_bool_type(), "bool"),
161    };
162 
163    const glsl_type *var_type = glsl_struct_type(fields, ARRAY_SIZE(fields), "hit_attrib_struct", false);
164    nir_variable *struct_var = nir_variable_create(b->shader, nir_var_ray_hit_attrib, var_type, "hit_attrib_struct");
165 
166    nir_deref_instr *var_deref = nir_build_deref_var(b, struct_var);
167 
168    for (uint32_t i = 0; i < ARRAY_SIZE(fields); i++) {
169       nir_deref_instr *member_deref = nir_build_deref_struct(b, var_deref, i);
170       nir_def *load = nir_load_deref(b, member_deref);
171       nir_store_deref(b, member_deref, load, (1 << load->num_components) - 1);
172    }
173 
174    uint32_t masks[RADV_MAX_HIT_ATTRIB_DWORDS] = {
175       /* vec3 */
176       0xFFFFFFFF,
177       0xFFFFFFFF,
178       0xFFFFFFFF,
179       /* padding */
180       0,
181       /* uint64_t */
182       0xFFFFFFFF,
183       0xFFFFFFFF,
184       /* uint16_t uint8_t */
185       0xFFFFFFFF,
186       /* bool */
187       1,
188    };
189    validate(masks, 0b11110111, true);
190 }
191 
TEST_F(radv_nir_lower_hit_attrib_derefs_test,array_inside_struct)192 TEST_F(radv_nir_lower_hit_attrib_derefs_test, array_inside_struct)
193 {
194    glsl_struct_field field =
195       glsl_struct_field(glsl_array_type(glsl_uint_type(), RADV_MAX_HIT_ATTRIB_DWORDS, 0), "array");
196 
197    const glsl_type *var_type = glsl_struct_type(&field, 1, "hit_attrib_struct", false);
198    nir_variable *struct_var = nir_variable_create(b->shader, nir_var_ray_hit_attrib, var_type, "hit_attrib_struct");
199 
200    nir_deref_instr *var_deref = nir_build_deref_var(b, struct_var);
201    nir_deref_instr *member_deref = nir_build_deref_struct(b, var_deref, 0);
202 
203    for (uint32_t i = 0; i < RADV_MAX_HIT_ATTRIB_DWORDS; i++) {
204       nir_deref_instr *element_deref = nir_build_deref_array_imm(b, member_deref, i);
205       nir_def *load = nir_load_deref(b, element_deref);
206       nir_store_deref(b, element_deref, load, (1 << load->num_components) - 1);
207    }
208 
209    uint32_t masks[RADV_MAX_HIT_ATTRIB_DWORDS] = {
210       0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
211    };
212    validate(masks, 0xFF, true);
213 }
214 
TEST_F(radv_nir_lower_hit_attrib_derefs_test,struct_inside_array)215 TEST_F(radv_nir_lower_hit_attrib_derefs_test, struct_inside_array)
216 {
217    glsl_struct_field field = glsl_struct_field(glsl_uint_type(), "array");
218    const glsl_type *struct_type = glsl_struct_type(&field, 1, "hit_attrib_struct", false);
219    nir_variable *array_var =
220       nir_variable_create(b->shader, nir_var_ray_hit_attrib,
221                           glsl_array_type(struct_type, RADV_MAX_HIT_ATTRIB_DWORDS, 0), "hit_attrib_struct[]");
222 
223    nir_deref_instr *var_deref = nir_build_deref_var(b, array_var);
224    for (uint32_t i = 0; i < RADV_MAX_HIT_ATTRIB_DWORDS; i++) {
225       nir_deref_instr *element_deref = nir_build_deref_array_imm(b, var_deref, i);
226       nir_deref_instr *member_deref = nir_build_deref_struct(b, element_deref, 0);
227 
228       nir_def *load = nir_load_deref(b, member_deref);
229       nir_store_deref(b, member_deref, load, (1 << load->num_components) - 1);
230    }
231 
232    uint32_t masks[RADV_MAX_HIT_ATTRIB_DWORDS] = {
233       0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
234    };
235    validate(masks, 0xFF, true);
236 }
237