xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/d3d12/d3d12_pipeline_state.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #ifdef _GAMING_XBOX
25 #ifdef _GAMING_XBOX_SCARLETT
26 #include <d3dx12_xs.h>
27 #else
28 #include <d3dx12_x.h>
29 #endif
30 #endif
31 
32 #include "d3d12_pipeline_state.h"
33 #include "d3d12_compiler.h"
34 #include "d3d12_context.h"
35 #include "d3d12_screen.h"
36 #ifndef _GAMING_XBOX
37 #include <directx/d3dx12_pipeline_state_stream.h>
38 #endif
39 
40 #include "util/hash_table.h"
41 #include "util/set.h"
42 #include "util/u_memory.h"
43 #include "util/u_prim.h"
44 
45 #include <dxguids/dxguids.h>
46 
47 struct d3d12_gfx_pso_entry {
48    struct d3d12_gfx_pipeline_state key;
49    ID3D12PipelineState *pso;
50 };
51 
52 struct d3d12_compute_pso_entry {
53    struct d3d12_compute_pipeline_state key;
54    ID3D12PipelineState *pso;
55 };
56 
57 static const char *
get_semantic_name(int location,int driver_location,unsigned * index)58 get_semantic_name(int location, int driver_location, unsigned *index)
59 {
60    *index = 0; /* Default index */
61 
62    switch (location) {
63 
64    case VARYING_SLOT_POS:
65       return "SV_Position";
66 
67     case VARYING_SLOT_FACE:
68       return "SV_IsFrontFace";
69 
70    case VARYING_SLOT_CLIP_DIST1:
71       *index = 1;
72       FALLTHROUGH;
73    case VARYING_SLOT_CLIP_DIST0:
74       return "SV_ClipDistance";
75 
76    case VARYING_SLOT_CULL_DIST1:
77       *index = 1;
78       FALLTHROUGH;
79    case VARYING_SLOT_CULL_DIST0:
80       return "SV_CullDistance";
81 
82    case VARYING_SLOT_PRIMITIVE_ID:
83       return "SV_PrimitiveID";
84 
85    case VARYING_SLOT_VIEWPORT:
86       return "SV_ViewportArrayIndex";
87 
88    case VARYING_SLOT_LAYER:
89       return "SV_RenderTargetArrayIndex";
90 
91    default: {
92          *index = driver_location;
93          return "TEXCOORD";
94       }
95    }
96 }
97 
98 static nir_variable *
find_so_variable(nir_shader * s,int location,unsigned location_frac,unsigned num_components)99 find_so_variable(nir_shader *s, int location, unsigned location_frac, unsigned num_components)
100 {
101    nir_foreach_variable_with_modes(var, s, nir_var_shader_out) {
102       if (var->data.location != location || var->data.location_frac > location_frac)
103          continue;
104       unsigned var_num_components = var->data.compact ?
105          glsl_get_length(var->type) : glsl_get_components(var->type);
106       if (var->data.location_frac <= location_frac &&
107           var->data.location_frac + var_num_components >= location_frac + num_components)
108          return var;
109    }
110    return nullptr;
111 }
112 
113 static void
fill_so_declaration(const struct pipe_stream_output_info * info,nir_shader * last_vertex_stage,D3D12_SO_DECLARATION_ENTRY * entries,UINT * num_entries,UINT * strides,UINT * num_strides)114 fill_so_declaration(const struct pipe_stream_output_info *info,
115                     nir_shader *last_vertex_stage,
116                     D3D12_SO_DECLARATION_ENTRY *entries, UINT *num_entries,
117                     UINT *strides, UINT *num_strides)
118 {
119    int next_offset[PIPE_MAX_VERTEX_STREAMS] = { 0 };
120 
121    *num_entries = 0;
122 
123    for (unsigned i = 0; i < info->num_outputs; i++) {
124       const struct pipe_stream_output *output = &info->output[i];
125       const int buffer = output->output_buffer;
126       unsigned index;
127 
128       /* Mesa doesn't store entries for gl_SkipComponents in the Outputs[]
129        * array.  Instead, it simply increments DstOffset for the following
130        * input by the number of components that should be skipped.
131        *
132        * DirectX12 requires that we create gap entries.
133        */
134       int skip_components = output->dst_offset - next_offset[buffer];
135 
136       if (skip_components > 0) {
137          entries[*num_entries].Stream = output->stream;
138          entries[*num_entries].SemanticName = NULL;
139          entries[*num_entries].SemanticIndex = 0;
140          entries[*num_entries].StartComponent = 0;
141          entries[*num_entries].ComponentCount = skip_components;
142          entries[*num_entries].OutputSlot = buffer;
143          (*num_entries)++;
144       }
145 
146       next_offset[buffer] = output->dst_offset + output->num_components;
147 
148       entries[*num_entries].Stream = output->stream;
149       nir_variable *var = find_so_variable(last_vertex_stage,
150          output->register_index, output->start_component, output->num_components);
151       assert((var->data.stream & ~NIR_STREAM_PACKED) == output->stream);
152       unsigned location = var->data.location;
153       if (location == VARYING_SLOT_CLIP_DIST0 || location == VARYING_SLOT_CLIP_DIST1) {
154          unsigned component = (location - VARYING_SLOT_CLIP_DIST0) * 4 + var->data.location_frac;
155          if (component >= last_vertex_stage->info.clip_distance_array_size)
156             location = VARYING_SLOT_CULL_DIST0 + (component - last_vertex_stage->info.clip_distance_array_size) / 4;
157       }
158       entries[*num_entries].SemanticName = get_semantic_name(location, var->data.driver_location, &index);
159       entries[*num_entries].SemanticIndex = index;
160       entries[*num_entries].StartComponent = output->start_component - var->data.location_frac;
161       entries[*num_entries].ComponentCount = output->num_components;
162       entries[*num_entries].OutputSlot = buffer;
163       (*num_entries)++;
164    }
165 
166    for (unsigned i = 0; i < PIPE_MAX_VERTEX_STREAMS; i++)
167       strides[i] = info->stride[i] * 4;
168    *num_strides = PIPE_MAX_VERTEX_STREAMS;
169 }
170 
171 static bool
depth_bias(struct d3d12_rasterizer_state * state,enum mesa_prim reduced_prim)172 depth_bias(struct d3d12_rasterizer_state *state, enum mesa_prim reduced_prim)
173 {
174    /* glPolygonOffset is supposed to be only enabled when rendering polygons.
175     * In d3d12 case, all polygons (and quads) are lowered to triangles */
176    if (reduced_prim != MESA_PRIM_TRIANGLES)
177       return false;
178 
179    unsigned fill_mode = state->base.cull_face == PIPE_FACE_FRONT ? state->base.fill_back
180                                                                  : state->base.fill_front;
181 
182    switch (fill_mode) {
183    case PIPE_POLYGON_MODE_FILL:
184       return state->base.offset_tri;
185 
186    case PIPE_POLYGON_MODE_LINE:
187       return state->base.offset_line;
188 
189    case PIPE_POLYGON_MODE_POINT:
190       return state->base.offset_point;
191 
192    default:
193       unreachable("unexpected fill mode");
194    }
195 }
196 
197 static D3D12_PRIMITIVE_TOPOLOGY_TYPE
topology_type(enum mesa_prim reduced_prim)198 topology_type(enum mesa_prim reduced_prim)
199 {
200    switch (reduced_prim) {
201    case MESA_PRIM_POINTS:
202       return D3D12_PRIMITIVE_TOPOLOGY_TYPE_POINT;
203 
204    case MESA_PRIM_LINES:
205       return D3D12_PRIMITIVE_TOPOLOGY_TYPE_LINE;
206 
207    case MESA_PRIM_TRIANGLES:
208       return D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE;
209 
210    case MESA_PRIM_PATCHES:
211       return D3D12_PRIMITIVE_TOPOLOGY_TYPE_PATCH;
212 
213    default:
214       debug_printf("mesa_prim: %s\n", u_prim_name(reduced_prim));
215       unreachable("unexpected enum mesa_prim");
216    }
217 }
218 
219 DXGI_FORMAT
d3d12_rtv_format(struct d3d12_context * ctx,unsigned index)220 d3d12_rtv_format(struct d3d12_context *ctx, unsigned index)
221 {
222    DXGI_FORMAT fmt = ctx->gfx_pipeline_state.rtv_formats[index];
223 
224    if (ctx->gfx_pipeline_state.blend->desc.RenderTarget[0].LogicOpEnable &&
225        !ctx->gfx_pipeline_state.has_float_rtv) {
226       switch (fmt) {
227       case DXGI_FORMAT_R8G8B8A8_SNORM:
228       case DXGI_FORMAT_R8G8B8A8_UNORM:
229       case DXGI_FORMAT_B8G8R8A8_UNORM:
230       case DXGI_FORMAT_B8G8R8X8_UNORM:
231          return DXGI_FORMAT_R8G8B8A8_UINT;
232       default:
233          unreachable("unsupported logic-op format");
234       }
235    }
236 
237    return fmt;
238 }
239 
240 static void
copy_input_attribs(const D3D12_INPUT_ELEMENT_DESC * ves_elements,D3D12_INPUT_ELEMENT_DESC * ia_elements,D3D12_INPUT_LAYOUT_DESC * ia_desc,nir_shader * vs)241 copy_input_attribs(const D3D12_INPUT_ELEMENT_DESC *ves_elements, D3D12_INPUT_ELEMENT_DESC *ia_elements,
242                    D3D12_INPUT_LAYOUT_DESC *ia_desc, nir_shader *vs)
243 {
244    uint32_t vert_input_count = 0;
245    int32_t ves_element_count = -1;
246    int var_loc = -1;
247    nir_foreach_shader_in_variable(var, vs) {
248       assert(vert_input_count < D3D12_VS_INPUT_REGISTER_COUNT);
249 
250       if (var->data.location != var_loc)
251          ves_element_count++;
252       var_loc = var->data.location;
253 
254       for (uint32_t i = 0; i < glsl_count_attribute_slots(var->type, false); ++i) {
255          ia_elements[vert_input_count] = ves_elements[ves_element_count++];
256          ia_elements[vert_input_count].SemanticIndex = vert_input_count;
257          var->data.driver_location = vert_input_count++;
258       }
259       --ves_element_count;
260    }
261 
262    if (vert_input_count > 0) {
263       ia_desc->pInputElementDescs = ia_elements;
264       ia_desc->NumElements = vert_input_count;
265    }
266 }
267 
268 static ID3D12PipelineState *
create_gfx_pipeline_state(struct d3d12_context * ctx)269 create_gfx_pipeline_state(struct d3d12_context *ctx)
270 {
271    struct d3d12_screen *screen = d3d12_screen(ctx->base.screen);
272    struct d3d12_gfx_pipeline_state *state = &ctx->gfx_pipeline_state;
273    enum mesa_prim reduced_prim = state->prim_type == MESA_PRIM_PATCHES ?
274       MESA_PRIM_PATCHES : u_reduced_prim(state->prim_type);
275    D3D12_SO_DECLARATION_ENTRY entries[PIPE_MAX_SO_OUTPUTS];
276    UINT strides[PIPE_MAX_VERTEX_STREAMS] = { 0 };
277    D3D12_INPUT_ELEMENT_DESC input_attribs[PIPE_MAX_ATTRIBS * 4];
278    UINT num_entries = 0, num_strides = 0;
279 
280    CD3DX12_PIPELINE_STATE_STREAM3 pso_desc;
281    pso_desc.pRootSignature = state->root_signature;
282 
283    nir_shader *last_vertex_stage_nir = NULL;
284 
285    if (state->stages[PIPE_SHADER_VERTEX]) {
286       auto shader = state->stages[PIPE_SHADER_VERTEX];
287       pso_desc.VS = D3D12_SHADER_BYTECODE { shader->bytecode, shader->bytecode_length };
288       last_vertex_stage_nir = shader->nir;
289    }
290 
291    if (state->stages[PIPE_SHADER_TESS_CTRL]) {
292       auto shader = state->stages[PIPE_SHADER_TESS_CTRL];
293       pso_desc.HS = D3D12_SHADER_BYTECODE{ shader->bytecode, shader->bytecode_length };
294       last_vertex_stage_nir = shader->nir;
295    }
296 
297    if (state->stages[PIPE_SHADER_TESS_EVAL]) {
298       auto shader = state->stages[PIPE_SHADER_TESS_EVAL];
299       pso_desc.DS = D3D12_SHADER_BYTECODE{ shader->bytecode, shader->bytecode_length };
300       last_vertex_stage_nir = shader->nir;
301    }
302 
303    if (state->stages[PIPE_SHADER_GEOMETRY]) {
304       auto shader = state->stages[PIPE_SHADER_GEOMETRY];
305       pso_desc.GS = D3D12_SHADER_BYTECODE{ shader->bytecode, shader->bytecode_length };
306       last_vertex_stage_nir = shader->nir;
307    }
308 
309    bool last_vertex_stage_writes_pos = (last_vertex_stage_nir->info.outputs_written & VARYING_BIT_POS) != 0;
310    if (last_vertex_stage_writes_pos && state->stages[PIPE_SHADER_FRAGMENT] &&
311        !state->rast->base.rasterizer_discard) {
312       auto shader = state->stages[PIPE_SHADER_FRAGMENT];
313       pso_desc.PS = D3D12_SHADER_BYTECODE{ shader->bytecode, shader->bytecode_length };
314    }
315 
316    if (state->num_so_targets)
317       fill_so_declaration(&state->so_info, last_vertex_stage_nir, entries, &num_entries, strides, &num_strides);
318 
319    D3D12_STREAM_OUTPUT_DESC& stream_output_desc = (D3D12_STREAM_OUTPUT_DESC&)pso_desc.StreamOutput;
320    stream_output_desc.NumEntries = num_entries;
321    stream_output_desc.pSODeclaration = entries;
322    stream_output_desc.RasterizedStream = state->rast->base.rasterizer_discard ? D3D12_SO_NO_RASTERIZED_STREAM : 0;
323    stream_output_desc.NumStrides = num_strides;
324    stream_output_desc.pBufferStrides = strides;
325    pso_desc.StreamOutput = stream_output_desc;
326 
327    D3D12_BLEND_DESC& blend_state = (D3D12_BLEND_DESC&)pso_desc.BlendState;
328    blend_state = state->blend->desc;
329    if (state->has_float_rtv)
330       blend_state.RenderTarget[0].LogicOpEnable = false;
331 
332    (d3d12_depth_stencil_desc_type&)pso_desc.DepthStencilState = state->zsa->desc;
333    pso_desc.SampleMask = state->sample_mask;
334 
335    D3D12_RASTERIZER_DESC& rast = (D3D12_RASTERIZER_DESC&)pso_desc.RasterizerState;
336    rast = state->rast->desc;
337 
338    if (reduced_prim != MESA_PRIM_TRIANGLES)
339       rast.CullMode = D3D12_CULL_MODE_NONE;
340 
341    if (depth_bias(state->rast, reduced_prim)) {
342       rast.DepthBias = state->rast->base.offset_units * 2;
343       rast.DepthBiasClamp = state->rast->base.offset_clamp;
344       rast.SlopeScaledDepthBias = state->rast->base.offset_scale;
345    }
346    D3D12_INPUT_LAYOUT_DESC& input_layout = (D3D12_INPUT_LAYOUT_DESC&)pso_desc.InputLayout;
347    input_layout.pInputElementDescs = state->ves->elements;
348    input_layout.NumElements = state->ves->num_elements;
349    copy_input_attribs(state->ves->elements, input_attribs, &input_layout, state->stages[PIPE_SHADER_VERTEX]->nir);
350 
351    pso_desc.IBStripCutValue = state->ib_strip_cut_value;
352 
353    pso_desc.PrimitiveTopologyType = topology_type(reduced_prim);
354 
355    D3D12_RT_FORMAT_ARRAY& render_targets = (D3D12_RT_FORMAT_ARRAY&)pso_desc.RTVFormats;
356    render_targets.NumRenderTargets = state->num_cbufs;
357    for (unsigned i = 0; i < state->num_cbufs; ++i)
358       render_targets.RTFormats[i] = d3d12_rtv_format(ctx, i);
359    pso_desc.DSVFormat = state->dsv_format;
360 
361    DXGI_SAMPLE_DESC& samples = (DXGI_SAMPLE_DESC&)pso_desc.SampleDesc;
362    samples.Count = state->samples;
363    if (state->num_cbufs || state->dsv_format != DXGI_FORMAT_UNKNOWN) {
364       if (!state->zsa->desc.DepthEnable &&
365           !state->zsa->desc.StencilEnable &&
366           !state->rast->desc.MultisampleEnable &&
367           state->samples > 1) {
368          rast.ForcedSampleCount = 1;
369          pso_desc.DSVFormat = DXGI_FORMAT_UNKNOWN;
370       }
371    }
372 #ifndef _GAMING_XBOX
373    else if (state->samples > 1 &&
374               !(screen->opts19.SupportedSampleCountsWithNoOutputs & (1 << state->samples))) {
375       samples.Count = 1;
376       rast.ForcedSampleCount = state->samples;
377    }
378 #endif
379    samples.Quality = 0;
380 
381    pso_desc.NodeMask = 0;
382 
383    D3D12_CACHED_PIPELINE_STATE& cached_pso = (D3D12_CACHED_PIPELINE_STATE&)pso_desc.CachedPSO;
384    cached_pso.pCachedBlob = NULL;
385    cached_pso.CachedBlobSizeInBytes = 0;
386 
387    pso_desc.Flags = D3D12_PIPELINE_STATE_FLAG_NONE;
388 
389    ID3D12PipelineState *ret;
390 
391    if (screen->opts14.IndependentFrontAndBackStencilRefMaskSupported) {
392       D3D12_PIPELINE_STATE_STREAM_DESC pso_stream_desc{
393           sizeof(pso_desc),
394           &pso_desc
395       };
396 
397       if (FAILED(screen->dev->CreatePipelineState(&pso_stream_desc,
398                                                   IID_PPV_ARGS(&ret)))) {
399          debug_printf("D3D12: CreateGraphicsPipelineState failed!\n");
400          return NULL;
401       }
402    }
403    else {
404       D3D12_GRAPHICS_PIPELINE_STATE_DESC v0desc = pso_desc.GraphicsDescV0();
405       if (FAILED(screen->dev->CreateGraphicsPipelineState(&v0desc,
406                                                        IID_PPV_ARGS(&ret)))) {
407          debug_printf("D3D12: CreateGraphicsPipelineState failed!\n");
408          return NULL;
409       }
410    }
411 
412    return ret;
413 }
414 
415 static uint32_t
hash_gfx_pipeline_state(const void * key)416 hash_gfx_pipeline_state(const void *key)
417 {
418    return _mesa_hash_data(key, sizeof(struct d3d12_gfx_pipeline_state));
419 }
420 
421 static bool
equals_gfx_pipeline_state(const void * a,const void * b)422 equals_gfx_pipeline_state(const void *a, const void *b)
423 {
424    return memcmp(a, b, sizeof(struct d3d12_gfx_pipeline_state)) == 0;
425 }
426 
427 ID3D12PipelineState *
d3d12_get_gfx_pipeline_state(struct d3d12_context * ctx)428 d3d12_get_gfx_pipeline_state(struct d3d12_context *ctx)
429 {
430    uint32_t hash = hash_gfx_pipeline_state(&ctx->gfx_pipeline_state);
431    struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(ctx->pso_cache, hash,
432                                                                  &ctx->gfx_pipeline_state);
433    if (!entry) {
434       struct d3d12_gfx_pso_entry *data = (struct d3d12_gfx_pso_entry *)MALLOC(sizeof(struct d3d12_gfx_pso_entry));
435       if (!data)
436          return NULL;
437 
438       data->key = ctx->gfx_pipeline_state;
439       data->pso = create_gfx_pipeline_state(ctx);
440       if (!data->pso) {
441          FREE(data);
442          return NULL;
443       }
444 
445       entry = _mesa_hash_table_insert_pre_hashed(ctx->pso_cache, hash, &data->key, data);
446       assert(entry);
447    }
448 
449    return ((struct d3d12_gfx_pso_entry *)(entry->data))->pso;
450 }
451 
452 void
d3d12_gfx_pipeline_state_cache_init(struct d3d12_context * ctx)453 d3d12_gfx_pipeline_state_cache_init(struct d3d12_context *ctx)
454 {
455    ctx->pso_cache = _mesa_hash_table_create(NULL, NULL, equals_gfx_pipeline_state);
456 }
457 
458 static void
delete_gfx_entry(struct hash_entry * entry)459 delete_gfx_entry(struct hash_entry *entry)
460 {
461    struct d3d12_gfx_pso_entry *data = (struct d3d12_gfx_pso_entry *)entry->data;
462    data->pso->Release();
463    FREE(data);
464 }
465 
466 static void
remove_gfx_entry(struct d3d12_context * ctx,struct hash_entry * entry)467 remove_gfx_entry(struct d3d12_context *ctx, struct hash_entry *entry)
468 {
469    struct d3d12_gfx_pso_entry *data = (struct d3d12_gfx_pso_entry *)entry->data;
470 
471    if (ctx->current_gfx_pso == data->pso)
472       ctx->current_gfx_pso = NULL;
473    _mesa_hash_table_remove(ctx->pso_cache, entry);
474    delete_gfx_entry(entry);
475 }
476 
477 void
d3d12_gfx_pipeline_state_cache_destroy(struct d3d12_context * ctx)478 d3d12_gfx_pipeline_state_cache_destroy(struct d3d12_context *ctx)
479 {
480    _mesa_hash_table_destroy(ctx->pso_cache, delete_gfx_entry);
481 }
482 
483 void
d3d12_gfx_pipeline_state_cache_invalidate(struct d3d12_context * ctx,const void * state)484 d3d12_gfx_pipeline_state_cache_invalidate(struct d3d12_context *ctx, const void *state)
485 {
486    hash_table_foreach(ctx->pso_cache, entry) {
487       const struct d3d12_gfx_pipeline_state *key = (struct d3d12_gfx_pipeline_state *)entry->key;
488       if (key->blend == state || key->zsa == state || key->rast == state)
489          remove_gfx_entry(ctx, entry);
490    }
491 }
492 
493 void
d3d12_gfx_pipeline_state_cache_invalidate_shader(struct d3d12_context * ctx,enum pipe_shader_type stage,struct d3d12_shader_selector * selector)494 d3d12_gfx_pipeline_state_cache_invalidate_shader(struct d3d12_context *ctx,
495                                                  enum pipe_shader_type stage,
496                                                  struct d3d12_shader_selector *selector)
497 {
498    struct d3d12_shader *shader = selector->first;
499 
500    while (shader) {
501       hash_table_foreach(ctx->pso_cache, entry) {
502          const struct d3d12_gfx_pipeline_state *key = (struct d3d12_gfx_pipeline_state *)entry->key;
503          if (key->stages[stage] == shader)
504             remove_gfx_entry(ctx, entry);
505       }
506       shader = shader->next_variant;
507    }
508 }
509 
510 static ID3D12PipelineState *
create_compute_pipeline_state(struct d3d12_context * ctx)511 create_compute_pipeline_state(struct d3d12_context *ctx)
512 {
513    struct d3d12_screen *screen = d3d12_screen(ctx->base.screen);
514    struct d3d12_compute_pipeline_state *state = &ctx->compute_pipeline_state;
515 
516    D3D12_COMPUTE_PIPELINE_STATE_DESC pso_desc = { 0 };
517    pso_desc.pRootSignature = state->root_signature;
518 
519    if (state->stage) {
520       auto shader = state->stage;
521       pso_desc.CS.BytecodeLength = shader->bytecode_length;
522       pso_desc.CS.pShaderBytecode = shader->bytecode;
523    }
524 
525    pso_desc.NodeMask = 0;
526 
527    pso_desc.CachedPSO.pCachedBlob = NULL;
528    pso_desc.CachedPSO.CachedBlobSizeInBytes = 0;
529 
530    pso_desc.Flags = D3D12_PIPELINE_STATE_FLAG_NONE;
531 
532    ID3D12PipelineState *ret;
533    if (FAILED(screen->dev->CreateComputePipelineState(&pso_desc,
534                                                       IID_PPV_ARGS(&ret)))) {
535       debug_printf("D3D12: CreateComputePipelineState failed!\n");
536       return NULL;
537    }
538 
539    return ret;
540 }
541 
542 static uint32_t
hash_compute_pipeline_state(const void * key)543 hash_compute_pipeline_state(const void *key)
544 {
545    return _mesa_hash_data(key, sizeof(struct d3d12_compute_pipeline_state));
546 }
547 
548 static bool
equals_compute_pipeline_state(const void * a,const void * b)549 equals_compute_pipeline_state(const void *a, const void *b)
550 {
551    return memcmp(a, b, sizeof(struct d3d12_compute_pipeline_state)) == 0;
552 }
553 
554 ID3D12PipelineState *
d3d12_get_compute_pipeline_state(struct d3d12_context * ctx)555 d3d12_get_compute_pipeline_state(struct d3d12_context *ctx)
556 {
557    uint32_t hash = hash_compute_pipeline_state(&ctx->compute_pipeline_state);
558    struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(ctx->compute_pso_cache, hash,
559                                                                  &ctx->compute_pipeline_state);
560    if (!entry) {
561       struct d3d12_compute_pso_entry *data = (struct d3d12_compute_pso_entry *)MALLOC(sizeof(struct d3d12_compute_pso_entry));
562       if (!data)
563          return NULL;
564 
565       data->key = ctx->compute_pipeline_state;
566       data->pso = create_compute_pipeline_state(ctx);
567       if (!data->pso) {
568          FREE(data);
569          return NULL;
570       }
571 
572       entry = _mesa_hash_table_insert_pre_hashed(ctx->compute_pso_cache, hash, &data->key, data);
573       assert(entry);
574    }
575 
576    return ((struct d3d12_compute_pso_entry *)(entry->data))->pso;
577 }
578 
579 void
d3d12_compute_pipeline_state_cache_init(struct d3d12_context * ctx)580 d3d12_compute_pipeline_state_cache_init(struct d3d12_context *ctx)
581 {
582    ctx->compute_pso_cache = _mesa_hash_table_create(NULL, NULL, equals_compute_pipeline_state);
583 }
584 
585 static void
delete_compute_entry(struct hash_entry * entry)586 delete_compute_entry(struct hash_entry *entry)
587 {
588    struct d3d12_compute_pso_entry *data = (struct d3d12_compute_pso_entry *)entry->data;
589    data->pso->Release();
590    FREE(data);
591 }
592 
593 static void
remove_compute_entry(struct d3d12_context * ctx,struct hash_entry * entry)594 remove_compute_entry(struct d3d12_context *ctx, struct hash_entry *entry)
595 {
596    struct d3d12_compute_pso_entry *data = (struct d3d12_compute_pso_entry *)entry->data;
597 
598    if (ctx->current_compute_pso == data->pso)
599       ctx->current_compute_pso = NULL;
600    _mesa_hash_table_remove(ctx->compute_pso_cache, entry);
601    delete_compute_entry(entry);
602 }
603 
604 void
d3d12_compute_pipeline_state_cache_destroy(struct d3d12_context * ctx)605 d3d12_compute_pipeline_state_cache_destroy(struct d3d12_context *ctx)
606 {
607    _mesa_hash_table_destroy(ctx->compute_pso_cache, delete_compute_entry);
608 }
609 
610 void
d3d12_compute_pipeline_state_cache_invalidate_shader(struct d3d12_context * ctx,struct d3d12_shader_selector * selector)611 d3d12_compute_pipeline_state_cache_invalidate_shader(struct d3d12_context *ctx,
612                                                      struct d3d12_shader_selector *selector)
613 {
614    struct d3d12_shader *shader = selector->first;
615 
616    while (shader) {
617       hash_table_foreach(ctx->compute_pso_cache, entry) {
618          const struct d3d12_compute_pipeline_state *key = (struct d3d12_compute_pipeline_state *)entry->key;
619          if (key->stage == shader)
620             remove_compute_entry(ctx, entry);
621       }
622       shader = shader->next_variant;
623    }
624 }
625