1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2023 LunarG, Inc.
6 * Copyright (c) 2023 Nintendo
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief Wrapper that can construct monolithic pipeline or use
23 VK_EXT_shader_object for compute pipeline construction.
24 *//*--------------------------------------------------------------------*/
25
26 #include "vkComputePipelineConstructionUtil.hpp"
27 #include "vkQueryUtil.hpp"
28 #include "vkObjUtil.hpp"
29
30 namespace vk
31 {
32
checkShaderObjectRequirements(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,ComputePipelineConstructionType computePipelineConstructionType)33 void checkShaderObjectRequirements(const InstanceInterface &vki, VkPhysicalDevice physicalDevice,
34 ComputePipelineConstructionType computePipelineConstructionType)
35 {
36 if (computePipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_PIPELINE)
37 return;
38
39 const auto &supportedExtensions = enumerateCachedDeviceExtensionProperties(vki, physicalDevice);
40 if (!isExtensionStructSupported(supportedExtensions, RequiredExtension("VK_EXT_shader_object")))
41 TCU_THROW(NotSupportedError, "VK_EXT_shader_object not supported");
42 }
43
44 struct ComputePipelineWrapper::InternalData
45 {
46 const DeviceInterface &vk;
47 VkDevice device;
48 const ComputePipelineConstructionType pipelineConstructionType;
49
50 // initialize with most common values
InternalDatavk::ComputePipelineWrapper::InternalData51 InternalData(const DeviceInterface &vkd, VkDevice vkDevice, const ComputePipelineConstructionType constructionType)
52 : vk(vkd)
53 , device(vkDevice)
54 , pipelineConstructionType(constructionType)
55 {
56 }
57 };
58
ComputePipelineWrapper(const DeviceInterface & vk,VkDevice device,const ComputePipelineConstructionType pipelineConstructionType)59 ComputePipelineWrapper::ComputePipelineWrapper(const DeviceInterface &vk, VkDevice device,
60 const ComputePipelineConstructionType pipelineConstructionType)
61 : m_internalData(new ComputePipelineWrapper::InternalData(vk, device, pipelineConstructionType))
62 , m_programBinary(DE_NULL)
63 , m_specializationInfo{}
64 , m_pipelineCreateFlags((VkPipelineCreateFlags)0u)
65 , m_pipelineCreatePNext(DE_NULL)
66 , m_subgroupSize(0)
67 {
68 }
69
ComputePipelineWrapper(const DeviceInterface & vk,VkDevice device,const ComputePipelineConstructionType pipelineConstructionType,const ProgramBinary & programBinary)70 ComputePipelineWrapper::ComputePipelineWrapper(const DeviceInterface &vk, VkDevice device,
71 const ComputePipelineConstructionType pipelineConstructionType,
72 const ProgramBinary &programBinary)
73 : m_internalData(new ComputePipelineWrapper::InternalData(vk, device, pipelineConstructionType))
74 , m_programBinary(&programBinary)
75 , m_specializationInfo{}
76 , m_pipelineCreateFlags((VkPipelineCreateFlags)0u)
77 , m_pipelineCreatePNext(DE_NULL)
78 , m_subgroupSize(0)
79 {
80 }
81
ComputePipelineWrapper(const ComputePipelineWrapper & rhs)82 ComputePipelineWrapper::ComputePipelineWrapper(const ComputePipelineWrapper &rhs) noexcept
83 : m_internalData(rhs.m_internalData)
84 , m_programBinary(rhs.m_programBinary)
85 , m_descriptorSetLayouts(rhs.m_descriptorSetLayouts)
86 , m_specializationInfo(rhs.m_specializationInfo)
87 , m_pipelineCreateFlags(rhs.m_pipelineCreateFlags)
88 , m_pipelineCreatePNext(rhs.m_pipelineCreatePNext)
89 , m_subgroupSize(rhs.m_subgroupSize)
90 {
91 DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
92 #ifndef CTS_USES_VULKANSC
93 DE_ASSERT(rhs.m_shader.get() == DE_NULL);
94 #endif
95 }
96
ComputePipelineWrapper(ComputePipelineWrapper && rhs)97 ComputePipelineWrapper::ComputePipelineWrapper(ComputePipelineWrapper &&rhs) noexcept
98 : m_internalData(rhs.m_internalData)
99 , m_programBinary(rhs.m_programBinary)
100 , m_descriptorSetLayouts(rhs.m_descriptorSetLayouts)
101 , m_specializationInfo(rhs.m_specializationInfo)
102 , m_pipelineCreateFlags(rhs.m_pipelineCreateFlags)
103 , m_pipelineCreatePNext(rhs.m_pipelineCreatePNext)
104 , m_subgroupSize(rhs.m_subgroupSize)
105 {
106 DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
107 #ifndef CTS_USES_VULKANSC
108 DE_ASSERT(rhs.m_shader.get() == DE_NULL);
109 #endif
110 }
111
operator =(const ComputePipelineWrapper & rhs)112 ComputePipelineWrapper &ComputePipelineWrapper::operator=(const ComputePipelineWrapper &rhs) noexcept
113 {
114 m_internalData = rhs.m_internalData;
115 m_programBinary = rhs.m_programBinary;
116 m_descriptorSetLayouts = rhs.m_descriptorSetLayouts;
117 m_specializationInfo = rhs.m_specializationInfo;
118 m_pipelineCreateFlags = rhs.m_pipelineCreateFlags;
119 m_pipelineCreatePNext = rhs.m_pipelineCreatePNext;
120 DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
121 #ifndef CTS_USES_VULKANSC
122 DE_ASSERT(rhs.m_shader.get() == DE_NULL);
123 #endif
124 m_subgroupSize = rhs.m_subgroupSize;
125 return *this;
126 }
127
operator =(ComputePipelineWrapper && rhs)128 ComputePipelineWrapper &ComputePipelineWrapper::operator=(ComputePipelineWrapper &&rhs) noexcept
129 {
130 m_internalData = std::move(rhs.m_internalData);
131 m_programBinary = rhs.m_programBinary;
132 m_descriptorSetLayouts = std::move(rhs.m_descriptorSetLayouts);
133 m_specializationInfo = rhs.m_specializationInfo;
134 m_pipelineCreateFlags = rhs.m_pipelineCreateFlags;
135 m_pipelineCreatePNext = rhs.m_pipelineCreatePNext;
136 DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
137 #ifndef CTS_USES_VULKANSC
138 DE_ASSERT(rhs.m_shader.get() == DE_NULL);
139 #endif
140 m_subgroupSize = rhs.m_subgroupSize;
141 return *this;
142 }
143
setDescriptorSetLayout(VkDescriptorSetLayout descriptorSetLayout)144 void ComputePipelineWrapper::setDescriptorSetLayout(VkDescriptorSetLayout descriptorSetLayout)
145 {
146 m_descriptorSetLayouts = {descriptorSetLayout};
147 }
148
setDescriptorSetLayouts(uint32_t setLayoutCount,const VkDescriptorSetLayout * descriptorSetLayouts)149 void ComputePipelineWrapper::setDescriptorSetLayouts(uint32_t setLayoutCount,
150 const VkDescriptorSetLayout *descriptorSetLayouts)
151 {
152 m_descriptorSetLayouts.assign(descriptorSetLayouts, descriptorSetLayouts + setLayoutCount);
153 }
154
setSpecializationInfo(VkSpecializationInfo specializationInfo)155 void ComputePipelineWrapper::setSpecializationInfo(VkSpecializationInfo specializationInfo)
156 {
157 m_specializationInfo = specializationInfo;
158 }
159
setPipelineCreateFlags(VkPipelineCreateFlags pipelineCreateFlags)160 void ComputePipelineWrapper::setPipelineCreateFlags(VkPipelineCreateFlags pipelineCreateFlags)
161 {
162 m_pipelineCreateFlags = pipelineCreateFlags;
163 }
164
setPipelineCreatePNext(void * pipelineCreatePNext)165 void ComputePipelineWrapper::setPipelineCreatePNext(void *pipelineCreatePNext)
166 {
167 m_pipelineCreatePNext = pipelineCreatePNext;
168 }
169
setSubgroupSize(uint32_t subgroupSize)170 void ComputePipelineWrapper::setSubgroupSize(uint32_t subgroupSize)
171 {
172 m_subgroupSize = subgroupSize;
173 }
buildPipeline(void)174 void ComputePipelineWrapper::buildPipeline(void)
175 {
176 const auto &vk = m_internalData->vk;
177 const auto &device = m_internalData->device;
178
179 VkSpecializationInfo *specializationInfo = m_specializationInfo.mapEntryCount > 0 ? &m_specializationInfo : DE_NULL;
180 if (m_internalData->pipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_PIPELINE)
181 {
182 DE_ASSERT(m_pipeline.get() == DE_NULL);
183 const Unique<VkShaderModule> shaderModule(createShaderModule(vk, device, *m_programBinary));
184 buildPipelineLayout();
185 m_pipeline =
186 vk::makeComputePipeline(vk, device, *m_pipelineLayout, m_pipelineCreateFlags, m_pipelineCreatePNext,
187 *shaderModule, 0u, specializationInfo, 0, m_subgroupSize);
188 }
189 else
190 {
191 #ifndef CTS_USES_VULKANSC
192 DE_ASSERT(m_shader.get() == DE_NULL);
193 buildPipelineLayout();
194
195 VkShaderRequiredSubgroupSizeCreateInfoEXT subgroupSizeCreateInfo = {
196 VK_STRUCTURE_TYPE_SHADER_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT, // VkStructureType sType;
197 DE_NULL, // void* pNext;
198 m_subgroupSize, // uint32_t requiredSubgroupSize;
199 };
200
201 vk::VkShaderCreateFlagsEXT flags = 0u;
202 if (m_pipelineCreateFlags & vk::VK_PIPELINE_CREATE_DISPATCH_BASE)
203 flags |= vk::VK_SHADER_CREATE_DISPATCH_BASE_BIT_EXT;
204
205 const auto createFlags2 = findStructure<VkPipelineCreateFlags2CreateInfoKHR>(m_pipelineCreatePNext);
206 if (createFlags2 && (createFlags2->flags & vk::VK_PIPELINE_CREATE_2_DISPATCH_BASE_BIT_KHR))
207 flags |= vk::VK_SHADER_CREATE_DISPATCH_BASE_BIT_EXT;
208
209 vk::VkShaderCreateInfoEXT createInfo = {
210 vk::VK_STRUCTURE_TYPE_SHADER_CREATE_INFO_EXT, // VkStructureType sType;
211 m_subgroupSize != 0 ? &subgroupSizeCreateInfo : DE_NULL, // const void* pNext;
212 flags, // VkShaderCreateFlagsEXT flags;
213 vk::VK_SHADER_STAGE_COMPUTE_BIT, // VkShaderStageFlagBits stage;
214 0u, // VkShaderStageFlags nextStage;
215 vk::VK_SHADER_CODE_TYPE_SPIRV_EXT, // VkShaderCodeTypeEXT codeType;
216 m_programBinary->getSize(), // size_t codeSize;
217 m_programBinary->getBinary(), // const void* pCode;
218 "main", // const char* pName;
219 (uint32_t)m_descriptorSetLayouts.size(), // uint32_t setLayoutCount;
220 m_descriptorSetLayouts.data(), // VkDescriptorSetLayout* pSetLayouts;
221 0u, // uint32_t pushConstantRangeCount;
222 DE_NULL, // const VkPushConstantRange* pPushConstantRanges;
223 specializationInfo, // const VkSpecializationInfo* pSpecializationInfo;
224 };
225
226 m_shader = createShader(vk, device, createInfo);
227
228 if (m_internalData->pipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_SHADER_OBJECT_BINARY)
229 {
230 size_t dataSize;
231 vk.getShaderBinaryDataEXT(device, *m_shader, &dataSize, DE_NULL);
232 std::vector<uint8_t> data(dataSize);
233 vk.getShaderBinaryDataEXT(device, *m_shader, &dataSize, data.data());
234
235 createInfo.codeType = vk::VK_SHADER_CODE_TYPE_BINARY_EXT;
236 createInfo.codeSize = dataSize;
237 createInfo.pCode = data.data();
238
239 m_shader = createShader(vk, device, createInfo);
240 }
241 #endif
242 }
243 }
244
bind(VkCommandBuffer commandBuffer)245 void ComputePipelineWrapper::bind(VkCommandBuffer commandBuffer)
246 {
247 if (m_internalData->pipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_PIPELINE)
248 {
249 m_internalData->vk.cmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, m_pipeline.get());
250 }
251 else
252 {
253 #ifndef CTS_USES_VULKANSC
254 const vk::VkShaderStageFlagBits stage = vk::VK_SHADER_STAGE_COMPUTE_BIT;
255 m_internalData->vk.cmdBindShadersEXT(commandBuffer, 1, &stage, &*m_shader);
256 #endif
257 }
258 }
259
buildPipelineLayout(void)260 void ComputePipelineWrapper::buildPipelineLayout(void)
261 {
262 m_pipelineLayout = makePipelineLayout(m_internalData->vk, m_internalData->device, m_descriptorSetLayouts);
263 }
264
getPipelineLayout(void)265 VkPipelineLayout ComputePipelineWrapper::getPipelineLayout(void)
266 {
267 return *m_pipelineLayout;
268 }
269
270 } // namespace vk
271