xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/d3d12/d3d12_tcs_variant.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "d3d12_context.h"
27 #include "d3d12_compiler.h"
28 #include "d3d12_nir_passes.h"
29 #include "d3d12_screen.h"
30 
31 static uint32_t
hash_tcs_variant_key(const void * key)32 hash_tcs_variant_key(const void *key)
33 {
34    d3d12_tcs_variant_key *v = (d3d12_tcs_variant_key*)key;
35    uint32_t hash = _mesa_hash_data(v, offsetof(d3d12_tcs_variant_key, varyings));
36    if (v->varyings)
37       hash = _mesa_hash_data_with_seed(v->varyings->slots, sizeof(v->varyings->slots[0]) * v->varyings->max, hash);
38    return hash;
39 }
40 
41 static bool
equals_tcs_variant_key(const void * a,const void * b)42 equals_tcs_variant_key(const void *a, const void *b)
43 {
44    return memcmp(a, b, sizeof(d3d12_tcs_variant_key)) == 0;
45 }
46 
47 void
d3d12_tcs_variant_cache_init(struct d3d12_context * ctx)48 d3d12_tcs_variant_cache_init(struct d3d12_context *ctx)
49 {
50    ctx->tcs_variant_cache = _mesa_hash_table_create(NULL, NULL, equals_tcs_variant_key);
51 }
52 
53 static void
delete_entry(struct hash_entry * entry)54 delete_entry(struct hash_entry *entry)
55 {
56    d3d12_shader_free((d3d12_shader_selector *)entry->data);
57 }
58 
59 void
d3d12_tcs_variant_cache_destroy(struct d3d12_context * ctx)60 d3d12_tcs_variant_cache_destroy(struct d3d12_context *ctx)
61 {
62    _mesa_hash_table_destroy(ctx->tcs_variant_cache, delete_entry);
63 }
64 
65 static void
copy_vars(nir_builder * b,nir_deref_instr * dst,nir_deref_instr * src)66 copy_vars(nir_builder *b, nir_deref_instr *dst, nir_deref_instr *src)
67 {
68    assert(glsl_get_bare_type(dst->type) == glsl_get_bare_type(src->type));
69    if (glsl_type_is_struct(dst->type)) {
70       for (unsigned i = 0; i < glsl_get_length(dst->type); ++i) {
71          copy_vars(b, nir_build_deref_struct(b, dst, i), nir_build_deref_struct(b, src, i));
72       }
73    } else if (glsl_type_is_array_or_matrix(dst->type)) {
74       copy_vars(b, nir_build_deref_array_wildcard(b, dst), nir_build_deref_array_wildcard(b, src));
75    } else {
76       nir_copy_deref(b, dst, src);
77    }
78 }
79 
80 static struct d3d12_shader_selector *
create_tess_ctrl_shader_variant(struct d3d12_context * ctx,struct d3d12_tcs_variant_key * key)81 create_tess_ctrl_shader_variant(struct d3d12_context *ctx, struct d3d12_tcs_variant_key *key)
82 {
83    nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_TESS_CTRL, &d3d12_screen(ctx->base.screen)->nir_options, "passthrough");
84    nir_shader *nir = b.shader;
85 
86    nir_def *invocation_id = nir_load_invocation_id(&b);
87    uint64_t varying_mask = key->varyings->mask;
88 
89    while(varying_mask) {
90       int var_idx = u_bit_scan64(&varying_mask);
91       auto slot = &key->varyings->slots[var_idx];
92       unsigned frac_mask = slot->location_frac_mask;
93       while (frac_mask) {
94          int frac = u_bit_scan(&frac_mask);
95          auto var = &slot->vars[frac];
96          const struct glsl_type *type = glsl_array_type(slot->types[frac], key->vertices_out, 0);
97 
98          char buf[1024];
99          snprintf(buf, sizeof(buf), "in_%d", var->driver_location);
100          nir_variable *in = nir_variable_create(nir, nir_var_shader_in, type, buf);
101          snprintf(buf, sizeof(buf), "out_%d", var->driver_location);
102          nir_variable *out = nir_variable_create(nir, nir_var_shader_out, type, buf);
103          out->data.location = in->data.location = var_idx;
104          out->data.location_frac = in->data.location_frac = frac;
105          out->data.driver_location = in->data.driver_location = var->driver_location;
106 
107          for (unsigned i = 0; i < key->vertices_out; i++) {
108             nir_if *start_block = nir_push_if(&b, nir_ieq_imm(&b, invocation_id, i));
109             nir_deref_instr *in_array_var = nir_build_deref_array(&b, nir_build_deref_var(&b, in), invocation_id);
110             nir_deref_instr *out_array_var = nir_build_deref_array_imm(&b, nir_build_deref_var(&b, out), i);
111             copy_vars(&b, out_array_var, in_array_var);
112             nir_pop_if(&b, start_block);
113          }
114       }
115    }
116    nir_variable *gl_TessLevelInner = nir_variable_create(nir, nir_var_shader_out, glsl_array_type(glsl_float_type(), 2, 0), "gl_TessLevelInner");
117    gl_TessLevelInner->data.location = VARYING_SLOT_TESS_LEVEL_INNER;
118    gl_TessLevelInner->data.patch = 1;
119    gl_TessLevelInner->data.compact = 1;
120    nir_variable *gl_TessLevelOuter = nir_variable_create(nir, nir_var_shader_out, glsl_array_type(glsl_float_type(), 4, 0), "gl_TessLevelOuter");
121    gl_TessLevelOuter->data.location = VARYING_SLOT_TESS_LEVEL_OUTER;
122    gl_TessLevelOuter->data.patch = 1;
123    gl_TessLevelOuter->data.compact = 1;
124 
125    nir_variable *state_var_inner = NULL, *state_var_outer = NULL;
126    nir_def *load_inner = d3d12_get_state_var(&b, D3D12_STATE_VAR_DEFAULT_INNER_TESS_LEVEL, "d3d12_TessLevelInner", glsl_vec_type(2), &state_var_inner);
127    nir_def *load_outer = d3d12_get_state_var(&b, D3D12_STATE_VAR_DEFAULT_OUTER_TESS_LEVEL, "d3d12_TessLevelOuter", glsl_vec4_type(), &state_var_outer);
128 
129    for (unsigned i = 0; i < 2; i++) {
130       nir_deref_instr *store_idx = nir_build_deref_array_imm(&b, nir_build_deref_var(&b, gl_TessLevelInner), i);
131       nir_store_deref(&b, store_idx, nir_channel(&b, load_inner, i), 0xff);
132    }
133    for (unsigned i = 0; i < 4; i++) {
134       nir_deref_instr *store_idx = nir_build_deref_array_imm(&b, nir_build_deref_var(&b, gl_TessLevelOuter), i);
135       nir_store_deref(&b, store_idx, nir_channel(&b, load_outer, i), 0xff);
136    }
137 
138    nir->info.tess.tcs_vertices_out = key->vertices_out;
139    nir_validate_shader(nir, "created");
140    NIR_PASS_V(nir, nir_lower_var_copies);
141 
142    struct pipe_shader_state templ;
143 
144    templ.type = PIPE_SHADER_IR_NIR;
145    templ.ir.nir = nir;
146    templ.stream_output.num_outputs = 0;
147 
148    d3d12_shader_selector *tcs = d3d12_create_shader(ctx, PIPE_SHADER_TESS_CTRL, &templ);
149    if (tcs) {
150       tcs->is_variant = true;
151       tcs->tcs_key = *key;
152    }
153    return tcs;
154 }
155 
156 d3d12_shader_selector *
d3d12_get_tcs_variant(struct d3d12_context * ctx,struct d3d12_tcs_variant_key * key)157 d3d12_get_tcs_variant(struct d3d12_context *ctx, struct d3d12_tcs_variant_key *key)
158 {
159    uint32_t hash = hash_tcs_variant_key(key);
160    struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(ctx->tcs_variant_cache,
161       hash, key);
162    if (!entry) {
163       d3d12_shader_selector *tcs = create_tess_ctrl_shader_variant(ctx, key);
164       entry = _mesa_hash_table_insert_pre_hashed(ctx->tcs_variant_cache,
165          hash, &tcs->tcs_key, tcs);
166       assert(entry);
167    }
168 
169    return (d3d12_shader_selector *)entry->data;
170 }
171