xref: /aosp_15_r20/external/deqp/external/vulkancts/framework/vulkan/vkComputePipelineConstructionUtil.cpp (revision 35238bce31c2a825756842865a792f8cf7f89930)
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2023 LunarG, Inc.
6  * Copyright (c) 2023 Nintendo
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Wrapper that can construct monolithic pipeline or use
23           VK_EXT_shader_object for compute pipeline construction.
24  *//*--------------------------------------------------------------------*/
25 
26 #include "vkComputePipelineConstructionUtil.hpp"
27 #include "vkQueryUtil.hpp"
28 #include "vkObjUtil.hpp"
29 
30 namespace vk
31 {
32 
checkShaderObjectRequirements(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,ComputePipelineConstructionType computePipelineConstructionType)33 void checkShaderObjectRequirements(const InstanceInterface &vki, VkPhysicalDevice physicalDevice,
34                                    ComputePipelineConstructionType computePipelineConstructionType)
35 {
36     if (computePipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_PIPELINE)
37         return;
38 
39     const auto &supportedExtensions = enumerateCachedDeviceExtensionProperties(vki, physicalDevice);
40     if (!isExtensionStructSupported(supportedExtensions, RequiredExtension("VK_EXT_shader_object")))
41         TCU_THROW(NotSupportedError, "VK_EXT_shader_object not supported");
42 }
43 
44 struct ComputePipelineWrapper::InternalData
45 {
46     const DeviceInterface &vk;
47     VkDevice device;
48     const ComputePipelineConstructionType pipelineConstructionType;
49 
50     // initialize with most common values
InternalDatavk::ComputePipelineWrapper::InternalData51     InternalData(const DeviceInterface &vkd, VkDevice vkDevice, const ComputePipelineConstructionType constructionType)
52         : vk(vkd)
53         , device(vkDevice)
54         , pipelineConstructionType(constructionType)
55     {
56     }
57 };
58 
ComputePipelineWrapper(const DeviceInterface & vk,VkDevice device,const ComputePipelineConstructionType pipelineConstructionType)59 ComputePipelineWrapper::ComputePipelineWrapper(const DeviceInterface &vk, VkDevice device,
60                                                const ComputePipelineConstructionType pipelineConstructionType)
61     : m_internalData(new ComputePipelineWrapper::InternalData(vk, device, pipelineConstructionType))
62     , m_programBinary(DE_NULL)
63     , m_specializationInfo{}
64     , m_pipelineCreateFlags((VkPipelineCreateFlags)0u)
65     , m_pipelineCreatePNext(DE_NULL)
66     , m_subgroupSize(0)
67 {
68 }
69 
ComputePipelineWrapper(const DeviceInterface & vk,VkDevice device,const ComputePipelineConstructionType pipelineConstructionType,const ProgramBinary & programBinary)70 ComputePipelineWrapper::ComputePipelineWrapper(const DeviceInterface &vk, VkDevice device,
71                                                const ComputePipelineConstructionType pipelineConstructionType,
72                                                const ProgramBinary &programBinary)
73     : m_internalData(new ComputePipelineWrapper::InternalData(vk, device, pipelineConstructionType))
74     , m_programBinary(&programBinary)
75     , m_specializationInfo{}
76     , m_pipelineCreateFlags((VkPipelineCreateFlags)0u)
77     , m_pipelineCreatePNext(DE_NULL)
78     , m_subgroupSize(0)
79 {
80 }
81 
ComputePipelineWrapper(const ComputePipelineWrapper & rhs)82 ComputePipelineWrapper::ComputePipelineWrapper(const ComputePipelineWrapper &rhs) noexcept
83     : m_internalData(rhs.m_internalData)
84     , m_programBinary(rhs.m_programBinary)
85     , m_descriptorSetLayouts(rhs.m_descriptorSetLayouts)
86     , m_specializationInfo(rhs.m_specializationInfo)
87     , m_pipelineCreateFlags(rhs.m_pipelineCreateFlags)
88     , m_pipelineCreatePNext(rhs.m_pipelineCreatePNext)
89     , m_subgroupSize(rhs.m_subgroupSize)
90 {
91     DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
92 #ifndef CTS_USES_VULKANSC
93     DE_ASSERT(rhs.m_shader.get() == DE_NULL);
94 #endif
95 }
96 
ComputePipelineWrapper(ComputePipelineWrapper && rhs)97 ComputePipelineWrapper::ComputePipelineWrapper(ComputePipelineWrapper &&rhs) noexcept
98     : m_internalData(rhs.m_internalData)
99     , m_programBinary(rhs.m_programBinary)
100     , m_descriptorSetLayouts(rhs.m_descriptorSetLayouts)
101     , m_specializationInfo(rhs.m_specializationInfo)
102     , m_pipelineCreateFlags(rhs.m_pipelineCreateFlags)
103     , m_pipelineCreatePNext(rhs.m_pipelineCreatePNext)
104     , m_subgroupSize(rhs.m_subgroupSize)
105 {
106     DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
107 #ifndef CTS_USES_VULKANSC
108     DE_ASSERT(rhs.m_shader.get() == DE_NULL);
109 #endif
110 }
111 
operator =(const ComputePipelineWrapper & rhs)112 ComputePipelineWrapper &ComputePipelineWrapper::operator=(const ComputePipelineWrapper &rhs) noexcept
113 {
114     m_internalData         = rhs.m_internalData;
115     m_programBinary        = rhs.m_programBinary;
116     m_descriptorSetLayouts = rhs.m_descriptorSetLayouts;
117     m_specializationInfo   = rhs.m_specializationInfo;
118     m_pipelineCreateFlags  = rhs.m_pipelineCreateFlags;
119     m_pipelineCreatePNext  = rhs.m_pipelineCreatePNext;
120     DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
121 #ifndef CTS_USES_VULKANSC
122     DE_ASSERT(rhs.m_shader.get() == DE_NULL);
123 #endif
124     m_subgroupSize = rhs.m_subgroupSize;
125     return *this;
126 }
127 
operator =(ComputePipelineWrapper && rhs)128 ComputePipelineWrapper &ComputePipelineWrapper::operator=(ComputePipelineWrapper &&rhs) noexcept
129 {
130     m_internalData         = std::move(rhs.m_internalData);
131     m_programBinary        = rhs.m_programBinary;
132     m_descriptorSetLayouts = std::move(rhs.m_descriptorSetLayouts);
133     m_specializationInfo   = rhs.m_specializationInfo;
134     m_pipelineCreateFlags  = rhs.m_pipelineCreateFlags;
135     m_pipelineCreatePNext  = rhs.m_pipelineCreatePNext;
136     DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
137 #ifndef CTS_USES_VULKANSC
138     DE_ASSERT(rhs.m_shader.get() == DE_NULL);
139 #endif
140     m_subgroupSize = rhs.m_subgroupSize;
141     return *this;
142 }
143 
setDescriptorSetLayout(VkDescriptorSetLayout descriptorSetLayout)144 void ComputePipelineWrapper::setDescriptorSetLayout(VkDescriptorSetLayout descriptorSetLayout)
145 {
146     m_descriptorSetLayouts = {descriptorSetLayout};
147 }
148 
setDescriptorSetLayouts(uint32_t setLayoutCount,const VkDescriptorSetLayout * descriptorSetLayouts)149 void ComputePipelineWrapper::setDescriptorSetLayouts(uint32_t setLayoutCount,
150                                                      const VkDescriptorSetLayout *descriptorSetLayouts)
151 {
152     m_descriptorSetLayouts.assign(descriptorSetLayouts, descriptorSetLayouts + setLayoutCount);
153 }
154 
setSpecializationInfo(VkSpecializationInfo specializationInfo)155 void ComputePipelineWrapper::setSpecializationInfo(VkSpecializationInfo specializationInfo)
156 {
157     m_specializationInfo = specializationInfo;
158 }
159 
setPipelineCreateFlags(VkPipelineCreateFlags pipelineCreateFlags)160 void ComputePipelineWrapper::setPipelineCreateFlags(VkPipelineCreateFlags pipelineCreateFlags)
161 {
162     m_pipelineCreateFlags = pipelineCreateFlags;
163 }
164 
setPipelineCreatePNext(void * pipelineCreatePNext)165 void ComputePipelineWrapper::setPipelineCreatePNext(void *pipelineCreatePNext)
166 {
167     m_pipelineCreatePNext = pipelineCreatePNext;
168 }
169 
setSubgroupSize(uint32_t subgroupSize)170 void ComputePipelineWrapper::setSubgroupSize(uint32_t subgroupSize)
171 {
172     m_subgroupSize = subgroupSize;
173 }
buildPipeline(void)174 void ComputePipelineWrapper::buildPipeline(void)
175 {
176     const auto &vk     = m_internalData->vk;
177     const auto &device = m_internalData->device;
178 
179     VkSpecializationInfo *specializationInfo = m_specializationInfo.mapEntryCount > 0 ? &m_specializationInfo : DE_NULL;
180     if (m_internalData->pipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_PIPELINE)
181     {
182         DE_ASSERT(m_pipeline.get() == DE_NULL);
183         const Unique<VkShaderModule> shaderModule(createShaderModule(vk, device, *m_programBinary));
184         buildPipelineLayout();
185         m_pipeline =
186             vk::makeComputePipeline(vk, device, *m_pipelineLayout, m_pipelineCreateFlags, m_pipelineCreatePNext,
187                                     *shaderModule, 0u, specializationInfo, 0, m_subgroupSize);
188     }
189     else
190     {
191 #ifndef CTS_USES_VULKANSC
192         DE_ASSERT(m_shader.get() == DE_NULL);
193         buildPipelineLayout();
194 
195         VkShaderRequiredSubgroupSizeCreateInfoEXT subgroupSizeCreateInfo = {
196             VK_STRUCTURE_TYPE_SHADER_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT, // VkStructureType sType;
197             DE_NULL,                                                         // void* pNext;
198             m_subgroupSize,                                                  // uint32_t requiredSubgroupSize;
199         };
200 
201         vk::VkShaderCreateFlagsEXT flags = 0u;
202         if (m_pipelineCreateFlags & vk::VK_PIPELINE_CREATE_DISPATCH_BASE)
203             flags |= vk::VK_SHADER_CREATE_DISPATCH_BASE_BIT_EXT;
204 
205         const auto createFlags2 = findStructure<VkPipelineCreateFlags2CreateInfoKHR>(m_pipelineCreatePNext);
206         if (createFlags2 && (createFlags2->flags & vk::VK_PIPELINE_CREATE_2_DISPATCH_BASE_BIT_KHR))
207             flags |= vk::VK_SHADER_CREATE_DISPATCH_BASE_BIT_EXT;
208 
209         vk::VkShaderCreateInfoEXT createInfo = {
210             vk::VK_STRUCTURE_TYPE_SHADER_CREATE_INFO_EXT,            // VkStructureType sType;
211             m_subgroupSize != 0 ? &subgroupSizeCreateInfo : DE_NULL, // const void* pNext;
212             flags,                                                   // VkShaderCreateFlagsEXT flags;
213             vk::VK_SHADER_STAGE_COMPUTE_BIT,                         // VkShaderStageFlagBits stage;
214             0u,                                                      // VkShaderStageFlags nextStage;
215             vk::VK_SHADER_CODE_TYPE_SPIRV_EXT,                       // VkShaderCodeTypeEXT codeType;
216             m_programBinary->getSize(),                              // size_t codeSize;
217             m_programBinary->getBinary(),                            // const void* pCode;
218             "main",                                                  // const char* pName;
219             (uint32_t)m_descriptorSetLayouts.size(),                 // uint32_t setLayoutCount;
220             m_descriptorSetLayouts.data(),                           // VkDescriptorSetLayout* pSetLayouts;
221             0u,                                                      // uint32_t pushConstantRangeCount;
222             DE_NULL,                                                 // const VkPushConstantRange* pPushConstantRanges;
223             specializationInfo,                                      // const VkSpecializationInfo* pSpecializationInfo;
224         };
225 
226         m_shader = createShader(vk, device, createInfo);
227 
228         if (m_internalData->pipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_SHADER_OBJECT_BINARY)
229         {
230             size_t dataSize;
231             vk.getShaderBinaryDataEXT(device, *m_shader, &dataSize, DE_NULL);
232             std::vector<uint8_t> data(dataSize);
233             vk.getShaderBinaryDataEXT(device, *m_shader, &dataSize, data.data());
234 
235             createInfo.codeType = vk::VK_SHADER_CODE_TYPE_BINARY_EXT;
236             createInfo.codeSize = dataSize;
237             createInfo.pCode    = data.data();
238 
239             m_shader = createShader(vk, device, createInfo);
240         }
241 #endif
242     }
243 }
244 
bind(VkCommandBuffer commandBuffer)245 void ComputePipelineWrapper::bind(VkCommandBuffer commandBuffer)
246 {
247     if (m_internalData->pipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_PIPELINE)
248     {
249         m_internalData->vk.cmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, m_pipeline.get());
250     }
251     else
252     {
253 #ifndef CTS_USES_VULKANSC
254         const vk::VkShaderStageFlagBits stage = vk::VK_SHADER_STAGE_COMPUTE_BIT;
255         m_internalData->vk.cmdBindShadersEXT(commandBuffer, 1, &stage, &*m_shader);
256 #endif
257     }
258 }
259 
buildPipelineLayout(void)260 void ComputePipelineWrapper::buildPipelineLayout(void)
261 {
262     m_pipelineLayout = makePipelineLayout(m_internalData->vk, m_internalData->device, m_descriptorSetLayouts);
263 }
264 
getPipelineLayout(void)265 VkPipelineLayout ComputePipelineWrapper::getPipelineLayout(void)
266 {
267     return *m_pipelineLayout;
268 }
269 
270 } // namespace vk
271