1 /*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #ifdef _GAMING_XBOX
25 #ifdef _GAMING_XBOX_SCARLETT
26 #include <d3dx12_xs.h>
27 #else
28 #include <d3dx12_x.h>
29 #endif
30 #endif
31
32 #include "d3d12_pipeline_state.h"
33 #include "d3d12_compiler.h"
34 #include "d3d12_context.h"
35 #include "d3d12_screen.h"
36 #ifndef _GAMING_XBOX
37 #include <directx/d3dx12_pipeline_state_stream.h>
38 #endif
39
40 #include "util/hash_table.h"
41 #include "util/set.h"
42 #include "util/u_memory.h"
43 #include "util/u_prim.h"
44
45 #include <dxguids/dxguids.h>
46
47 struct d3d12_gfx_pso_entry {
48 struct d3d12_gfx_pipeline_state key;
49 ID3D12PipelineState *pso;
50 };
51
52 struct d3d12_compute_pso_entry {
53 struct d3d12_compute_pipeline_state key;
54 ID3D12PipelineState *pso;
55 };
56
57 static const char *
get_semantic_name(int location,int driver_location,unsigned * index)58 get_semantic_name(int location, int driver_location, unsigned *index)
59 {
60 *index = 0; /* Default index */
61
62 switch (location) {
63
64 case VARYING_SLOT_POS:
65 return "SV_Position";
66
67 case VARYING_SLOT_FACE:
68 return "SV_IsFrontFace";
69
70 case VARYING_SLOT_CLIP_DIST1:
71 *index = 1;
72 FALLTHROUGH;
73 case VARYING_SLOT_CLIP_DIST0:
74 return "SV_ClipDistance";
75
76 case VARYING_SLOT_CULL_DIST1:
77 *index = 1;
78 FALLTHROUGH;
79 case VARYING_SLOT_CULL_DIST0:
80 return "SV_CullDistance";
81
82 case VARYING_SLOT_PRIMITIVE_ID:
83 return "SV_PrimitiveID";
84
85 case VARYING_SLOT_VIEWPORT:
86 return "SV_ViewportArrayIndex";
87
88 case VARYING_SLOT_LAYER:
89 return "SV_RenderTargetArrayIndex";
90
91 default: {
92 *index = driver_location;
93 return "TEXCOORD";
94 }
95 }
96 }
97
98 static nir_variable *
find_so_variable(nir_shader * s,int location,unsigned location_frac,unsigned num_components)99 find_so_variable(nir_shader *s, int location, unsigned location_frac, unsigned num_components)
100 {
101 nir_foreach_variable_with_modes(var, s, nir_var_shader_out) {
102 if (var->data.location != location || var->data.location_frac > location_frac)
103 continue;
104 unsigned var_num_components = var->data.compact ?
105 glsl_get_length(var->type) : glsl_get_components(var->type);
106 if (var->data.location_frac <= location_frac &&
107 var->data.location_frac + var_num_components >= location_frac + num_components)
108 return var;
109 }
110 return nullptr;
111 }
112
113 static void
fill_so_declaration(const struct pipe_stream_output_info * info,nir_shader * last_vertex_stage,D3D12_SO_DECLARATION_ENTRY * entries,UINT * num_entries,UINT * strides,UINT * num_strides)114 fill_so_declaration(const struct pipe_stream_output_info *info,
115 nir_shader *last_vertex_stage,
116 D3D12_SO_DECLARATION_ENTRY *entries, UINT *num_entries,
117 UINT *strides, UINT *num_strides)
118 {
119 int next_offset[PIPE_MAX_VERTEX_STREAMS] = { 0 };
120
121 *num_entries = 0;
122
123 for (unsigned i = 0; i < info->num_outputs; i++) {
124 const struct pipe_stream_output *output = &info->output[i];
125 const int buffer = output->output_buffer;
126 unsigned index;
127
128 /* Mesa doesn't store entries for gl_SkipComponents in the Outputs[]
129 * array. Instead, it simply increments DstOffset for the following
130 * input by the number of components that should be skipped.
131 *
132 * DirectX12 requires that we create gap entries.
133 */
134 int skip_components = output->dst_offset - next_offset[buffer];
135
136 if (skip_components > 0) {
137 entries[*num_entries].Stream = output->stream;
138 entries[*num_entries].SemanticName = NULL;
139 entries[*num_entries].SemanticIndex = 0;
140 entries[*num_entries].StartComponent = 0;
141 entries[*num_entries].ComponentCount = skip_components;
142 entries[*num_entries].OutputSlot = buffer;
143 (*num_entries)++;
144 }
145
146 next_offset[buffer] = output->dst_offset + output->num_components;
147
148 entries[*num_entries].Stream = output->stream;
149 nir_variable *var = find_so_variable(last_vertex_stage,
150 output->register_index, output->start_component, output->num_components);
151 assert((var->data.stream & ~NIR_STREAM_PACKED) == output->stream);
152 unsigned location = var->data.location;
153 if (location == VARYING_SLOT_CLIP_DIST0 || location == VARYING_SLOT_CLIP_DIST1) {
154 unsigned component = (location - VARYING_SLOT_CLIP_DIST0) * 4 + var->data.location_frac;
155 if (component >= last_vertex_stage->info.clip_distance_array_size)
156 location = VARYING_SLOT_CULL_DIST0 + (component - last_vertex_stage->info.clip_distance_array_size) / 4;
157 }
158 entries[*num_entries].SemanticName = get_semantic_name(location, var->data.driver_location, &index);
159 entries[*num_entries].SemanticIndex = index;
160 entries[*num_entries].StartComponent = output->start_component - var->data.location_frac;
161 entries[*num_entries].ComponentCount = output->num_components;
162 entries[*num_entries].OutputSlot = buffer;
163 (*num_entries)++;
164 }
165
166 for (unsigned i = 0; i < PIPE_MAX_VERTEX_STREAMS; i++)
167 strides[i] = info->stride[i] * 4;
168 *num_strides = PIPE_MAX_VERTEX_STREAMS;
169 }
170
171 static bool
depth_bias(struct d3d12_rasterizer_state * state,enum mesa_prim reduced_prim)172 depth_bias(struct d3d12_rasterizer_state *state, enum mesa_prim reduced_prim)
173 {
174 /* glPolygonOffset is supposed to be only enabled when rendering polygons.
175 * In d3d12 case, all polygons (and quads) are lowered to triangles */
176 if (reduced_prim != MESA_PRIM_TRIANGLES)
177 return false;
178
179 unsigned fill_mode = state->base.cull_face == PIPE_FACE_FRONT ? state->base.fill_back
180 : state->base.fill_front;
181
182 switch (fill_mode) {
183 case PIPE_POLYGON_MODE_FILL:
184 return state->base.offset_tri;
185
186 case PIPE_POLYGON_MODE_LINE:
187 return state->base.offset_line;
188
189 case PIPE_POLYGON_MODE_POINT:
190 return state->base.offset_point;
191
192 default:
193 unreachable("unexpected fill mode");
194 }
195 }
196
197 static D3D12_PRIMITIVE_TOPOLOGY_TYPE
topology_type(enum mesa_prim reduced_prim)198 topology_type(enum mesa_prim reduced_prim)
199 {
200 switch (reduced_prim) {
201 case MESA_PRIM_POINTS:
202 return D3D12_PRIMITIVE_TOPOLOGY_TYPE_POINT;
203
204 case MESA_PRIM_LINES:
205 return D3D12_PRIMITIVE_TOPOLOGY_TYPE_LINE;
206
207 case MESA_PRIM_TRIANGLES:
208 return D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE;
209
210 case MESA_PRIM_PATCHES:
211 return D3D12_PRIMITIVE_TOPOLOGY_TYPE_PATCH;
212
213 default:
214 debug_printf("mesa_prim: %s\n", u_prim_name(reduced_prim));
215 unreachable("unexpected enum mesa_prim");
216 }
217 }
218
219 DXGI_FORMAT
d3d12_rtv_format(struct d3d12_context * ctx,unsigned index)220 d3d12_rtv_format(struct d3d12_context *ctx, unsigned index)
221 {
222 DXGI_FORMAT fmt = ctx->gfx_pipeline_state.rtv_formats[index];
223
224 if (ctx->gfx_pipeline_state.blend->desc.RenderTarget[0].LogicOpEnable &&
225 !ctx->gfx_pipeline_state.has_float_rtv) {
226 switch (fmt) {
227 case DXGI_FORMAT_R8G8B8A8_SNORM:
228 case DXGI_FORMAT_R8G8B8A8_UNORM:
229 case DXGI_FORMAT_B8G8R8A8_UNORM:
230 case DXGI_FORMAT_B8G8R8X8_UNORM:
231 return DXGI_FORMAT_R8G8B8A8_UINT;
232 default:
233 unreachable("unsupported logic-op format");
234 }
235 }
236
237 return fmt;
238 }
239
240 static void
copy_input_attribs(const D3D12_INPUT_ELEMENT_DESC * ves_elements,D3D12_INPUT_ELEMENT_DESC * ia_elements,D3D12_INPUT_LAYOUT_DESC * ia_desc,nir_shader * vs)241 copy_input_attribs(const D3D12_INPUT_ELEMENT_DESC *ves_elements, D3D12_INPUT_ELEMENT_DESC *ia_elements,
242 D3D12_INPUT_LAYOUT_DESC *ia_desc, nir_shader *vs)
243 {
244 uint32_t vert_input_count = 0;
245 int32_t ves_element_count = -1;
246 int var_loc = -1;
247 nir_foreach_shader_in_variable(var, vs) {
248 assert(vert_input_count < D3D12_VS_INPUT_REGISTER_COUNT);
249
250 if (var->data.location != var_loc)
251 ves_element_count++;
252 var_loc = var->data.location;
253
254 for (uint32_t i = 0; i < glsl_count_attribute_slots(var->type, false); ++i) {
255 ia_elements[vert_input_count] = ves_elements[ves_element_count++];
256 ia_elements[vert_input_count].SemanticIndex = vert_input_count;
257 var->data.driver_location = vert_input_count++;
258 }
259 --ves_element_count;
260 }
261
262 if (vert_input_count > 0) {
263 ia_desc->pInputElementDescs = ia_elements;
264 ia_desc->NumElements = vert_input_count;
265 }
266 }
267
268 static ID3D12PipelineState *
create_gfx_pipeline_state(struct d3d12_context * ctx)269 create_gfx_pipeline_state(struct d3d12_context *ctx)
270 {
271 struct d3d12_screen *screen = d3d12_screen(ctx->base.screen);
272 struct d3d12_gfx_pipeline_state *state = &ctx->gfx_pipeline_state;
273 enum mesa_prim reduced_prim = state->prim_type == MESA_PRIM_PATCHES ?
274 MESA_PRIM_PATCHES : u_reduced_prim(state->prim_type);
275 D3D12_SO_DECLARATION_ENTRY entries[PIPE_MAX_SO_OUTPUTS];
276 UINT strides[PIPE_MAX_VERTEX_STREAMS] = { 0 };
277 D3D12_INPUT_ELEMENT_DESC input_attribs[PIPE_MAX_ATTRIBS * 4];
278 UINT num_entries = 0, num_strides = 0;
279
280 CD3DX12_PIPELINE_STATE_STREAM3 pso_desc;
281 pso_desc.pRootSignature = state->root_signature;
282
283 nir_shader *last_vertex_stage_nir = NULL;
284
285 if (state->stages[PIPE_SHADER_VERTEX]) {
286 auto shader = state->stages[PIPE_SHADER_VERTEX];
287 pso_desc.VS = D3D12_SHADER_BYTECODE { shader->bytecode, shader->bytecode_length };
288 last_vertex_stage_nir = shader->nir;
289 }
290
291 if (state->stages[PIPE_SHADER_TESS_CTRL]) {
292 auto shader = state->stages[PIPE_SHADER_TESS_CTRL];
293 pso_desc.HS = D3D12_SHADER_BYTECODE{ shader->bytecode, shader->bytecode_length };
294 last_vertex_stage_nir = shader->nir;
295 }
296
297 if (state->stages[PIPE_SHADER_TESS_EVAL]) {
298 auto shader = state->stages[PIPE_SHADER_TESS_EVAL];
299 pso_desc.DS = D3D12_SHADER_BYTECODE{ shader->bytecode, shader->bytecode_length };
300 last_vertex_stage_nir = shader->nir;
301 }
302
303 if (state->stages[PIPE_SHADER_GEOMETRY]) {
304 auto shader = state->stages[PIPE_SHADER_GEOMETRY];
305 pso_desc.GS = D3D12_SHADER_BYTECODE{ shader->bytecode, shader->bytecode_length };
306 last_vertex_stage_nir = shader->nir;
307 }
308
309 bool last_vertex_stage_writes_pos = (last_vertex_stage_nir->info.outputs_written & VARYING_BIT_POS) != 0;
310 if (last_vertex_stage_writes_pos && state->stages[PIPE_SHADER_FRAGMENT] &&
311 !state->rast->base.rasterizer_discard) {
312 auto shader = state->stages[PIPE_SHADER_FRAGMENT];
313 pso_desc.PS = D3D12_SHADER_BYTECODE{ shader->bytecode, shader->bytecode_length };
314 }
315
316 if (state->num_so_targets)
317 fill_so_declaration(&state->so_info, last_vertex_stage_nir, entries, &num_entries, strides, &num_strides);
318
319 D3D12_STREAM_OUTPUT_DESC& stream_output_desc = (D3D12_STREAM_OUTPUT_DESC&)pso_desc.StreamOutput;
320 stream_output_desc.NumEntries = num_entries;
321 stream_output_desc.pSODeclaration = entries;
322 stream_output_desc.RasterizedStream = state->rast->base.rasterizer_discard ? D3D12_SO_NO_RASTERIZED_STREAM : 0;
323 stream_output_desc.NumStrides = num_strides;
324 stream_output_desc.pBufferStrides = strides;
325 pso_desc.StreamOutput = stream_output_desc;
326
327 D3D12_BLEND_DESC& blend_state = (D3D12_BLEND_DESC&)pso_desc.BlendState;
328 blend_state = state->blend->desc;
329 if (state->has_float_rtv)
330 blend_state.RenderTarget[0].LogicOpEnable = false;
331
332 (d3d12_depth_stencil_desc_type&)pso_desc.DepthStencilState = state->zsa->desc;
333 pso_desc.SampleMask = state->sample_mask;
334
335 D3D12_RASTERIZER_DESC& rast = (D3D12_RASTERIZER_DESC&)pso_desc.RasterizerState;
336 rast = state->rast->desc;
337
338 if (reduced_prim != MESA_PRIM_TRIANGLES)
339 rast.CullMode = D3D12_CULL_MODE_NONE;
340
341 if (depth_bias(state->rast, reduced_prim)) {
342 rast.DepthBias = state->rast->base.offset_units * 2;
343 rast.DepthBiasClamp = state->rast->base.offset_clamp;
344 rast.SlopeScaledDepthBias = state->rast->base.offset_scale;
345 }
346 D3D12_INPUT_LAYOUT_DESC& input_layout = (D3D12_INPUT_LAYOUT_DESC&)pso_desc.InputLayout;
347 input_layout.pInputElementDescs = state->ves->elements;
348 input_layout.NumElements = state->ves->num_elements;
349 copy_input_attribs(state->ves->elements, input_attribs, &input_layout, state->stages[PIPE_SHADER_VERTEX]->nir);
350
351 pso_desc.IBStripCutValue = state->ib_strip_cut_value;
352
353 pso_desc.PrimitiveTopologyType = topology_type(reduced_prim);
354
355 D3D12_RT_FORMAT_ARRAY& render_targets = (D3D12_RT_FORMAT_ARRAY&)pso_desc.RTVFormats;
356 render_targets.NumRenderTargets = state->num_cbufs;
357 for (unsigned i = 0; i < state->num_cbufs; ++i)
358 render_targets.RTFormats[i] = d3d12_rtv_format(ctx, i);
359 pso_desc.DSVFormat = state->dsv_format;
360
361 DXGI_SAMPLE_DESC& samples = (DXGI_SAMPLE_DESC&)pso_desc.SampleDesc;
362 samples.Count = state->samples;
363 if (state->num_cbufs || state->dsv_format != DXGI_FORMAT_UNKNOWN) {
364 if (!state->zsa->desc.DepthEnable &&
365 !state->zsa->desc.StencilEnable &&
366 !state->rast->desc.MultisampleEnable &&
367 state->samples > 1) {
368 rast.ForcedSampleCount = 1;
369 pso_desc.DSVFormat = DXGI_FORMAT_UNKNOWN;
370 }
371 }
372 #ifndef _GAMING_XBOX
373 else if (state->samples > 1 &&
374 !(screen->opts19.SupportedSampleCountsWithNoOutputs & (1 << state->samples))) {
375 samples.Count = 1;
376 rast.ForcedSampleCount = state->samples;
377 }
378 #endif
379 samples.Quality = 0;
380
381 pso_desc.NodeMask = 0;
382
383 D3D12_CACHED_PIPELINE_STATE& cached_pso = (D3D12_CACHED_PIPELINE_STATE&)pso_desc.CachedPSO;
384 cached_pso.pCachedBlob = NULL;
385 cached_pso.CachedBlobSizeInBytes = 0;
386
387 pso_desc.Flags = D3D12_PIPELINE_STATE_FLAG_NONE;
388
389 ID3D12PipelineState *ret;
390
391 if (screen->opts14.IndependentFrontAndBackStencilRefMaskSupported) {
392 D3D12_PIPELINE_STATE_STREAM_DESC pso_stream_desc{
393 sizeof(pso_desc),
394 &pso_desc
395 };
396
397 if (FAILED(screen->dev->CreatePipelineState(&pso_stream_desc,
398 IID_PPV_ARGS(&ret)))) {
399 debug_printf("D3D12: CreateGraphicsPipelineState failed!\n");
400 return NULL;
401 }
402 }
403 else {
404 D3D12_GRAPHICS_PIPELINE_STATE_DESC v0desc = pso_desc.GraphicsDescV0();
405 if (FAILED(screen->dev->CreateGraphicsPipelineState(&v0desc,
406 IID_PPV_ARGS(&ret)))) {
407 debug_printf("D3D12: CreateGraphicsPipelineState failed!\n");
408 return NULL;
409 }
410 }
411
412 return ret;
413 }
414
415 static uint32_t
hash_gfx_pipeline_state(const void * key)416 hash_gfx_pipeline_state(const void *key)
417 {
418 return _mesa_hash_data(key, sizeof(struct d3d12_gfx_pipeline_state));
419 }
420
421 static bool
equals_gfx_pipeline_state(const void * a,const void * b)422 equals_gfx_pipeline_state(const void *a, const void *b)
423 {
424 return memcmp(a, b, sizeof(struct d3d12_gfx_pipeline_state)) == 0;
425 }
426
427 ID3D12PipelineState *
d3d12_get_gfx_pipeline_state(struct d3d12_context * ctx)428 d3d12_get_gfx_pipeline_state(struct d3d12_context *ctx)
429 {
430 uint32_t hash = hash_gfx_pipeline_state(&ctx->gfx_pipeline_state);
431 struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(ctx->pso_cache, hash,
432 &ctx->gfx_pipeline_state);
433 if (!entry) {
434 struct d3d12_gfx_pso_entry *data = (struct d3d12_gfx_pso_entry *)MALLOC(sizeof(struct d3d12_gfx_pso_entry));
435 if (!data)
436 return NULL;
437
438 data->key = ctx->gfx_pipeline_state;
439 data->pso = create_gfx_pipeline_state(ctx);
440 if (!data->pso) {
441 FREE(data);
442 return NULL;
443 }
444
445 entry = _mesa_hash_table_insert_pre_hashed(ctx->pso_cache, hash, &data->key, data);
446 assert(entry);
447 }
448
449 return ((struct d3d12_gfx_pso_entry *)(entry->data))->pso;
450 }
451
452 void
d3d12_gfx_pipeline_state_cache_init(struct d3d12_context * ctx)453 d3d12_gfx_pipeline_state_cache_init(struct d3d12_context *ctx)
454 {
455 ctx->pso_cache = _mesa_hash_table_create(NULL, NULL, equals_gfx_pipeline_state);
456 }
457
458 static void
delete_gfx_entry(struct hash_entry * entry)459 delete_gfx_entry(struct hash_entry *entry)
460 {
461 struct d3d12_gfx_pso_entry *data = (struct d3d12_gfx_pso_entry *)entry->data;
462 data->pso->Release();
463 FREE(data);
464 }
465
466 static void
remove_gfx_entry(struct d3d12_context * ctx,struct hash_entry * entry)467 remove_gfx_entry(struct d3d12_context *ctx, struct hash_entry *entry)
468 {
469 struct d3d12_gfx_pso_entry *data = (struct d3d12_gfx_pso_entry *)entry->data;
470
471 if (ctx->current_gfx_pso == data->pso)
472 ctx->current_gfx_pso = NULL;
473 _mesa_hash_table_remove(ctx->pso_cache, entry);
474 delete_gfx_entry(entry);
475 }
476
477 void
d3d12_gfx_pipeline_state_cache_destroy(struct d3d12_context * ctx)478 d3d12_gfx_pipeline_state_cache_destroy(struct d3d12_context *ctx)
479 {
480 _mesa_hash_table_destroy(ctx->pso_cache, delete_gfx_entry);
481 }
482
483 void
d3d12_gfx_pipeline_state_cache_invalidate(struct d3d12_context * ctx,const void * state)484 d3d12_gfx_pipeline_state_cache_invalidate(struct d3d12_context *ctx, const void *state)
485 {
486 hash_table_foreach(ctx->pso_cache, entry) {
487 const struct d3d12_gfx_pipeline_state *key = (struct d3d12_gfx_pipeline_state *)entry->key;
488 if (key->blend == state || key->zsa == state || key->rast == state)
489 remove_gfx_entry(ctx, entry);
490 }
491 }
492
493 void
d3d12_gfx_pipeline_state_cache_invalidate_shader(struct d3d12_context * ctx,enum pipe_shader_type stage,struct d3d12_shader_selector * selector)494 d3d12_gfx_pipeline_state_cache_invalidate_shader(struct d3d12_context *ctx,
495 enum pipe_shader_type stage,
496 struct d3d12_shader_selector *selector)
497 {
498 struct d3d12_shader *shader = selector->first;
499
500 while (shader) {
501 hash_table_foreach(ctx->pso_cache, entry) {
502 const struct d3d12_gfx_pipeline_state *key = (struct d3d12_gfx_pipeline_state *)entry->key;
503 if (key->stages[stage] == shader)
504 remove_gfx_entry(ctx, entry);
505 }
506 shader = shader->next_variant;
507 }
508 }
509
510 static ID3D12PipelineState *
create_compute_pipeline_state(struct d3d12_context * ctx)511 create_compute_pipeline_state(struct d3d12_context *ctx)
512 {
513 struct d3d12_screen *screen = d3d12_screen(ctx->base.screen);
514 struct d3d12_compute_pipeline_state *state = &ctx->compute_pipeline_state;
515
516 D3D12_COMPUTE_PIPELINE_STATE_DESC pso_desc = { 0 };
517 pso_desc.pRootSignature = state->root_signature;
518
519 if (state->stage) {
520 auto shader = state->stage;
521 pso_desc.CS.BytecodeLength = shader->bytecode_length;
522 pso_desc.CS.pShaderBytecode = shader->bytecode;
523 }
524
525 pso_desc.NodeMask = 0;
526
527 pso_desc.CachedPSO.pCachedBlob = NULL;
528 pso_desc.CachedPSO.CachedBlobSizeInBytes = 0;
529
530 pso_desc.Flags = D3D12_PIPELINE_STATE_FLAG_NONE;
531
532 ID3D12PipelineState *ret;
533 if (FAILED(screen->dev->CreateComputePipelineState(&pso_desc,
534 IID_PPV_ARGS(&ret)))) {
535 debug_printf("D3D12: CreateComputePipelineState failed!\n");
536 return NULL;
537 }
538
539 return ret;
540 }
541
542 static uint32_t
hash_compute_pipeline_state(const void * key)543 hash_compute_pipeline_state(const void *key)
544 {
545 return _mesa_hash_data(key, sizeof(struct d3d12_compute_pipeline_state));
546 }
547
548 static bool
equals_compute_pipeline_state(const void * a,const void * b)549 equals_compute_pipeline_state(const void *a, const void *b)
550 {
551 return memcmp(a, b, sizeof(struct d3d12_compute_pipeline_state)) == 0;
552 }
553
554 ID3D12PipelineState *
d3d12_get_compute_pipeline_state(struct d3d12_context * ctx)555 d3d12_get_compute_pipeline_state(struct d3d12_context *ctx)
556 {
557 uint32_t hash = hash_compute_pipeline_state(&ctx->compute_pipeline_state);
558 struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(ctx->compute_pso_cache, hash,
559 &ctx->compute_pipeline_state);
560 if (!entry) {
561 struct d3d12_compute_pso_entry *data = (struct d3d12_compute_pso_entry *)MALLOC(sizeof(struct d3d12_compute_pso_entry));
562 if (!data)
563 return NULL;
564
565 data->key = ctx->compute_pipeline_state;
566 data->pso = create_compute_pipeline_state(ctx);
567 if (!data->pso) {
568 FREE(data);
569 return NULL;
570 }
571
572 entry = _mesa_hash_table_insert_pre_hashed(ctx->compute_pso_cache, hash, &data->key, data);
573 assert(entry);
574 }
575
576 return ((struct d3d12_compute_pso_entry *)(entry->data))->pso;
577 }
578
579 void
d3d12_compute_pipeline_state_cache_init(struct d3d12_context * ctx)580 d3d12_compute_pipeline_state_cache_init(struct d3d12_context *ctx)
581 {
582 ctx->compute_pso_cache = _mesa_hash_table_create(NULL, NULL, equals_compute_pipeline_state);
583 }
584
585 static void
delete_compute_entry(struct hash_entry * entry)586 delete_compute_entry(struct hash_entry *entry)
587 {
588 struct d3d12_compute_pso_entry *data = (struct d3d12_compute_pso_entry *)entry->data;
589 data->pso->Release();
590 FREE(data);
591 }
592
593 static void
remove_compute_entry(struct d3d12_context * ctx,struct hash_entry * entry)594 remove_compute_entry(struct d3d12_context *ctx, struct hash_entry *entry)
595 {
596 struct d3d12_compute_pso_entry *data = (struct d3d12_compute_pso_entry *)entry->data;
597
598 if (ctx->current_compute_pso == data->pso)
599 ctx->current_compute_pso = NULL;
600 _mesa_hash_table_remove(ctx->compute_pso_cache, entry);
601 delete_compute_entry(entry);
602 }
603
604 void
d3d12_compute_pipeline_state_cache_destroy(struct d3d12_context * ctx)605 d3d12_compute_pipeline_state_cache_destroy(struct d3d12_context *ctx)
606 {
607 _mesa_hash_table_destroy(ctx->compute_pso_cache, delete_compute_entry);
608 }
609
610 void
d3d12_compute_pipeline_state_cache_invalidate_shader(struct d3d12_context * ctx,struct d3d12_shader_selector * selector)611 d3d12_compute_pipeline_state_cache_invalidate_shader(struct d3d12_context *ctx,
612 struct d3d12_shader_selector *selector)
613 {
614 struct d3d12_shader *shader = selector->first;
615
616 while (shader) {
617 hash_table_foreach(ctx->compute_pso_cache, entry) {
618 const struct d3d12_compute_pipeline_state *key = (struct d3d12_compute_pipeline_state *)entry->key;
619 if (key->stage == shader)
620 remove_compute_entry(ctx, entry);
621 }
622 shader = shader->next_variant;
623 }
624 }
625