xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/radv_shader_object.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2024 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "vk_log.h"
8 
9 #include "radv_device.h"
10 #include "radv_entrypoints.h"
11 #include "radv_physical_device.h"
12 #include "radv_pipeline_cache.h"
13 #include "radv_pipeline_compute.h"
14 #include "radv_pipeline_graphics.h"
15 #include "radv_shader_object.h"
16 
17 static void
radv_shader_object_destroy_variant(struct radv_device * device,VkShaderCodeTypeEXT code_type,struct radv_shader * shader,struct radv_shader_binary * binary)18 radv_shader_object_destroy_variant(struct radv_device *device, VkShaderCodeTypeEXT code_type,
19                                    struct radv_shader *shader, struct radv_shader_binary *binary)
20 {
21    if (shader)
22       radv_shader_unref(device, shader);
23 
24    if (code_type == VK_SHADER_CODE_TYPE_SPIRV_EXT)
25       free(binary);
26 }
27 
28 static void
radv_shader_object_destroy(struct radv_device * device,struct radv_shader_object * shader_obj,const VkAllocationCallbacks * pAllocator)29 radv_shader_object_destroy(struct radv_device *device, struct radv_shader_object *shader_obj,
30                            const VkAllocationCallbacks *pAllocator)
31 {
32    radv_shader_object_destroy_variant(device, shader_obj->code_type, shader_obj->as_ls.shader,
33                                       shader_obj->as_ls.binary);
34    radv_shader_object_destroy_variant(device, shader_obj->code_type, shader_obj->as_es.shader,
35                                       shader_obj->as_es.binary);
36    radv_shader_object_destroy_variant(device, shader_obj->code_type, shader_obj->gs.copy_shader,
37                                       shader_obj->gs.copy_binary);
38    radv_shader_object_destroy_variant(device, shader_obj->code_type, shader_obj->shader, shader_obj->binary);
39 
40    vk_object_base_finish(&shader_obj->base);
41    vk_free2(&device->vk.alloc, pAllocator, shader_obj);
42 }
43 
44 VKAPI_ATTR void VKAPI_CALL
radv_DestroyShaderEXT(VkDevice _device,VkShaderEXT shader,const VkAllocationCallbacks * pAllocator)45 radv_DestroyShaderEXT(VkDevice _device, VkShaderEXT shader, const VkAllocationCallbacks *pAllocator)
46 {
47    VK_FROM_HANDLE(radv_device, device, _device);
48    VK_FROM_HANDLE(radv_shader_object, shader_obj, shader);
49 
50    if (!shader)
51       return;
52 
53    radv_shader_object_destroy(device, shader_obj, pAllocator);
54 }
55 
56 static void
radv_shader_stage_init(const VkShaderCreateInfoEXT * sinfo,struct radv_shader_stage * out_stage)57 radv_shader_stage_init(const VkShaderCreateInfoEXT *sinfo, struct radv_shader_stage *out_stage)
58 {
59    uint16_t dynamic_shader_stages = 0;
60 
61    memset(out_stage, 0, sizeof(*out_stage));
62 
63    out_stage->stage = vk_to_mesa_shader_stage(sinfo->stage);
64    out_stage->next_stage = MESA_SHADER_NONE;
65    out_stage->entrypoint = sinfo->pName;
66    out_stage->spec_info = sinfo->pSpecializationInfo;
67    out_stage->feedback.flags = VK_PIPELINE_CREATION_FEEDBACK_VALID_BIT;
68    out_stage->spirv.data = (const char *)sinfo->pCode;
69    out_stage->spirv.size = sinfo->codeSize;
70 
71    for (uint32_t i = 0; i < sinfo->setLayoutCount; i++) {
72       VK_FROM_HANDLE(radv_descriptor_set_layout, set_layout, sinfo->pSetLayouts[i]);
73 
74       if (set_layout == NULL)
75          continue;
76 
77       out_stage->layout.num_sets = MAX2(i + 1, out_stage->layout.num_sets);
78       out_stage->layout.set[i].layout = set_layout;
79 
80       out_stage->layout.set[i].dynamic_offset_start = out_stage->layout.dynamic_offset_count;
81       out_stage->layout.dynamic_offset_count += set_layout->dynamic_offset_count;
82 
83       dynamic_shader_stages |= set_layout->dynamic_shader_stages;
84    }
85 
86    if (out_stage->layout.dynamic_offset_count && (dynamic_shader_stages & sinfo->stage)) {
87       out_stage->layout.use_dynamic_descriptors = true;
88    }
89 
90    for (unsigned i = 0; i < sinfo->pushConstantRangeCount; ++i) {
91       const VkPushConstantRange *range = sinfo->pPushConstantRanges + i;
92       out_stage->layout.push_constant_size = MAX2(out_stage->layout.push_constant_size, range->offset + range->size);
93    }
94 
95    out_stage->layout.push_constant_size = align(out_stage->layout.push_constant_size, 16);
96 
97    const VkShaderRequiredSubgroupSizeCreateInfoEXT *const subgroup_size =
98       vk_find_struct_const(sinfo->pNext, SHADER_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT);
99 
100    if (subgroup_size) {
101       if (subgroup_size->requiredSubgroupSize == 32)
102          out_stage->key.subgroup_required_size = RADV_REQUIRED_WAVE32;
103       else if (subgroup_size->requiredSubgroupSize == 64)
104          out_stage->key.subgroup_required_size = RADV_REQUIRED_WAVE64;
105       else
106          unreachable("Unsupported required subgroup size.");
107    }
108 
109    if (sinfo->flags & VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT) {
110       out_stage->key.subgroup_require_full = 1;
111    }
112 
113    if (out_stage->stage == MESA_SHADER_MESH) {
114       out_stage->key.has_task_shader = !(sinfo->flags & VK_SHADER_CREATE_NO_TASK_SHADER_BIT_EXT);
115    }
116 }
117 
118 static VkResult
radv_shader_object_init_graphics(struct radv_shader_object * shader_obj,struct radv_device * device,const VkShaderCreateInfoEXT * pCreateInfo)119 radv_shader_object_init_graphics(struct radv_shader_object *shader_obj, struct radv_device *device,
120                                  const VkShaderCreateInfoEXT *pCreateInfo)
121 {
122    const struct radv_physical_device *pdev = radv_device_physical(device);
123    gl_shader_stage stage = vk_to_mesa_shader_stage(pCreateInfo->stage);
124    struct radv_shader_stage stages[MESA_VULKAN_SHADER_STAGES];
125 
126    for (unsigned i = 0; i < MESA_VULKAN_SHADER_STAGES; i++) {
127       stages[i].entrypoint = NULL;
128       stages[i].nir = NULL;
129       stages[i].spirv.size = 0;
130       stages[i].next_stage = MESA_SHADER_NONE;
131    }
132 
133    radv_shader_stage_init(pCreateInfo, &stages[stage]);
134 
135    struct radv_graphics_state_key gfx_state = {0};
136 
137    gfx_state.vs.has_prolog = true;
138    gfx_state.ps.has_epilog = true;
139    gfx_state.dynamic_rasterization_samples = true;
140    gfx_state.unknown_rast_prim = true;
141    gfx_state.dynamic_provoking_vtx_mode = true;
142    gfx_state.dynamic_line_rast_mode = true;
143 
144    if (pdev->info.gfx_level >= GFX11)
145       gfx_state.ps.exports_mrtz_via_epilog = true;
146 
147    for (uint32_t i = 0; i < MAX_RTS; i++)
148       gfx_state.ps.epilog.color_map[i] = i;
149 
150    struct radv_shader *shader = NULL;
151    struct radv_shader_binary *binary = NULL;
152 
153    VkShaderStageFlags next_stages = pCreateInfo->nextStage;
154    if (!next_stages) {
155       /* When next stage is 0, gather all valid next stages. */
156       switch (pCreateInfo->stage) {
157       case VK_SHADER_STAGE_VERTEX_BIT:
158          next_stages |=
159             VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_GEOMETRY_BIT | VK_SHADER_STAGE_FRAGMENT_BIT;
160          break;
161       case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
162          next_stages |= VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
163          break;
164       case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
165          next_stages |= VK_SHADER_STAGE_GEOMETRY_BIT | VK_SHADER_STAGE_FRAGMENT_BIT;
166          break;
167       case VK_SHADER_STAGE_GEOMETRY_BIT:
168       case VK_SHADER_STAGE_MESH_BIT_EXT:
169          next_stages |= VK_SHADER_STAGE_FRAGMENT_BIT;
170          break;
171       case VK_SHADER_STAGE_TASK_BIT_EXT:
172          next_stages |= VK_SHADER_STAGE_MESH_BIT_EXT;
173          break;
174       case VK_SHADER_STAGE_FRAGMENT_BIT:
175       case VK_SHADER_STAGE_COMPUTE_BIT:
176          break;
177       default:
178          unreachable("Invalid shader stage");
179       }
180    }
181 
182    if (!next_stages) {
183       struct radv_shader *shaders[MESA_VULKAN_SHADER_STAGES] = {NULL};
184       struct radv_shader_binary *binaries[MESA_VULKAN_SHADER_STAGES] = {NULL};
185 
186       radv_graphics_shaders_compile(device, NULL, stages, &gfx_state, true, false, false, NULL, false, shaders,
187                                     binaries, &shader_obj->gs.copy_shader, &shader_obj->gs.copy_binary);
188 
189       shader = shaders[stage];
190       binary = binaries[stage];
191 
192       ralloc_free(stages[stage].nir);
193 
194       shader_obj->shader = shader;
195       shader_obj->binary = binary;
196    } else {
197       radv_foreach_stage(next_stage, next_stages)
198       {
199          struct radv_shader *shaders[MESA_VULKAN_SHADER_STAGES] = {NULL};
200          struct radv_shader_binary *binaries[MESA_VULKAN_SHADER_STAGES] = {NULL};
201 
202          radv_shader_stage_init(pCreateInfo, &stages[stage]);
203          stages[stage].next_stage = next_stage;
204 
205          radv_graphics_shaders_compile(device, NULL, stages, &gfx_state, true, false, false, NULL, false, shaders,
206                                        binaries, &shader_obj->gs.copy_shader, &shader_obj->gs.copy_binary);
207 
208          shader = shaders[stage];
209          binary = binaries[stage];
210 
211          ralloc_free(stages[stage].nir);
212 
213          if (stage == MESA_SHADER_VERTEX) {
214             if (next_stage == MESA_SHADER_TESS_CTRL) {
215                shader_obj->as_ls.shader = shader;
216                shader_obj->as_ls.binary = binary;
217             } else if (next_stage == MESA_SHADER_GEOMETRY) {
218                shader_obj->as_es.shader = shader;
219                shader_obj->as_es.binary = binary;
220             } else {
221                shader_obj->shader = shader;
222                shader_obj->binary = binary;
223             }
224          } else if (stage == MESA_SHADER_TESS_EVAL) {
225             if (next_stage == MESA_SHADER_GEOMETRY) {
226                shader_obj->as_es.shader = shader;
227                shader_obj->as_es.binary = binary;
228             } else {
229                shader_obj->shader = shader;
230                shader_obj->binary = binary;
231             }
232          } else {
233             shader_obj->shader = shader;
234             shader_obj->binary = binary;
235          }
236       }
237    }
238 
239    return VK_SUCCESS;
240 }
241 
242 static VkResult
radv_shader_object_init_compute(struct radv_shader_object * shader_obj,struct radv_device * device,const VkShaderCreateInfoEXT * pCreateInfo)243 radv_shader_object_init_compute(struct radv_shader_object *shader_obj, struct radv_device *device,
244                                 const VkShaderCreateInfoEXT *pCreateInfo)
245 {
246    struct radv_shader_binary *cs_binary;
247    struct radv_shader_stage stage = {0};
248 
249    radv_shader_stage_init(pCreateInfo, &stage);
250 
251    struct radv_shader *cs_shader = radv_compile_cs(device, NULL, &stage, true, false, false, &cs_binary);
252 
253    ralloc_free(stage.nir);
254 
255    shader_obj->shader = cs_shader;
256    shader_obj->binary = cs_binary;
257 
258    return VK_SUCCESS;
259 }
260 
261 static void
radv_get_shader_layout(const VkShaderCreateInfoEXT * pCreateInfo,struct radv_shader_layout * layout)262 radv_get_shader_layout(const VkShaderCreateInfoEXT *pCreateInfo, struct radv_shader_layout *layout)
263 {
264    uint16_t dynamic_shader_stages = 0;
265 
266    memset(layout, 0, sizeof(*layout));
267 
268    layout->dynamic_offset_count = 0;
269 
270    for (uint32_t i = 0; i < pCreateInfo->setLayoutCount; i++) {
271       VK_FROM_HANDLE(radv_descriptor_set_layout, set_layout, pCreateInfo->pSetLayouts[i]);
272 
273       if (set_layout == NULL)
274          continue;
275 
276       layout->num_sets = MAX2(i + 1, layout->num_sets);
277 
278       layout->set[i].layout = set_layout;
279       layout->set[i].dynamic_offset_start = layout->dynamic_offset_count;
280 
281       layout->dynamic_offset_count += set_layout->dynamic_offset_count;
282       dynamic_shader_stages |= set_layout->dynamic_shader_stages;
283    }
284 
285    if (layout->dynamic_offset_count && (dynamic_shader_stages & pCreateInfo->stage)) {
286       layout->use_dynamic_descriptors = true;
287    }
288 
289    layout->push_constant_size = 0;
290 
291    for (unsigned i = 0; i < pCreateInfo->pushConstantRangeCount; ++i) {
292       const VkPushConstantRange *range = pCreateInfo->pPushConstantRanges + i;
293       layout->push_constant_size = MAX2(layout->push_constant_size, range->offset + range->size);
294    }
295 
296    layout->push_constant_size = align(layout->push_constant_size, 16);
297 }
298 
299 static VkResult
radv_shader_object_init_binary(struct radv_device * device,struct blob_reader * blob,struct radv_shader ** shader_out,struct radv_shader_binary ** binary_out)300 radv_shader_object_init_binary(struct radv_device *device, struct blob_reader *blob, struct radv_shader **shader_out,
301                                struct radv_shader_binary **binary_out)
302 {
303    const char *binary_sha1 = blob_read_bytes(blob, SHA1_DIGEST_LENGTH);
304    const uint32_t binary_size = blob_read_uint32(blob);
305    const struct radv_shader_binary *binary = blob_read_bytes(blob, binary_size);
306    unsigned char sha1[SHA1_DIGEST_LENGTH];
307 
308    _mesa_sha1_compute(binary, binary->total_size, sha1);
309    if (memcmp(sha1, binary_sha1, SHA1_DIGEST_LENGTH))
310       return VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT;
311 
312    *shader_out = radv_shader_create(device, NULL, binary, true);
313    *binary_out = (struct radv_shader_binary *)binary;
314 
315    return VK_SUCCESS;
316 }
317 
318 static VkResult
radv_shader_object_init(struct radv_shader_object * shader_obj,struct radv_device * device,const VkShaderCreateInfoEXT * pCreateInfo)319 radv_shader_object_init(struct radv_shader_object *shader_obj, struct radv_device *device,
320                         const VkShaderCreateInfoEXT *pCreateInfo)
321 {
322    const struct radv_physical_device *pdev = radv_device_physical(device);
323    struct radv_shader_layout layout;
324    VkResult result;
325 
326    radv_get_shader_layout(pCreateInfo, &layout);
327 
328    shader_obj->stage = vk_to_mesa_shader_stage(pCreateInfo->stage);
329    shader_obj->code_type = pCreateInfo->codeType;
330    shader_obj->push_constant_size = layout.push_constant_size;
331    shader_obj->dynamic_offset_count = layout.dynamic_offset_count;
332 
333    if (pCreateInfo->codeType == VK_SHADER_CODE_TYPE_BINARY_EXT) {
334       if (pCreateInfo->codeSize < VK_UUID_SIZE + sizeof(uint32_t)) {
335          return VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT;
336       }
337 
338       struct blob_reader blob;
339       blob_reader_init(&blob, pCreateInfo->pCode, pCreateInfo->codeSize);
340 
341       const uint8_t *cache_uuid = blob_read_bytes(&blob, VK_UUID_SIZE);
342 
343       if (memcmp(cache_uuid, pdev->cache_uuid, VK_UUID_SIZE))
344          return VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT;
345 
346       const bool has_main_binary = blob_read_uint32(&blob);
347 
348       if (has_main_binary) {
349          result = radv_shader_object_init_binary(device, &blob, &shader_obj->shader, &shader_obj->binary);
350          if (result != VK_SUCCESS)
351             return result;
352       }
353 
354       if (shader_obj->stage == MESA_SHADER_VERTEX) {
355          const bool has_es_binary = blob_read_uint32(&blob);
356          if (has_es_binary) {
357             result =
358                radv_shader_object_init_binary(device, &blob, &shader_obj->as_es.shader, &shader_obj->as_es.binary);
359             if (result != VK_SUCCESS)
360                return result;
361          }
362 
363          const bool has_ls_binary = blob_read_uint32(&blob);
364          if (has_ls_binary) {
365             result =
366                radv_shader_object_init_binary(device, &blob, &shader_obj->as_ls.shader, &shader_obj->as_ls.binary);
367             if (result != VK_SUCCESS)
368                return result;
369          }
370       } else if (shader_obj->stage == MESA_SHADER_TESS_EVAL) {
371          const bool has_es_binary = blob_read_uint32(&blob);
372          if (has_es_binary) {
373             result =
374                radv_shader_object_init_binary(device, &blob, &shader_obj->as_es.shader, &shader_obj->as_es.binary);
375             if (result != VK_SUCCESS)
376                return result;
377          }
378       } else if (shader_obj->stage == MESA_SHADER_GEOMETRY) {
379          const bool has_gs_copy_binary = blob_read_uint32(&blob);
380          if (has_gs_copy_binary) {
381             result =
382                radv_shader_object_init_binary(device, &blob, &shader_obj->gs.copy_shader, &shader_obj->gs.copy_binary);
383             if (result != VK_SUCCESS)
384                return result;
385          }
386       }
387    } else {
388       assert(pCreateInfo->codeType == VK_SHADER_CODE_TYPE_SPIRV_EXT);
389 
390       if (pCreateInfo->stage == VK_SHADER_STAGE_COMPUTE_BIT) {
391          result = radv_shader_object_init_compute(shader_obj, device, pCreateInfo);
392       } else {
393          result = radv_shader_object_init_graphics(shader_obj, device, pCreateInfo);
394       }
395 
396       if (result != VK_SUCCESS)
397          return result;
398    }
399 
400    return VK_SUCCESS;
401 }
402 
403 static VkResult
radv_shader_object_create(VkDevice _device,const VkShaderCreateInfoEXT * pCreateInfo,const VkAllocationCallbacks * pAllocator,VkShaderEXT * pShader)404 radv_shader_object_create(VkDevice _device, const VkShaderCreateInfoEXT *pCreateInfo,
405                           const VkAllocationCallbacks *pAllocator, VkShaderEXT *pShader)
406 {
407    VK_FROM_HANDLE(radv_device, device, _device);
408    struct radv_shader_object *shader_obj;
409    VkResult result;
410 
411    shader_obj = vk_zalloc2(&device->vk.alloc, pAllocator, sizeof(*shader_obj), 8, VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
412    if (shader_obj == NULL)
413       return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
414 
415    vk_object_base_init(&device->vk, &shader_obj->base, VK_OBJECT_TYPE_SHADER_EXT);
416 
417    result = radv_shader_object_init(shader_obj, device, pCreateInfo);
418    if (result != VK_SUCCESS) {
419       radv_shader_object_destroy(device, shader_obj, pAllocator);
420       return result;
421    }
422 
423    *pShader = radv_shader_object_to_handle(shader_obj);
424 
425    return VK_SUCCESS;
426 }
427 
428 static VkResult
radv_shader_object_create_linked(VkDevice _device,uint32_t createInfoCount,const VkShaderCreateInfoEXT * pCreateInfos,const VkAllocationCallbacks * pAllocator,VkShaderEXT * pShaders)429 radv_shader_object_create_linked(VkDevice _device, uint32_t createInfoCount, const VkShaderCreateInfoEXT *pCreateInfos,
430                                  const VkAllocationCallbacks *pAllocator, VkShaderEXT *pShaders)
431 {
432    VK_FROM_HANDLE(radv_device, device, _device);
433    const struct radv_physical_device *pdev = radv_device_physical(device);
434    struct radv_shader_stage stages[MESA_VULKAN_SHADER_STAGES];
435 
436    for (unsigned i = 0; i < MESA_VULKAN_SHADER_STAGES; i++) {
437       stages[i].entrypoint = NULL;
438       stages[i].nir = NULL;
439       stages[i].spirv.size = 0;
440       stages[i].next_stage = MESA_SHADER_NONE;
441    }
442 
443    struct radv_graphics_state_key gfx_state = {0};
444 
445    gfx_state.vs.has_prolog = true;
446    gfx_state.ps.has_epilog = true;
447    gfx_state.dynamic_rasterization_samples = true;
448    gfx_state.unknown_rast_prim = true;
449    gfx_state.dynamic_provoking_vtx_mode = true;
450    gfx_state.dynamic_line_rast_mode = true;
451 
452    if (pdev->info.gfx_level >= GFX11)
453       gfx_state.ps.exports_mrtz_via_epilog = true;
454 
455    for (uint32_t i = 0; i < MAX_RTS; i++)
456       gfx_state.ps.epilog.color_map[i] = i;
457 
458    for (unsigned i = 0; i < createInfoCount; i++) {
459       const VkShaderCreateInfoEXT *pCreateInfo = &pCreateInfos[i];
460       gl_shader_stage s = vk_to_mesa_shader_stage(pCreateInfo->stage);
461 
462       radv_shader_stage_init(pCreateInfo, &stages[s]);
463    }
464 
465    /* Determine next stage. */
466    for (unsigned i = 0; i < MESA_VULKAN_SHADER_STAGES; i++) {
467       if (!stages[i].entrypoint)
468          continue;
469 
470       switch (stages[i].stage) {
471       case MESA_SHADER_VERTEX:
472          if (stages[MESA_SHADER_TESS_CTRL].entrypoint) {
473             stages[i].next_stage = MESA_SHADER_TESS_CTRL;
474          } else if (stages[MESA_SHADER_GEOMETRY].entrypoint) {
475             stages[i].next_stage = MESA_SHADER_GEOMETRY;
476          } else if (stages[MESA_SHADER_FRAGMENT].entrypoint) {
477             stages[i].next_stage = MESA_SHADER_FRAGMENT;
478          }
479          break;
480       case MESA_SHADER_TESS_CTRL:
481          stages[i].next_stage = MESA_SHADER_TESS_EVAL;
482          break;
483       case MESA_SHADER_TESS_EVAL:
484          if (stages[MESA_SHADER_GEOMETRY].entrypoint) {
485             stages[i].next_stage = MESA_SHADER_GEOMETRY;
486          } else if (stages[MESA_SHADER_FRAGMENT].entrypoint) {
487             stages[i].next_stage = MESA_SHADER_FRAGMENT;
488          }
489          break;
490       case MESA_SHADER_GEOMETRY:
491       case MESA_SHADER_MESH:
492          if (stages[MESA_SHADER_FRAGMENT].entrypoint) {
493             stages[i].next_stage = MESA_SHADER_FRAGMENT;
494          }
495          break;
496       case MESA_SHADER_FRAGMENT:
497          stages[i].next_stage = MESA_SHADER_NONE;
498          break;
499       case MESA_SHADER_TASK:
500          stages[i].next_stage = MESA_SHADER_MESH;
501          break;
502       default:
503          assert(0);
504       }
505    }
506 
507    struct radv_shader *shaders[MESA_VULKAN_SHADER_STAGES] = {NULL};
508    struct radv_shader_binary *binaries[MESA_VULKAN_SHADER_STAGES] = {NULL};
509    struct radv_shader *gs_copy_shader = NULL;
510    struct radv_shader_binary *gs_copy_binary = NULL;
511 
512    radv_graphics_shaders_compile(device, NULL, stages, &gfx_state, true, false, false, NULL, false, shaders, binaries,
513                                  &gs_copy_shader, &gs_copy_binary);
514 
515    for (unsigned i = 0; i < createInfoCount; i++) {
516       const VkShaderCreateInfoEXT *pCreateInfo = &pCreateInfos[i];
517       gl_shader_stage s = vk_to_mesa_shader_stage(pCreateInfo->stage);
518       struct radv_shader_object *shader_obj;
519 
520       shader_obj = vk_zalloc2(&device->vk.alloc, pAllocator, sizeof(*shader_obj), 8, VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
521       if (shader_obj == NULL)
522          return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
523 
524       vk_object_base_init(&device->vk, &shader_obj->base, VK_OBJECT_TYPE_SHADER_EXT);
525 
526       shader_obj->stage = s;
527       shader_obj->code_type = pCreateInfo->codeType;
528       shader_obj->push_constant_size = stages[s].layout.push_constant_size;
529       shader_obj->dynamic_offset_count = stages[s].layout.dynamic_offset_count;
530 
531       if (s == MESA_SHADER_VERTEX) {
532          if (stages[s].next_stage == MESA_SHADER_TESS_CTRL) {
533             shader_obj->as_ls.shader = shaders[s];
534             shader_obj->as_ls.binary = binaries[s];
535          } else if (stages[s].next_stage == MESA_SHADER_GEOMETRY) {
536             shader_obj->as_es.shader = shaders[s];
537             shader_obj->as_es.binary = binaries[s];
538          } else {
539             shader_obj->shader = shaders[s];
540             shader_obj->binary = binaries[s];
541          }
542       } else if (s == MESA_SHADER_TESS_EVAL) {
543          if (stages[s].next_stage == MESA_SHADER_GEOMETRY) {
544             shader_obj->as_es.shader = shaders[s];
545             shader_obj->as_es.binary = binaries[s];
546          } else {
547             shader_obj->shader = shaders[s];
548             shader_obj->binary = binaries[s];
549          }
550       } else {
551          shader_obj->shader = shaders[s];
552          shader_obj->binary = binaries[s];
553       }
554 
555       if (s == MESA_SHADER_GEOMETRY) {
556          shader_obj->gs.copy_shader = gs_copy_shader;
557          shader_obj->gs.copy_binary = gs_copy_binary;
558       }
559 
560       ralloc_free(stages[s].nir);
561 
562       pShaders[i] = radv_shader_object_to_handle(shader_obj);
563    }
564 
565    return VK_SUCCESS;
566 }
567 
568 static bool
radv_shader_object_linking_enabled(uint32_t createInfoCount,const VkShaderCreateInfoEXT * pCreateInfos)569 radv_shader_object_linking_enabled(uint32_t createInfoCount, const VkShaderCreateInfoEXT *pCreateInfos)
570 {
571    const bool has_linked_spirv = createInfoCount > 1 &&
572                                  !!(pCreateInfos[0].flags & VK_SHADER_CREATE_LINK_STAGE_BIT_EXT) &&
573                                  pCreateInfos[0].codeType == VK_SHADER_CODE_TYPE_SPIRV_EXT;
574 
575    if (!has_linked_spirv)
576       return false;
577 
578    /* Gather the available shader stages. */
579    VkShaderStageFlagBits stages = 0;
580    for (unsigned i = 0; i < createInfoCount; i++) {
581       const VkShaderCreateInfoEXT *pCreateInfo = &pCreateInfos[i];
582       stages |= pCreateInfo->stage;
583    }
584 
585    for (unsigned i = 0; i < createInfoCount; i++) {
586       const VkShaderCreateInfoEXT *pCreateInfo = &pCreateInfos[i];
587 
588       /* Force disable shaders linking when the next stage of VS/TES isn't present because the
589        * driver would need to compile all shaders twice due to shader variants. This is probably
590        * less optimal than compiling unlinked shaders.
591        */
592       if ((pCreateInfo->stage & VK_SHADER_STAGE_VERTEX_BIT) &&
593           (pCreateInfo->nextStage & (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_GEOMETRY_BIT)) &&
594           !(stages & (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_GEOMETRY_BIT)))
595          return false;
596 
597       if ((pCreateInfo->stage & VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT) &&
598           (pCreateInfo->nextStage & VK_SHADER_STAGE_GEOMETRY_BIT) && !(stages & VK_SHADER_STAGE_GEOMETRY_BIT))
599          return false;
600 
601       assert(pCreateInfo->flags & VK_SHADER_CREATE_LINK_STAGE_BIT_EXT);
602    }
603 
604    return true;
605 }
606 
607 VKAPI_ATTR VkResult VKAPI_CALL
radv_CreateShadersEXT(VkDevice _device,uint32_t createInfoCount,const VkShaderCreateInfoEXT * pCreateInfos,const VkAllocationCallbacks * pAllocator,VkShaderEXT * pShaders)608 radv_CreateShadersEXT(VkDevice _device, uint32_t createInfoCount, const VkShaderCreateInfoEXT *pCreateInfos,
609                       const VkAllocationCallbacks *pAllocator, VkShaderEXT *pShaders)
610 {
611    VkResult result = VK_SUCCESS;
612    unsigned i = 0;
613 
614    if (radv_shader_object_linking_enabled(createInfoCount, pCreateInfos))
615       return radv_shader_object_create_linked(_device, createInfoCount, pCreateInfos, pAllocator, pShaders);
616 
617    for (; i < createInfoCount; i++) {
618       VkResult r;
619 
620       r = radv_shader_object_create(_device, &pCreateInfos[i], pAllocator, &pShaders[i]);
621       if (r != VK_SUCCESS) {
622          result = r;
623          pShaders[i] = VK_NULL_HANDLE;
624       }
625    }
626 
627    for (; i < createInfoCount; ++i)
628       pShaders[i] = VK_NULL_HANDLE;
629 
630    return result;
631 }
632 
633 static size_t
radv_get_shader_binary_size(const struct radv_shader_binary * binary)634 radv_get_shader_binary_size(const struct radv_shader_binary *binary)
635 {
636    size_t size = sizeof(uint32_t); /* has_binary */
637 
638    if (binary)
639       size += SHA1_DIGEST_LENGTH + 4 + ALIGN(binary->total_size, 4);
640 
641    return size;
642 }
643 
644 static size_t
radv_get_shader_object_size(const struct radv_shader_object * shader_obj)645 radv_get_shader_object_size(const struct radv_shader_object *shader_obj)
646 {
647    size_t size = VK_UUID_SIZE;
648 
649    size += radv_get_shader_binary_size(shader_obj->binary);
650 
651    if (shader_obj->stage == MESA_SHADER_VERTEX) {
652       size += radv_get_shader_binary_size(shader_obj->as_es.binary);
653       size += radv_get_shader_binary_size(shader_obj->as_ls.binary);
654    } else if (shader_obj->stage == MESA_SHADER_TESS_EVAL) {
655       size += radv_get_shader_binary_size(shader_obj->as_es.binary);
656    } else if (shader_obj->stage == MESA_SHADER_GEOMETRY) {
657       size += radv_get_shader_binary_size(shader_obj->gs.copy_binary);
658    }
659 
660    return size;
661 }
662 
663 static void
radv_write_shader_binary(struct blob * blob,const struct radv_shader_binary * binary)664 radv_write_shader_binary(struct blob *blob, const struct radv_shader_binary *binary)
665 {
666    unsigned char binary_sha1[SHA1_DIGEST_LENGTH];
667 
668    blob_write_uint32(blob, !!binary);
669 
670    if (binary) {
671       _mesa_sha1_compute(binary, binary->total_size, binary_sha1);
672 
673       blob_write_bytes(blob, binary_sha1, sizeof(binary_sha1));
674       blob_write_uint32(blob, binary->total_size);
675       blob_write_bytes(blob, binary, binary->total_size);
676    }
677 }
678 
679 VKAPI_ATTR VkResult VKAPI_CALL
radv_GetShaderBinaryDataEXT(VkDevice _device,VkShaderEXT shader,size_t * pDataSize,void * pData)680 radv_GetShaderBinaryDataEXT(VkDevice _device, VkShaderEXT shader, size_t *pDataSize, void *pData)
681 {
682    VK_FROM_HANDLE(radv_device, device, _device);
683    VK_FROM_HANDLE(radv_shader_object, shader_obj, shader);
684    const struct radv_physical_device *pdev = radv_device_physical(device);
685    const size_t size = radv_get_shader_object_size(shader_obj);
686 
687    if (!pData) {
688       *pDataSize = size;
689       return VK_SUCCESS;
690    }
691 
692    if (*pDataSize < size) {
693       *pDataSize = 0;
694       return VK_INCOMPLETE;
695    }
696 
697    struct blob blob;
698    blob_init_fixed(&blob, pData, *pDataSize);
699    blob_write_bytes(&blob, pdev->cache_uuid, VK_UUID_SIZE);
700 
701    radv_write_shader_binary(&blob, shader_obj->binary);
702 
703    if (shader_obj->stage == MESA_SHADER_VERTEX) {
704       radv_write_shader_binary(&blob, shader_obj->as_es.binary);
705       radv_write_shader_binary(&blob, shader_obj->as_ls.binary);
706    } else if (shader_obj->stage == MESA_SHADER_TESS_EVAL) {
707       radv_write_shader_binary(&blob, shader_obj->as_es.binary);
708    } else if (shader_obj->stage == MESA_SHADER_GEOMETRY) {
709       radv_write_shader_binary(&blob, shader_obj->gs.copy_binary);
710    }
711 
712    assert(!blob.out_of_memory);
713 
714    return VK_SUCCESS;
715 }
716