xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/d3d12/d3d12_compute_transforms.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "d3d12_compute_transforms.h"
25 #include "d3d12_nir_passes.h"
26 #include "d3d12_query.h"
27 #include "d3d12_screen.h"
28 
29 #include "nir.h"
30 #include "nir_builder.h"
31 
32 #include "util/u_memory.h"
33 
34 nir_shader *
get_indirect_draw_base_vertex_transform(const nir_shader_compiler_options * options,const d3d12_compute_transform_key * args)35 get_indirect_draw_base_vertex_transform(const nir_shader_compiler_options *options, const d3d12_compute_transform_key *args)
36 {
37    nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, options, "TransformIndirectDrawBaseVertex");
38 
39    if (args->base_vertex.dynamic_count) {
40       nir_variable *count_ubo = nir_variable_create(b.shader, nir_var_mem_ubo,
41          glsl_uint_type(), "in_count");
42       count_ubo->data.driver_location = 0;
43    }
44 
45    nir_variable *input_ssbo = nir_variable_create(b.shader, nir_var_mem_ssbo,
46       glsl_array_type(glsl_uint_type(), 0, 0), "input");
47    nir_variable *output_ssbo = nir_variable_create(b.shader, nir_var_mem_ssbo,
48       input_ssbo->type, "output");
49    input_ssbo->data.driver_location = 0;
50    output_ssbo->data.driver_location = 1;
51 
52    nir_def *draw_id = nir_channel(&b, nir_load_global_invocation_id(&b, 32), 0);
53    if (args->base_vertex.dynamic_count) {
54       nir_def *count = nir_load_ubo(&b, 1, 32, nir_imm_int(&b, 1), nir_imm_int(&b, 0),
55          (gl_access_qualifier)0, 4, 0, 0, 4);
56       nir_push_if(&b, nir_ilt(&b, draw_id, count));
57    }
58 
59    nir_variable *stride_ubo = NULL;
60    nir_def *in_stride_offset_and_base_drawid = d3d12_get_state_var(&b, D3D12_STATE_VAR_TRANSFORM_GENERIC0, "d3d12_Stride",
61       glsl_uvec4_type(), &stride_ubo);
62    nir_def *in_offset = nir_iadd(&b, nir_channel(&b, in_stride_offset_and_base_drawid, 1),
63       nir_imul(&b, nir_channel(&b, in_stride_offset_and_base_drawid, 0), draw_id));
64    nir_def *in_data0 = nir_load_ssbo(&b, 4, 32, nir_imm_int(&b, 0), in_offset, (gl_access_qualifier)0, 4, 0);
65 
66    nir_def *in_data1 = NULL;
67    nir_def *base_vertex = NULL, *base_instance = NULL;
68    if (args->base_vertex.indexed) {
69       nir_def *in_offset1 = nir_iadd(&b, in_offset, nir_imm_int(&b, 16));
70       in_data1 = nir_load_ssbo(&b, 1, 32, nir_imm_int(&b, 0), in_offset1, (gl_access_qualifier)0, 4, 0);
71       base_vertex = nir_channel(&b, in_data0, 3);
72       base_instance = in_data1;
73    } else {
74       base_vertex = nir_channel(&b, in_data0, 2);
75       base_instance = nir_channel(&b, in_data0, 3);
76    }
77 
78    /* 4 additional uints for base vertex, base instance, draw ID, and a bool for indexed draw */
79    unsigned out_stride = sizeof(uint32_t) * ((args->base_vertex.indexed ? 5 : 4) + 4);
80 
81    nir_def *out_offset = nir_imul(&b, draw_id, nir_imm_int(&b, out_stride));
82    nir_def *out_data0 = nir_vec4(&b, base_vertex, base_instance,
83       nir_iadd(&b, draw_id, nir_channel(&b, in_stride_offset_and_base_drawid, 2)),
84       nir_imm_int(&b, args->base_vertex.indexed ? -1 : 0));
85    nir_def *out_data1 = in_data0;
86 
87    nir_store_ssbo(&b, out_data0, nir_imm_int(&b, 1), out_offset, 0xf, (gl_access_qualifier)0, 4, 0);
88    nir_store_ssbo(&b, out_data1, nir_imm_int(&b, 1), nir_iadd(&b, out_offset, nir_imm_int(&b, 16)),
89       (1u << out_data1->num_components) - 1, (gl_access_qualifier)0, 4, 0);
90    if (args->base_vertex.indexed)
91       nir_store_ssbo(&b, in_data1, nir_imm_int(&b, 1), nir_iadd(&b, out_offset, nir_imm_int(&b, 32)), 1, (gl_access_qualifier)0, 4, 0);
92 
93    if (args->base_vertex.dynamic_count)
94       nir_pop_if(&b, NULL);
95 
96    nir_validate_shader(b.shader, "creation");
97    b.shader->info.num_ssbos = 2;
98    b.shader->info.num_ubos = (args->base_vertex.dynamic_count ? 1 : 0);
99 
100    return b.shader;
101 }
102 
103 static struct nir_shader *
get_fake_so_buffer_copy_back(const nir_shader_compiler_options * options,const d3d12_compute_transform_key * key)104 get_fake_so_buffer_copy_back(const nir_shader_compiler_options *options, const d3d12_compute_transform_key *key)
105 {
106    nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, options, "FakeSOBufferCopyBack");
107 
108    nir_variable *output_so_data_var = nir_variable_create(b.shader, nir_var_mem_ssbo,
109       glsl_array_type(glsl_uint_type(), 0, 0), "output_data");
110    nir_variable *input_so_data_var = nir_variable_create(b.shader, nir_var_mem_ssbo, output_so_data_var->type, "input_data");
111    output_so_data_var->data.driver_location = 0;
112    input_so_data_var->data.driver_location = 1;
113 
114    /* UBO is [fake SO filled size, fake SO vertex count, 1, 1, original SO filled size] */
115    nir_variable *input_ubo = nir_variable_create(b.shader, nir_var_mem_ubo,
116       glsl_array_type(glsl_uint_type(), 5, 0), "input_ubo");
117    input_ubo->data.driver_location = 0;
118 
119    nir_def *original_so_filled_size = nir_load_ubo(&b, 1, 32, nir_imm_int(&b, 0), nir_imm_int(&b, 4 * sizeof(uint32_t)),
120       (gl_access_qualifier)0, 4, 0, 4 * sizeof(uint32_t), 4);
121 
122    nir_variable *state_var = nullptr;
123    nir_def *fake_so_multiplier = d3d12_get_state_var(&b, D3D12_STATE_VAR_TRANSFORM_GENERIC0, "fake_so_multiplier", glsl_uint_type(), &state_var);
124 
125    nir_def *vertex_offset = nir_imul(&b, nir_imm_int(&b, key->fake_so_buffer_copy_back.stride),
126       nir_channel(&b, nir_load_global_invocation_id(&b, 32), 0));
127 
128    nir_def *output_offset_base = nir_iadd(&b, original_so_filled_size, vertex_offset);
129    nir_def *input_offset_base = nir_imul(&b, vertex_offset, fake_so_multiplier);
130 
131    for (unsigned i = 0; i < key->fake_so_buffer_copy_back.num_ranges; ++i) {
132       auto& output = key->fake_so_buffer_copy_back.ranges[i];
133       assert(output.size % 4 == 0 && output.offset % 4 == 0);
134       nir_def *field_offset = nir_imm_int(&b, output.offset);
135       nir_def *output_offset = nir_iadd(&b, output_offset_base, field_offset);
136       nir_def *input_offset = nir_iadd(&b, input_offset_base, field_offset);
137 
138       for (unsigned loaded = 0; loaded < output.size; loaded += 16) {
139          unsigned to_load = MIN2(output.size, 16);
140          unsigned components = to_load / 4;
141          nir_def *loaded_data = nir_load_ssbo(&b, components, 32, nir_imm_int(&b, 1),
142             nir_iadd(&b, input_offset, nir_imm_int(&b, loaded)), (gl_access_qualifier)0, 4, 0);
143          nir_store_ssbo(&b, loaded_data, nir_imm_int(&b, 0),
144             nir_iadd(&b, output_offset, nir_imm_int(&b, loaded)), (1u << components) - 1, (gl_access_qualifier)0, 4, 0);
145       }
146    }
147 
148    nir_validate_shader(b.shader, "creation");
149    b.shader->info.num_ssbos = 2;
150    b.shader->info.num_ubos = 1;
151 
152    return b.shader;
153 }
154 
155 static struct nir_shader *
get_fake_so_buffer_vertex_count(const nir_shader_compiler_options * options)156 get_fake_so_buffer_vertex_count(const nir_shader_compiler_options *options)
157 {
158    nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, options, "FakeSOBufferVertexCount");
159 
160    nir_variable_create(b.shader, nir_var_mem_ssbo, glsl_array_type(glsl_uint_type(), 0, 0), "fake_so");
161    nir_def *fake_buffer_filled_size = nir_load_ssbo(&b, 1, 32, nir_imm_int(&b, 0), nir_imm_int(&b, 0), (gl_access_qualifier)0, 4, 0);
162 
163    nir_variable *real_so_var = nir_variable_create(b.shader, nir_var_mem_ssbo,
164       glsl_array_type(glsl_uint_type(), 0, 0), "real_so");
165    real_so_var->data.driver_location = 1;
166    nir_def *real_buffer_filled_size = nir_load_ssbo(&b, 1, 32, nir_imm_int(&b, 1), nir_imm_int(&b, 0), (gl_access_qualifier)0, 4, 0);
167 
168    nir_variable *state_var = nullptr;
169    nir_def *state_var_data = d3d12_get_state_var(&b, D3D12_STATE_VAR_TRANSFORM_GENERIC0, "state_var", glsl_uvec4_type(), &state_var);
170    nir_def *stride = nir_channel(&b, state_var_data, 0);
171    nir_def *fake_so_multiplier = nir_channel(&b, state_var_data, 1);
172 
173    nir_def *real_so_bytes_added = nir_idiv(&b, fake_buffer_filled_size, fake_so_multiplier);
174    nir_def *vertex_count = nir_idiv(&b, real_so_bytes_added, stride);
175    nir_def *to_write_to_fake_buffer = nir_vec4(&b, vertex_count, nir_imm_int(&b, 1), nir_imm_int(&b, 1), real_buffer_filled_size);
176    nir_store_ssbo(&b, to_write_to_fake_buffer, nir_imm_int(&b, 0), nir_imm_int(&b, 4), 0xf, (gl_access_qualifier)0, 4, 0);
177 
178    nir_def *updated_filled_size = nir_iadd(&b, real_buffer_filled_size, real_so_bytes_added);
179    nir_store_ssbo(&b, updated_filled_size, nir_imm_int(&b, 1), nir_imm_int(&b, 0), 1, (gl_access_qualifier)0, 4, 0);
180 
181    nir_validate_shader(b.shader, "creation");
182    b.shader->info.num_ssbos = 2;
183    b.shader->info.num_ubos = 0;
184 
185    return b.shader;
186 }
187 
188 static struct nir_shader *
get_draw_auto(const nir_shader_compiler_options * options)189 get_draw_auto(const nir_shader_compiler_options *options)
190 {
191    nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, options, "DrawAuto");
192 
193    nir_variable_create(b.shader, nir_var_mem_ssbo, glsl_array_type(glsl_uint_type(), 0, 0), "ssbo");
194    nir_def *buffer_filled_size = nir_load_ssbo(&b, 1, 32, nir_imm_int(&b, 0), nir_imm_int(&b, 0), (gl_access_qualifier)0, 4, 0);
195 
196    nir_variable *state_var = nullptr;
197    nir_def *state_var_data = d3d12_get_state_var(&b, D3D12_STATE_VAR_TRANSFORM_GENERIC0, "state_var", glsl_uvec4_type(), &state_var);
198    nir_def *stride = nir_channel(&b, state_var_data, 0);
199    nir_def *vb_offset = nir_channel(&b, state_var_data, 1);
200 
201    nir_def *vb_bytes = nir_bcsel(&b, nir_ilt(&b, vb_offset, buffer_filled_size),
202       nir_isub(&b, buffer_filled_size, vb_offset), nir_imm_int(&b, 0));
203 
204    nir_def *vertex_count = nir_idiv(&b, vb_bytes, stride);
205    nir_def *to_write = nir_vec4(&b, vertex_count, nir_imm_int(&b, 1), nir_imm_int(&b, 0), nir_imm_int(&b, 0));
206    nir_store_ssbo(&b, to_write, nir_imm_int(&b, 0), nir_imm_int(&b, 4), 0xf, (gl_access_qualifier)0, 4, 0);
207 
208    nir_validate_shader(b.shader, "creation");
209    b.shader->info.num_ssbos = 1;
210    b.shader->info.num_ubos = 0;
211 
212    return b.shader;
213 }
214 
215 static struct nir_shader *
get_query_resolve(const nir_shader_compiler_options * options,const d3d12_compute_transform_key * key)216 get_query_resolve(const nir_shader_compiler_options *options, const d3d12_compute_transform_key *key)
217 {
218    nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, options, "QueryResolve");
219 
220    uint32_t bit_size = key->query_resolve.is_64bit ? 64 : 32;
221    const struct glsl_type *value_type = glsl_uintN_t_type(bit_size);
222 
223    assert(!key->query_resolve.is_resolve_in_place ||
224           (key->query_resolve.is_64bit && key->query_resolve.num_subqueries == 1));
225    assert(key->query_resolve.num_subqueries == 1 ||
226           key->query_resolve.pipe_query_type == PIPE_QUERY_PRIMITIVES_GENERATED ||
227           key->query_resolve.pipe_query_type == PIPE_QUERY_SO_OVERFLOW_ANY_PREDICATE);
228    assert(key->query_resolve.num_subqueries <= 4);
229 
230    nir_variable *inputs[4];
231    for (uint32_t i = 0; i < key->query_resolve.num_subqueries; ++i) {
232       /* Inputs are always 64-bit */
233       inputs[i] = nir_variable_create(b.shader, nir_var_mem_ssbo, glsl_array_type(glsl_uint64_t_type(), 0, 8), "input");
234       inputs[i]->data.binding = i;
235    }
236    nir_variable *output = inputs[0];
237    if (!key->query_resolve.is_resolve_in_place) {
238       output = nir_variable_create(b.shader, nir_var_mem_ssbo, glsl_array_type(value_type, 0, bit_size / 8), "output");
239       output->data.binding = key->query_resolve.num_subqueries;
240    }
241 
242    /* How many entries in each sub-query is passed via root constants */
243    nir_variable *state_var = nullptr, *state_var1 = nullptr;
244    nir_def *state_var_data = d3d12_get_state_var(&b, D3D12_STATE_VAR_TRANSFORM_GENERIC0, "state_var", glsl_uvec4_type(), &state_var);
245    nir_def *state_var_data1 = d3d12_get_state_var(&b, D3D12_STATE_VAR_TRANSFORM_GENERIC1, "state_var1", glsl_uvec4_type(), &state_var1);
246 
247    /* For in-place resolves, we resolve each field of the query. Otherwise, resolve one field into the dest */
248    nir_variable *results[sizeof(D3D12_QUERY_DATA_PIPELINE_STATISTICS) / sizeof(UINT64)];
249    uint32_t num_result_values = 1;
250 
251    if (key->query_resolve.is_resolve_in_place) {
252       if (key->query_resolve.pipe_query_type == PIPE_QUERY_PIPELINE_STATISTICS)
253          num_result_values = sizeof(D3D12_QUERY_DATA_PIPELINE_STATISTICS) / sizeof(UINT64);
254       else if (key->query_resolve.pipe_query_type == PIPE_QUERY_SO_STATISTICS)
255          num_result_values = sizeof(D3D12_QUERY_DATA_SO_STATISTICS) / sizeof(UINT64);
256    }
257 
258    uint32_t var_bit_size = key->query_resolve.pipe_query_type == PIPE_QUERY_TIME_ELAPSED ||
259                            key->query_resolve.pipe_query_type == PIPE_QUERY_TIMESTAMP ? 64 : bit_size;
260    for (uint32_t i = 0; i < num_result_values; ++i) {
261       results[i] = nir_local_variable_create(b.impl, glsl_uintN_t_type(var_bit_size), "result");
262       nir_store_var(&b, results[i], nir_imm_intN_t(&b, 0, var_bit_size), 1);
263    }
264 
265    /* For each subquery... */
266    for (uint32_t i = 0; i < key->query_resolve.num_subqueries; ++i) {
267       nir_def *num_results = nir_channel(&b, state_var_data, i);
268 
269       uint32_t subquery_index = key->query_resolve.num_subqueries == 1 ?
270          key->query_resolve.single_subquery_index : i;
271       uint32_t base_offset = 0;
272       uint32_t stride = 0;
273       switch (key->query_resolve.pipe_query_type) {
274       case PIPE_QUERY_OCCLUSION_COUNTER:
275       case PIPE_QUERY_OCCLUSION_PREDICATE:
276       case PIPE_QUERY_OCCLUSION_PREDICATE_CONSERVATIVE:
277       case PIPE_QUERY_TIMESTAMP:
278          stride = 1;
279          break;
280       case PIPE_QUERY_TIME_ELAPSED:
281          stride = 2;
282          break;
283       case PIPE_QUERY_SO_STATISTICS:
284       case PIPE_QUERY_PRIMITIVES_EMITTED:
285       case PIPE_QUERY_SO_OVERFLOW_PREDICATE:
286       case PIPE_QUERY_SO_OVERFLOW_ANY_PREDICATE:
287          stride = sizeof(D3D12_QUERY_DATA_SO_STATISTICS) / sizeof(UINT64);
288          break;
289       case PIPE_QUERY_PRIMITIVES_GENERATED:
290          if (subquery_index == 0)
291             stride = sizeof(D3D12_QUERY_DATA_SO_STATISTICS) / sizeof(UINT64);
292          else
293             stride = sizeof(D3D12_QUERY_DATA_PIPELINE_STATISTICS) / sizeof(UINT64);
294          if (!key->query_resolve.is_resolve_in_place) {
295             if (subquery_index == 1)
296                base_offset = offsetof(D3D12_QUERY_DATA_PIPELINE_STATISTICS, GSPrimitives) / sizeof(UINT64);
297             else if (subquery_index == 2)
298                base_offset = offsetof(D3D12_QUERY_DATA_PIPELINE_STATISTICS, IAPrimitives) / sizeof(UINT64);
299          }
300          break;
301       case PIPE_QUERY_PIPELINE_STATISTICS:
302          stride = sizeof(D3D12_QUERY_DATA_PIPELINE_STATISTICS) / sizeof(UINT64);
303          break;
304       default:
305          unreachable("Unhandled query resolve");
306       }
307 
308       if (!key->query_resolve.is_resolve_in_place && key->query_resolve.num_subqueries == 1)
309          base_offset = key->query_resolve.single_result_field_offset;
310 
311       nir_def *base_array_index = nir_imm_int(&b, base_offset);
312 
313       /* For each query result in this subquery... */
314       nir_variable *loop_counter = nir_local_variable_create(b.impl, glsl_uint_type(), "loop_counter");
315       nir_store_var(&b, loop_counter, nir_imm_int(&b, 0), 1);
316       nir_loop *loop = nir_push_loop(&b);
317 
318       nir_def *loop_counter_value = nir_load_var(&b, loop_counter);
319       nir_if *nif = nir_push_if(&b, nir_ieq(&b, loop_counter_value, num_results));
320       nir_jump(&b, nir_jump_break);
321       nir_pop_if(&b, nif);
322 
323       /* For each field in the query result, accumulate */
324       nir_def *array_index = nir_iadd(&b, nir_imul_imm(&b, loop_counter_value, stride), base_array_index);
325       for (uint32_t j = 0; j < num_result_values; ++j) {
326          nir_def *new_value;
327          if (key->query_resolve.pipe_query_type == PIPE_QUERY_TIME_ELAPSED) {
328             assert(j == 0 && i == 0);
329             nir_def *start = nir_load_ssbo(&b, 1, 64, nir_imm_int(&b, i), nir_imul_imm(&b, array_index, 8));
330             nir_def *end = nir_load_ssbo(&b, 1, 64, nir_imm_int(&b, i), nir_imul_imm(&b, nir_iadd_imm(&b, array_index, 1), 8));
331             new_value = nir_iadd(&b, nir_load_var(&b, results[j]), nir_isub(&b, end, start));
332          } else if (key->query_resolve.pipe_query_type == PIPE_QUERY_SO_OVERFLOW_ANY_PREDICATE ||
333                     key->query_resolve.pipe_query_type == PIPE_QUERY_SO_OVERFLOW_PREDICATE) {
334             /* These predicates are true if the primitives emitted != primitives stored */
335             assert(j == 0);
336             nir_def *val_a = nir_load_ssbo(&b, 1, 64, nir_imm_int(&b, i), nir_imul_imm(&b, array_index, 8));
337             nir_def *val_b = nir_load_ssbo(&b, 1, 64, nir_imm_int(&b, i), nir_imul_imm(&b, nir_iadd_imm(&b, array_index, 1), 8));
338             new_value = nir_ior(&b, nir_load_var(&b, results[j]), nir_u2uN(&b, nir_ine(&b, val_a, val_b), var_bit_size));
339          } else {
340             new_value = nir_u2uN(&b, nir_load_ssbo(&b, 1, 64, nir_imm_int(&b, i), nir_imul_imm(&b, nir_iadd_imm(&b, array_index, j), 8)), var_bit_size);
341             new_value = nir_iadd(&b, nir_load_var(&b, results[j]), new_value);
342          }
343          nir_store_var(&b, results[j], new_value, 1);
344       }
345 
346       nir_store_var(&b, loop_counter, nir_iadd_imm(&b, loop_counter_value, 1), 1);
347       nir_pop_loop(&b, loop);
348    }
349 
350    /* Results are accumulated, now store the final values */
351    nir_def *output_base_index = nir_channel(&b, state_var_data1, 0);
352    for (uint32_t i = 0; i < num_result_values; ++i) {
353       /* When resolving in-place, resolve each field, otherwise just write the one result */
354       uint32_t field_offset = key->query_resolve.is_resolve_in_place ? i : 0;
355 
356       /* When resolving time elapsed in-place, write [0, time], as the only special case */
357       if (key->query_resolve.is_resolve_in_place &&
358           key->query_resolve.pipe_query_type == PIPE_QUERY_TIME_ELAPSED) {
359          nir_store_ssbo(&b, nir_imm_int64(&b, 0), nir_imm_int(&b, output->data.binding),
360                         nir_imul_imm(&b, output_base_index, bit_size / 8), 1, (gl_access_qualifier)0, bit_size / 8, 0);
361          field_offset++;
362       }
363       nir_def *result_val = nir_load_var(&b, results[i]);
364       if (!key->query_resolve.is_resolve_in_place &&
365           (key->query_resolve.pipe_query_type == PIPE_QUERY_TIME_ELAPSED ||
366            key->query_resolve.pipe_query_type == PIPE_QUERY_TIMESTAMP)) {
367          result_val = nir_f2u64(&b, nir_fmul_imm(&b, nir_u2f32(&b, result_val), key->query_resolve.timestamp_multiplier));
368 
369          if (!key->query_resolve.is_64bit) {
370             nir_alu_type rounding_type = key->query_resolve.is_signed ? nir_type_int : nir_type_uint;
371             nir_alu_type src_round = (nir_alu_type)(rounding_type | 64);
372             nir_alu_type dst_round = (nir_alu_type)(rounding_type | bit_size);
373             result_val = nir_convert_alu_types(&b, bit_size, result_val, src_round, dst_round, nir_rounding_mode_undef, true);
374          }
375       }
376       nir_store_ssbo(&b, result_val, nir_imm_int(&b, output->data.binding),
377                      nir_imul_imm(&b, nir_iadd_imm(&b, output_base_index, field_offset), bit_size / 8),
378                      1, (gl_access_qualifier)0, bit_size / 8, 0);
379    }
380 
381    nir_validate_shader(b.shader, "creation");
382    b.shader->info.num_ssbos = key->query_resolve.num_subqueries + !key->query_resolve.is_resolve_in_place;
383    b.shader->info.num_ubos = 0;
384 
385    NIR_PASS_V(b.shader, nir_lower_convert_alu_types, NULL);
386 
387    return b.shader;
388 }
389 
390 static struct nir_shader *
create_compute_transform(const nir_shader_compiler_options * options,const d3d12_compute_transform_key * key)391 create_compute_transform(const nir_shader_compiler_options *options, const d3d12_compute_transform_key *key)
392 {
393    switch (key->type) {
394    case d3d12_compute_transform_type::base_vertex:
395       return get_indirect_draw_base_vertex_transform(options, key);
396    case d3d12_compute_transform_type::fake_so_buffer_copy_back:
397       return get_fake_so_buffer_copy_back(options, key);
398    case d3d12_compute_transform_type::fake_so_buffer_vertex_count:
399       return get_fake_so_buffer_vertex_count(options);
400    case d3d12_compute_transform_type::draw_auto:
401       return get_draw_auto(options);
402    case d3d12_compute_transform_type::query_resolve:
403       return get_query_resolve(options, key);
404    default:
405       unreachable("Invalid transform");
406    }
407 }
408 
409 struct compute_transform
410 {
411    d3d12_compute_transform_key key;
412    d3d12_shader_selector *shader;
413 };
414 
415 d3d12_shader_selector *
d3d12_get_compute_transform(struct d3d12_context * ctx,const d3d12_compute_transform_key * key)416 d3d12_get_compute_transform(struct d3d12_context *ctx, const d3d12_compute_transform_key *key)
417 {
418    struct hash_entry *entry = _mesa_hash_table_search(ctx->compute_transform_cache, key);
419    if (!entry) {
420       compute_transform *data = (compute_transform *)MALLOC(sizeof(compute_transform));
421       if (!data)
422          return NULL;
423 
424       const nir_shader_compiler_options *options = &d3d12_screen(ctx->base.screen)->nir_options;
425 
426       memcpy(&data->key, key, sizeof(*key));
427       nir_shader *s = create_compute_transform(options, key);
428       if (!s) {
429          FREE(data);
430          return NULL;
431       }
432       struct pipe_compute_state shader_args = { PIPE_SHADER_IR_NIR, s };
433       data->shader = d3d12_create_compute_shader(ctx, &shader_args);
434       if (!data->shader) {
435          ralloc_free(s);
436          FREE(data);
437          return NULL;
438       }
439 
440       data->shader->is_variant = true;
441       entry = _mesa_hash_table_insert(ctx->compute_transform_cache, &data->key, data);
442       assert(entry);
443    }
444 
445    return ((struct compute_transform *)entry->data)->shader;
446 }
447 
448 static uint32_t
hash_compute_transform_key(const void * key)449 hash_compute_transform_key(const void *key)
450 {
451    return _mesa_hash_data(key, sizeof(struct d3d12_compute_transform_key));
452 }
453 
454 static bool
equals_compute_transform_key(const void * a,const void * b)455 equals_compute_transform_key(const void *a, const void *b)
456 {
457    return memcmp(a, b, sizeof(struct d3d12_compute_transform_key)) == 0;
458 }
459 
460 void
d3d12_compute_transform_cache_init(struct d3d12_context * ctx)461 d3d12_compute_transform_cache_init(struct d3d12_context *ctx)
462 {
463    ctx->compute_transform_cache = _mesa_hash_table_create(NULL,
464                                                           hash_compute_transform_key,
465                                                           equals_compute_transform_key);
466 }
467 
468 static void
delete_entry(struct hash_entry * entry)469 delete_entry(struct hash_entry *entry)
470 {
471    struct compute_transform *data = (struct compute_transform *)entry->data;
472    d3d12_shader_free(data->shader);
473    FREE(data);
474 }
475 
476 void
d3d12_compute_transform_cache_destroy(struct d3d12_context * ctx)477 d3d12_compute_transform_cache_destroy(struct d3d12_context *ctx)
478 {
479    _mesa_hash_table_destroy(ctx->compute_transform_cache, delete_entry);
480 }
481 
482 void
d3d12_save_compute_transform_state(struct d3d12_context * ctx,d3d12_compute_transform_save_restore * save)483 d3d12_save_compute_transform_state(struct d3d12_context *ctx, d3d12_compute_transform_save_restore *save)
484 {
485    if (ctx->current_predication)
486       ctx->cmdlist->SetPredication(nullptr, 0, D3D12_PREDICATION_OP_EQUAL_ZERO);
487 
488    memset(save, 0, sizeof(*save));
489    save->cs = ctx->compute_state;
490 
491    pipe_resource_reference(&save->cbuf0.buffer, ctx->cbufs[PIPE_SHADER_COMPUTE][1].buffer);
492    save->cbuf0 = ctx->cbufs[PIPE_SHADER_COMPUTE][1];
493 
494    for (unsigned i = 0; i < ARRAY_SIZE(save->ssbos); ++i) {
495       pipe_resource_reference(&save->ssbos[i].buffer, ctx->ssbo_views[PIPE_SHADER_COMPUTE][i].buffer);
496       save->ssbos[i] = ctx->ssbo_views[PIPE_SHADER_COMPUTE][i];
497    }
498 
499    save->queries_disabled = ctx->queries_disabled;
500    ctx->base.set_active_query_state(&ctx->base, false);
501 }
502 
503 void
d3d12_restore_compute_transform_state(struct d3d12_context * ctx,d3d12_compute_transform_save_restore * save)504 d3d12_restore_compute_transform_state(struct d3d12_context *ctx, d3d12_compute_transform_save_restore *save)
505 {
506    ctx->base.set_active_query_state(&ctx->base, !save->queries_disabled);
507 
508    ctx->base.bind_compute_state(&ctx->base, save->cs);
509 
510    ctx->base.set_constant_buffer(&ctx->base, PIPE_SHADER_COMPUTE, 1, true, &save->cbuf0);
511    ctx->base.set_shader_buffers(&ctx->base, PIPE_SHADER_COMPUTE, 0, ARRAY_SIZE(save->ssbos), save->ssbos, (1u << ARRAY_SIZE(save->ssbos)) - 1);
512 
513    if (ctx->current_predication)
514       d3d12_enable_predication(ctx);
515 }
516