xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/d3d12/d3d12_lower_point_sprite.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "util/u_dynarray.h"
27 #include "d3d12_compiler.h"
28 #include "d3d12_nir_passes.h"
29 #include "dxil_nir.h"
30 #include "program/prog_statevars.h"
31 
32 struct output_writes {
33    nir_def *val;
34    nir_deref_instr *deref;
35    unsigned write_mask;
36 };
37 
38 struct lower_state {
39    nir_variable *uniform; /* (1/w, 1/h, pt_sz, max_sz) */
40    nir_variable *pos_out;
41    nir_variable *psiz_out;
42    nir_variable *point_coord_out[10];
43    unsigned num_point_coords;
44 
45    nir_def *point_dir_imm[4];
46    nir_def *point_coord_imm[4];
47 
48    /* Current point primitive */
49    nir_def *point_pos;
50    nir_def *point_size;
51 
52    struct util_dynarray output_writes;
53 
54    bool sprite_origin_lower_left;
55    bool point_size_per_vertex;
56    bool aa_point;
57 };
58 
59 static void
find_outputs(nir_shader * shader,struct lower_state * state)60 find_outputs(nir_shader *shader, struct lower_state *state)
61 {
62    nir_foreach_variable_with_modes(var, shader, nir_var_shader_out) {
63       switch (var->data.location) {
64       case VARYING_SLOT_POS:
65          state->pos_out = var;
66          break;
67       case VARYING_SLOT_PSIZ:
68          state->psiz_out = var;
69          break;
70       }
71    }
72 }
73 
74 static nir_def *
get_point_dir(nir_builder * b,struct lower_state * state,unsigned i)75 get_point_dir(nir_builder *b, struct lower_state *state, unsigned i)
76 {
77    if (state->point_dir_imm[0] == NULL) {
78       state->point_dir_imm[0] = nir_imm_vec2(b, -1, -1);
79       state->point_dir_imm[1] = nir_imm_vec2(b, -1, 1);
80       state->point_dir_imm[2] = nir_imm_vec2(b, 1, -1);
81       state->point_dir_imm[3] = nir_imm_vec2(b, 1, 1);
82    }
83 
84    return state->point_dir_imm[i];
85 }
86 
87 static nir_def *
get_point_coord(nir_builder * b,struct lower_state * state,unsigned i)88 get_point_coord(nir_builder *b, struct lower_state *state, unsigned i)
89 {
90    if (state->point_coord_imm[0] == NULL) {
91       if (state->sprite_origin_lower_left) {
92          state->point_coord_imm[0] = nir_imm_vec4(b, 0, 0, 0, 1);
93          state->point_coord_imm[1] = nir_imm_vec4(b, 0, 1, 0, 1);
94          state->point_coord_imm[2] = nir_imm_vec4(b, 1, 0, 0, 1);
95          state->point_coord_imm[3] = nir_imm_vec4(b, 1, 1, 0, 1);
96       } else {
97          state->point_coord_imm[0] = nir_imm_vec4(b, 0, 1, 0, 1);
98          state->point_coord_imm[1] = nir_imm_vec4(b, 0, 0, 0, 1);
99          state->point_coord_imm[2] = nir_imm_vec4(b, 1, 1, 0, 1);
100          state->point_coord_imm[3] = nir_imm_vec4(b, 1, 0, 0, 1);
101       }
102    }
103 
104    return state->point_coord_imm[i];
105 }
106 
107 /**
108  * scaled_point_size = pointSize * pos.w * ViewportSizeRcp
109  */
110 static void
get_scaled_point_size(nir_builder * b,struct lower_state * state,nir_def ** x,nir_def ** y)111 get_scaled_point_size(nir_builder *b, struct lower_state *state,
112                       nir_def **x, nir_def **y)
113 {
114    /* State uniform contains: (1/ViewportWidth, 1/ViewportHeight, PointSize, MaxPointSize) */
115    nir_def *uniform = nir_load_var(b, state->uniform);
116    nir_def *point_size = state->point_size;
117 
118    /* clamp point-size to valid range */
119    if (point_size && state->point_size_per_vertex) {
120       point_size = nir_fmax(b, point_size, nir_imm_float(b, 1.0f));
121       point_size = nir_fmin(b, point_size, nir_imm_float(b, D3D12_MAX_POINT_SIZE));
122    } else {
123       /* Use static point size (from uniform) if the shader output was not set */
124       point_size = nir_channel(b, uniform, 2);
125    }
126 
127    point_size = nir_fmul(b, point_size, nir_channel(b, state->point_pos, 3));
128    *x = nir_fmul(b, point_size, nir_channel(b, uniform, 0));
129    *y = nir_fmul(b, point_size, nir_channel(b, uniform, 1));
130 }
131 
132 static bool
lower_store(nir_intrinsic_instr * instr,nir_builder * b,struct lower_state * state)133 lower_store(nir_intrinsic_instr *instr, nir_builder *b, struct lower_state *state)
134 {
135    nir_deref_instr *deref = nir_src_as_deref(instr->src[0]);
136    if (nir_deref_mode_is(deref, nir_var_shader_out)) {
137       nir_variable *var = nir_deref_instr_get_variable(deref);
138 
139       switch (var->data.location) {
140       case VARYING_SLOT_POS:
141          state->point_pos = instr->src[1].ssa;
142          break;
143       case VARYING_SLOT_PSIZ:
144          state->point_size = instr->src[1].ssa;
145          break;
146       default: {
147             struct output_writes data = {
148                .val = instr->src[1].ssa,
149                .deref = nir_src_as_deref(instr->src[0]),
150                .write_mask = nir_intrinsic_write_mask(instr),
151             };
152             util_dynarray_append(&state->output_writes, struct output_writes, data);
153             break;
154          }
155       }
156 
157       nir_instr_remove(&instr->instr);
158       return true;
159    }
160 
161    return false;
162 }
163 
164 static bool
lower_emit_vertex(nir_intrinsic_instr * instr,nir_builder * b,struct lower_state * state)165 lower_emit_vertex(nir_intrinsic_instr *instr, nir_builder *b, struct lower_state *state)
166 {
167    unsigned stream_id = nir_intrinsic_stream_id(instr);
168 
169    nir_def *point_width, *point_height;
170    get_scaled_point_size(b, state, &point_width, &point_height);
171 
172    nir_instr_remove(&instr->instr);
173    if (stream_id == 0) {
174       for (unsigned i = 0; i < 4; i++) {
175          /* All outputs need to be emitted for each vertex */
176          util_dynarray_foreach(&state->output_writes, struct output_writes, data) {
177             nir_store_deref(b, data->deref, data->val, data->write_mask);
178          }
179 
180          /* pos = scaled_point_size * point_dir + point_pos */
181          nir_def *point_dir = get_point_dir(b, state, i);
182          nir_def *pos = nir_vec4(b,
183                                      nir_ffma(b,
184                                               point_width,
185                                               nir_channel(b, point_dir, 0),
186                                               nir_channel(b, state->point_pos, 0)),
187                                      nir_ffma(b,
188                                               point_height,
189                                               nir_channel(b, point_dir, 1),
190                                               nir_channel(b, state->point_pos, 1)),
191                                      nir_channel(b, state->point_pos, 2),
192                                      nir_channel(b, state->point_pos, 3));
193          nir_store_var(b, state->pos_out, pos, 0xf);
194 
195          /* point coord */
196          nir_def *point_coord = get_point_coord(b, state, i);
197          for (unsigned j = 0; j < state->num_point_coords; ++j) {
198             unsigned num_channels = glsl_get_components(state->point_coord_out[j]->type);
199             unsigned mask = (1 << num_channels) - 1;
200             nir_store_var(b, state->point_coord_out[j], nir_channels(b, point_coord, mask), mask);
201          }
202 
203          /* EmitVertex */
204          nir_emit_vertex(b, .stream_id = stream_id);
205       }
206 
207       /* EndPrimitive */
208       nir_end_primitive(b, .stream_id = stream_id);
209    }
210 
211    /* Reset everything */
212    state->point_pos = NULL;
213    state->point_size = NULL;
214    util_dynarray_clear(&state->output_writes);
215 
216    return true;
217 }
218 
219 static bool
lower_instr(nir_intrinsic_instr * instr,nir_builder * b,struct lower_state * state)220 lower_instr(nir_intrinsic_instr *instr, nir_builder *b, struct lower_state *state)
221 {
222    b->cursor = nir_before_instr(&instr->instr);
223 
224    if (instr->intrinsic == nir_intrinsic_store_deref) {
225       return lower_store(instr, b, state);
226    } else if (instr->intrinsic == nir_intrinsic_emit_vertex) {
227       return lower_emit_vertex(instr, b, state);
228    } else if (instr->intrinsic == nir_intrinsic_end_primitive) {
229       nir_instr_remove(&instr->instr);
230       return true;
231    }
232 
233    return false;
234 }
235 
236 bool
d3d12_lower_point_sprite(nir_shader * shader,bool sprite_origin_lower_left,bool point_size_per_vertex,unsigned point_coord_enable,uint64_t next_inputs_read)237 d3d12_lower_point_sprite(nir_shader *shader,
238                          bool sprite_origin_lower_left,
239                          bool point_size_per_vertex,
240                          unsigned point_coord_enable,
241                          uint64_t next_inputs_read)
242 {
243    const gl_state_index16 tokens[4] = { STATE_INTERNAL_DRIVER,
244                                         D3D12_STATE_VAR_PT_SPRITE };
245    struct lower_state state;
246    util_dynarray_init(&state.output_writes, shader);
247    bool progress = false;
248 
249    assert(shader->info.gs.output_primitive == MESA_PRIM_POINTS);
250 
251    memset(&state, 0, sizeof(state));
252    find_outputs(shader, &state);
253    state.sprite_origin_lower_left = sprite_origin_lower_left;
254    state.point_size_per_vertex = point_size_per_vertex;
255 
256    /* Create uniform to retrieve inverse of viewport size and point size:
257     * (1/ViewportWidth, 1/ViewportHeight, PointSize, MaxPointSize) */
258    state.uniform = nir_state_variable_create(shader, glsl_vec4_type(),
259                                              "d3d12_ViewportSizeRcp", tokens);
260 
261    /* Create new outputs for point tex coordinates */
262    unsigned count = 0;
263    for (unsigned int sem = 0; sem < ARRAY_SIZE(state.point_coord_out); sem++) {
264       if (point_coord_enable & BITFIELD64_BIT(sem)) {
265          char tmp[100];
266          unsigned location = VARYING_SLOT_TEX0 + sem;
267 
268          snprintf(tmp, ARRAY_SIZE(tmp), "gl_TexCoord%dMESA", count);
269 
270          nir_variable *var = nir_variable_create(shader,
271                                                  nir_var_shader_out,
272                                                  glsl_vec4_type(),
273                                                  tmp);
274          var->data.location = location;
275          state.point_coord_out[count++] = var;
276       }
277    }
278    if (next_inputs_read & VARYING_BIT_PNTC) {
279       nir_variable *pntcoord_var = nir_variable_create(shader,
280                                                        nir_var_shader_out,
281                                                        glsl_vec_type(2),
282                                                        "gl_PointCoordMESA");
283       pntcoord_var->data.location = VARYING_SLOT_PNTC;
284       state.point_coord_out[count++] = pntcoord_var;
285    }
286 
287    state.num_point_coords = count;
288 
289    nir_foreach_function_impl(impl, shader) {
290       nir_builder builder = nir_builder_create(impl);
291       nir_foreach_block(block, impl) {
292          nir_foreach_instr_safe(instr, block) {
293             if (instr->type == nir_instr_type_intrinsic)
294                progress |= lower_instr(nir_instr_as_intrinsic(instr),
295                                        &builder,
296                                        &state);
297          }
298       }
299 
300       nir_metadata_preserve(impl, nir_metadata_control_flow);
301    }
302 
303    util_dynarray_fini(&state.output_writes);
304    shader->info.gs.output_primitive = MESA_PRIM_TRIANGLE_STRIP;
305    shader->info.gs.vertices_out = shader->info.gs.vertices_out * 4 /
306       util_bitcount(shader->info.gs.active_stream_mask);
307    shader->info.gs.active_stream_mask = 1;
308 
309    return progress;
310 }
311