1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2021 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *      http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Acceleration Structure binding tests
22  *//*--------------------------------------------------------------------*/
23 
24 #include "vktBindingDescriptorUpdateASTests.hpp"
25 
26 #include "vkDefs.hpp"
27 
28 #include "vktTestCase.hpp"
29 #include "vktTestGroupUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 #include "vkObjUtil.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34 #include "vkBufferWithMemory.hpp"
35 #include "vkImageWithMemory.hpp"
36 #include "vkTypeUtil.hpp"
37 #include "vkImageUtil.hpp"
38 #include "deRandom.hpp"
39 #include "tcuTexture.hpp"
40 #include "tcuTextureUtil.hpp"
41 #include "tcuTestLog.hpp"
42 #include "tcuImageCompare.hpp"
43 #include "tcuCommandLine.hpp"
44 
45 #include "vkRayTracingUtil.hpp"
46 
47 namespace vkt
48 {
49 namespace BindingModel
50 {
51 namespace
52 {
53 using namespace vk;
54 using namespace vkt;
55 
56 static const VkFlags ALL_RAY_TRACING_STAGES = VK_SHADER_STAGE_RAYGEN_BIT_KHR | VK_SHADER_STAGE_ANY_HIT_BIT_KHR |
57                                               VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR |
58                                               VK_SHADER_STAGE_INTERSECTION_BIT_KHR | VK_SHADER_STAGE_CALLABLE_BIT_KHR;
59 
60 enum TestType
61 {
62     TEST_TYPE_USING_RAY_QUERY = 0,
63     TEST_TYPE_USING_RAY_TRACING,
64 };
65 
66 enum UpdateMethod
67 {
68     UPDATE_METHOD_NORMAL = 0,         //!< use vkUpdateDescriptorSets                vkUpdateDescriptorSets
69     UPDATE_METHOD_WITH_TEMPLATE,      //!< use descriptor update templates        vkUpdateDescriptorSetWithTemplate
70     UPDATE_METHOD_WITH_PUSH,          //!< use push descriptor updates            vkCmdPushDescriptorSetKHR
71     UPDATE_METHOD_WITH_PUSH_TEMPLATE, //!< use push descriptor update templates    vkCmdPushDescriptorSetWithTemplateKHR
72 
73     UPDATE_METHOD_LAST
74 };
75 
76 const uint32_t TEST_WIDTH          = 16u;
77 const uint32_t TEST_HEIGHT         = 16u;
78 const uint32_t FIXED_POINT_DIVISOR = 1024 * 1024;
79 const float PLAIN_Z0               = 2.0f;
80 const float PLAIN_Z1               = 4.0f;
81 
82 struct TestParams;
83 
84 typedef void (*CheckSupportFunc)(Context &context, const TestParams &testParams);
85 typedef void (*InitProgramsFunc)(SourceCollections &programCollection, const TestParams &testParams);
86 typedef const std::string (*ShaderBodyTextFunc)(const TestParams &testParams);
87 
88 struct TestParams
89 {
90     uint32_t width;
91     uint32_t height;
92     uint32_t depth;
93     TestType testType;
94     UpdateMethod updateMethod;
95     VkShaderStageFlagBits stage;
96     VkFormat format;
97     CheckSupportFunc pipelineCheckSupport;
98     InitProgramsFunc pipelineInitPrograms;
99     ShaderBodyTextFunc testConfigShaderBodyText;
100 };
101 
getShaderGroupHandleSize(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)102 static uint32_t getShaderGroupHandleSize(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
103 {
104     de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
105 
106     rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
107 
108     return rayTracingPropertiesKHR->getShaderGroupHandleSize();
109 }
110 
getShaderGroupBaseAlignment(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)111 static uint32_t getShaderGroupBaseAlignment(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
112 {
113     de::MovePtr<RayTracingProperties> rayTracingPropertiesKHR;
114 
115     rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
116 
117     return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
118 }
119 
getVkBuffer(const de::MovePtr<BufferWithMemory> & buffer)120 static VkBuffer getVkBuffer(const de::MovePtr<BufferWithMemory> &buffer)
121 {
122     VkBuffer result = (buffer.get() == DE_NULL) ? DE_NULL : buffer->get();
123 
124     return result;
125 }
126 
makeStridedDeviceAddressRegion(const DeviceInterface & vkd,const VkDevice device,VkBuffer buffer,uint32_t stride,uint32_t count)127 static VkStridedDeviceAddressRegionKHR makeStridedDeviceAddressRegion(const DeviceInterface &vkd, const VkDevice device,
128                                                                       VkBuffer buffer, uint32_t stride, uint32_t count)
129 {
130     if (buffer == DE_NULL)
131     {
132         return makeStridedDeviceAddressRegionKHR(0, 0, 0);
133     }
134     else
135     {
136         return makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, buffer, 0), stride,
137                                                  stride * count);
138     }
139 }
140 
makePipelineLayout(const DeviceInterface & vk,const VkDevice device,const VkDescriptorSetLayout descriptorSetLayout0,const VkDescriptorSetLayout descriptorSetLayout1,const VkDescriptorSetLayout descriptorSetLayoutOpt=DE_NULL)141 static Move<VkPipelineLayout> makePipelineLayout(const DeviceInterface &vk, const VkDevice device,
142                                                  const VkDescriptorSetLayout descriptorSetLayout0,
143                                                  const VkDescriptorSetLayout descriptorSetLayout1,
144                                                  const VkDescriptorSetLayout descriptorSetLayoutOpt = DE_NULL)
145 {
146     std::vector<VkDescriptorSetLayout> descriptorSetLayouts;
147 
148     descriptorSetLayouts.push_back(descriptorSetLayout0);
149     descriptorSetLayouts.push_back(descriptorSetLayout1);
150 
151     if (descriptorSetLayoutOpt != DE_NULL)
152         descriptorSetLayouts.push_back(descriptorSetLayoutOpt);
153 
154     return makePipelineLayout(vk, device, (uint32_t)descriptorSetLayouts.size(), descriptorSetLayouts.data());
155 }
156 
makeWriteDescriptorSetAccelerationStructureKHR(const VkAccelerationStructureKHR * accelerationStructureKHR)157 static VkWriteDescriptorSetAccelerationStructureKHR makeWriteDescriptorSetAccelerationStructureKHR(
158     const VkAccelerationStructureKHR *accelerationStructureKHR)
159 {
160     const VkWriteDescriptorSetAccelerationStructureKHR result = {
161         VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR, //  VkStructureType sType;
162         DE_NULL,                                                           //  const void* pNext;
163         1u,                                                                //  uint32_t accelerationStructureCount;
164         accelerationStructureKHR //  const VkAccelerationStructureKHR* pAccelerationStructures;
165     };
166 
167     return result;
168 }
169 
isPushUpdateMethod(const UpdateMethod updateMethod)170 static bool isPushUpdateMethod(const UpdateMethod updateMethod)
171 {
172     switch (updateMethod)
173     {
174     case UPDATE_METHOD_NORMAL:
175         return false;
176     case UPDATE_METHOD_WITH_TEMPLATE:
177         return false;
178     case UPDATE_METHOD_WITH_PUSH:
179         return true;
180     case UPDATE_METHOD_WITH_PUSH_TEMPLATE:
181         return true;
182     default:
183         TCU_THROW(InternalError, "Unknown update method");
184     }
185 }
186 
isTemplateUpdateMethod(const UpdateMethod updateMethod)187 static bool isTemplateUpdateMethod(const UpdateMethod updateMethod)
188 {
189     switch (updateMethod)
190     {
191     case UPDATE_METHOD_NORMAL:
192         return false;
193     case UPDATE_METHOD_WITH_TEMPLATE:
194         return true;
195     case UPDATE_METHOD_WITH_PUSH:
196         return false;
197     case UPDATE_METHOD_WITH_PUSH_TEMPLATE:
198         return true;
199     default:
200         TCU_THROW(InternalError, "Unknown update method");
201     }
202 }
203 
makeDescriptorSet(const DeviceInterface & vki,const VkDevice device,const VkDescriptorPool descriptorPool,const VkDescriptorSetLayout setLayout,UpdateMethod updateMethod)204 static Move<VkDescriptorSet> makeDescriptorSet(const DeviceInterface &vki, const VkDevice device,
205                                                const VkDescriptorPool descriptorPool,
206                                                const VkDescriptorSetLayout setLayout, UpdateMethod updateMethod)
207 {
208     const bool pushUpdateMethod         = isPushUpdateMethod(updateMethod);
209     Move<VkDescriptorSet> descriptorSet = pushUpdateMethod ?
210                                               vk::Move<vk::VkDescriptorSet>() :
211                                               vk::makeDescriptorSet(vki, device, descriptorPool, setLayout, DE_NULL);
212 
213     return descriptorSet;
214 }
215 
makeImageCreateInfo(VkFormat format,uint32_t width,uint32_t height,uint32_t depth,VkImageType imageType=VK_IMAGE_TYPE_3D,VkImageUsageFlags usageFlags=VK_IMAGE_USAGE_STORAGE_BIT|VK_IMAGE_USAGE_TRANSFER_SRC_BIT|VK_IMAGE_USAGE_TRANSFER_DST_BIT)216 static VkImageCreateInfo makeImageCreateInfo(VkFormat format, uint32_t width, uint32_t height, uint32_t depth,
217                                              VkImageType imageType        = VK_IMAGE_TYPE_3D,
218                                              VkImageUsageFlags usageFlags = VK_IMAGE_USAGE_STORAGE_BIT |
219                                                                             VK_IMAGE_USAGE_TRANSFER_SRC_BIT |
220                                                                             VK_IMAGE_USAGE_TRANSFER_DST_BIT)
221 {
222     const VkImageCreateInfo imageCreateInfo = {
223         VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
224         DE_NULL,                             // const void* pNext;
225         (VkImageCreateFlags)0u,              // VkImageCreateFlags flags;
226         imageType,                           // VkImageType imageType;
227         format,                              // VkFormat format;
228         makeExtent3D(width, height, depth),  // VkExtent3D extent;
229         1u,                                  // uint32_t mipLevels;
230         1u,                                  // uint32_t arrayLayers;
231         VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
232         VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
233         usageFlags,                          // VkImageUsageFlags usage;
234         VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
235         0u,                                  // uint32_t queueFamilyIndexCount;
236         DE_NULL,                             // const uint32_t* pQueueFamilyIndices;
237         VK_IMAGE_LAYOUT_UNDEFINED            // VkImageLayout initialLayout;
238     };
239 
240     return imageCreateInfo;
241 }
242 
getMissPassthrough(void)243 static const std::string getMissPassthrough(void)
244 {
245     std::ostringstream src;
246 
247     src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
248         << "#extension GL_EXT_ray_tracing : require\n"
249         << "\n"
250         << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
251         << "\n"
252         << "void main()\n"
253         << "{\n"
254         << "}\n";
255 
256     return src.str();
257 }
258 
getHitPassthrough(void)259 static const std::string getHitPassthrough(void)
260 {
261     std::ostringstream src;
262 
263     src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
264         << "#extension GL_EXT_ray_tracing : require\n"
265         << "hitAttributeEXT vec3 attribs;\n"
266         << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
267         << "\n"
268         << "void main()\n"
269         << "{\n"
270         << "}\n";
271 
272     return src.str();
273 }
274 
getGraphicsPassthrough(void)275 static const std::string getGraphicsPassthrough(void)
276 {
277     std::ostringstream src;
278 
279     src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
280         << "\n"
281         << "void main(void)\n"
282         << "{\n"
283         << "}\n";
284 
285     return src.str();
286 }
287 
getVertexPassthrough(void)288 static const std::string getVertexPassthrough(void)
289 {
290     std::ostringstream src;
291 
292     src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
293         << "\n"
294         << "layout(location = 0) in vec4 in_position;\n"
295         << "\n"
296         << "void main(void)\n"
297         << "{\n"
298         << "  gl_Position = in_position;\n"
299         << "}\n";
300 
301     return src.str();
302 }
303 
getDescriptorSetLayoutCreateFlags(const UpdateMethod updateMethod)304 static VkDescriptorSetLayoutCreateFlags getDescriptorSetLayoutCreateFlags(const UpdateMethod updateMethod)
305 {
306     vk::VkDescriptorSetLayoutCreateFlags extraFlags = 0;
307 
308     if (updateMethod == UPDATE_METHOD_WITH_PUSH_TEMPLATE || updateMethod == UPDATE_METHOD_WITH_PUSH)
309     {
310         extraFlags |= vk::VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
311     }
312 
313     return extraFlags;
314 }
315 
316 class BindingAcceleratioStructureTestInstance : public TestInstance
317 {
318 public:
319     BindingAcceleratioStructureTestInstance(Context &context, const TestParams &testParams);
~BindingAcceleratioStructureTestInstance()320     virtual ~BindingAcceleratioStructureTestInstance()
321     {
322     }
323     virtual tcu::TestStatus iterate(void);
324 
325 protected:
326     virtual void initPipeline(void)                            = 0;
327     virtual uint32_t getExtraAccelerationDescriptorCount(void) = 0;
328     virtual VkShaderStageFlags getShaderStageFlags(void)       = 0;
329     virtual VkPipelineBindPoint getPipelineBindPoint(void)     = 0;
330 
331     virtual void fillCommandBuffer(VkCommandBuffer commandBuffer) = 0;
332 
333     virtual const VkAccelerationStructureKHR *createAccelerationStructures(Context &context, TestParams &testParams);
334     virtual void buildAccelerationStructures(Context &context, TestParams &testParams, VkCommandBuffer commandBuffer);
335     virtual bool verify(BufferWithMemory *resultBuffer, Context &context, TestParams &testParams);
336 
337     TestParams m_testParams;
338 
339     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> m_bottomAccelerationStructures;
340     de::SharedPtr<TopLevelAccelerationStructure> m_topAccelerationStructure;
341 
342     Move<VkDescriptorPool> m_descriptorPool;
343 
344     Move<VkDescriptorSetLayout> m_descriptorSetLayoutImg;
345     Move<VkDescriptorSet> m_descriptorSetImg;
346 
347     Move<VkDescriptorSetLayout> m_descriptorSetLayoutAS;
348     Move<VkDescriptorSet> m_descriptorSetAS;
349 
350     Move<VkPipelineLayout> m_pipelineLayout;
351     Move<VkPipeline> m_pipeline;
352 
353     Move<VkDescriptorUpdateTemplate> m_updateTemplate;
354 };
355 
BindingAcceleratioStructureTestInstance(Context & context,const TestParams & testParams)356 BindingAcceleratioStructureTestInstance::BindingAcceleratioStructureTestInstance(Context &context,
357                                                                                  const TestParams &testParams)
358     : TestInstance(context)
359     , m_testParams(testParams)
360     , m_bottomAccelerationStructures()
361     , m_topAccelerationStructure()
362     , m_descriptorPool()
363     , m_descriptorSetLayoutImg()
364     , m_descriptorSetImg()
365     , m_descriptorSetLayoutAS()
366     , m_descriptorSetAS()
367     , m_pipelineLayout()
368     , m_pipeline()
369     , m_updateTemplate()
370 {
371 }
372 
iterate(void)373 tcu::TestStatus BindingAcceleratioStructureTestInstance::iterate(void)
374 {
375     const DeviceInterface &vkd      = m_context.getDeviceInterface();
376     const VkDevice device           = m_context.getDevice();
377     const VkQueue queue             = m_context.getUniversalQueue();
378     Allocator &allocator            = m_context.getDefaultAllocator();
379     const uint32_t queueFamilyIndex = m_context.getUniversalQueueFamilyIndex();
380     const bool templateUpdateMethod = isTemplateUpdateMethod(m_testParams.updateMethod);
381     const bool pushUpdateMethod     = isPushUpdateMethod(m_testParams.updateMethod);
382 
383     const uint32_t width                    = m_testParams.width;
384     const uint32_t height                   = m_testParams.height;
385     const uint32_t depth                    = m_testParams.depth;
386     const VkImageCreateInfo imageCreateInfo = makeImageCreateInfo(m_testParams.format, width, height, depth);
387     const VkImageSubresourceRange imageSubresourceRange =
388         makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
389     const de::MovePtr<ImageWithMemory> image = de::MovePtr<ImageWithMemory>(
390         new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
391     const Move<VkImageView> imageView =
392         makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_3D, m_testParams.format, imageSubresourceRange);
393 
394     const uint32_t pixelSize = mapVkFormat(m_testParams.format).getPixelSize();
395     const VkBufferCreateInfo resultBufferCreateInfo =
396         makeBufferCreateInfo(width * height * depth * pixelSize, VK_BUFFER_USAGE_TRANSFER_DST_BIT);
397     const VkImageSubresourceLayers resultBufferImageSubresourceLayers =
398         makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
399     const VkBufferImageCopy resultBufferImageRegion =
400         makeBufferImageCopy(makeExtent3D(width, height, depth), resultBufferImageSubresourceLayers);
401     de::MovePtr<BufferWithMemory> resultBuffer = de::MovePtr<BufferWithMemory>(
402         new BufferWithMemory(vkd, device, allocator, resultBufferCreateInfo, MemoryRequirement::HostVisible));
403     const VkDescriptorImageInfo resultImageInfo = makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
404 
405     const Move<VkCommandPool> commandPool = createCommandPool(vkd, device, 0, queueFamilyIndex);
406     const Move<VkCommandBuffer> commandBuffer =
407         allocateCommandBuffer(vkd, device, *commandPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
408     const VkAccelerationStructureKHR *topAccelerationStructurePtr =
409         createAccelerationStructures(m_context, m_testParams);
410     const VkWriteDescriptorSetAccelerationStructureKHR writeDescriptorSetAccelerationStructure =
411         makeWriteDescriptorSetAccelerationStructureKHR(topAccelerationStructurePtr);
412     const uint32_t accelerationStructureDescriptorCount = 1 + getExtraAccelerationDescriptorCount();
413     uint32_t updateCount                                = 0;
414 
415     m_descriptorPool = DescriptorPoolBuilder()
416                            .addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
417                            .addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, accelerationStructureDescriptorCount)
418                            .build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT,
419                                   1u + accelerationStructureDescriptorCount);
420 
421     m_descriptorSetLayoutImg = DescriptorSetLayoutBuilder()
422                                    .addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, getShaderStageFlags())
423                                    .build(vkd, device);
424     m_descriptorSetImg = makeDescriptorSet(vkd, device, *m_descriptorPool, *m_descriptorSetLayoutImg);
425 
426     DescriptorSetUpdateBuilder()
427         .writeSingle(*m_descriptorSetImg, DescriptorSetUpdateBuilder::Location::binding(0u),
428                      VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &resultImageInfo)
429         .update(vkd, device);
430 
431     m_descriptorSetLayoutAS =
432         DescriptorSetLayoutBuilder()
433             .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, getShaderStageFlags())
434             .build(vkd, device, getDescriptorSetLayoutCreateFlags(m_testParams.updateMethod));
435     m_descriptorSetAS =
436         makeDescriptorSet(vkd, device, *m_descriptorPool, *m_descriptorSetLayoutAS, m_testParams.updateMethod);
437 
438     initPipeline();
439 
440     if (m_testParams.updateMethod == UPDATE_METHOD_NORMAL)
441     {
442         DescriptorSetUpdateBuilder()
443             .writeSingle(*m_descriptorSetAS, DescriptorSetUpdateBuilder::Location::binding(0u),
444                          VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &writeDescriptorSetAccelerationStructure)
445             .update(vkd, device);
446 
447         updateCount++;
448     }
449 
450     if (templateUpdateMethod)
451     {
452         const VkDescriptorUpdateTemplateType updateTemplateType =
453             isPushUpdateMethod(m_testParams.updateMethod) ? VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR :
454                                                             VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_DESCRIPTOR_SET;
455         const VkDescriptorUpdateTemplateEntry updateTemplateEntry = {
456             0,                                             //  uint32_t dstBinding;
457             0,                                             //  uint32_t dstArrayElement;
458             1,                                             //  uint32_t descriptorCount;
459             VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, //  VkDescriptorType descriptorType;
460             0,                                             //  uintptr_t offset;
461             0,                                             //  uintptr_t stride;
462         };
463         const VkDescriptorUpdateTemplateCreateInfo templateCreateInfo = {
464             VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR, //  VkStructureType sType;
465             DE_NULL,                                                      //  const void* pNext;
466             0,                        //  VkDescriptorUpdateTemplateCreateFlags flags;
467             1,                        //  uint32_t descriptorUpdateEntryCount;
468             &updateTemplateEntry,     //  const VkDescriptorUpdateTemplateEntry* pDescriptorUpdateEntries;
469             updateTemplateType,       //  VkDescriptorUpdateTemplateType templateType;
470             *m_descriptorSetLayoutAS, //  VkDescriptorSetLayout descriptorSetLayout;
471             getPipelineBindPoint(),   //  VkPipelineBindPoint pipelineBindPoint;
472             *m_pipelineLayout,        //  VkPipelineLayout pipelineLayout;
473             0,                        //  uint32_t set;
474         };
475 
476         m_updateTemplate = vk::createDescriptorUpdateTemplate(vkd, device, &templateCreateInfo);
477 
478         if (!pushUpdateMethod)
479         {
480             vkd.updateDescriptorSetWithTemplate(device, *m_descriptorSetAS, *m_updateTemplate,
481                                                 topAccelerationStructurePtr);
482 
483             updateCount++;
484         }
485     }
486 
487     beginCommandBuffer(vkd, *commandBuffer, 0u);
488     {
489         {
490             const VkClearValue clearValue = makeClearValueColorU32(0u, 0u, 0u, 0u);
491             const VkImageMemoryBarrier preImageBarrier =
492                 makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT, VK_IMAGE_LAYOUT_UNDEFINED,
493                                        VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, **image, imageSubresourceRange);
494             const VkImageMemoryBarrier postImageBarrier = makeImageMemoryBarrier(
495                 VK_ACCESS_TRANSFER_WRITE_BIT,
496                 VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR,
497                 VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL, **image, imageSubresourceRange);
498 
499             cmdPipelineImageMemoryBarrier(vkd, *commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
500                                           VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
501             vkd.cmdClearColorImage(*commandBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1,
502                                    &imageSubresourceRange);
503             cmdPipelineImageMemoryBarrier(vkd, *commandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT,
504                                           VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR, &postImageBarrier);
505 
506             vkd.cmdBindDescriptorSets(*commandBuffer, getPipelineBindPoint(), *m_pipelineLayout, 1, 1,
507                                       &m_descriptorSetImg.get(), 0, DE_NULL);
508         }
509 
510         switch (m_testParams.updateMethod)
511         {
512         case UPDATE_METHOD_NORMAL: // fallthrough
513         case UPDATE_METHOD_WITH_TEMPLATE:
514         {
515             vkd.cmdBindDescriptorSets(*commandBuffer, getPipelineBindPoint(), *m_pipelineLayout, 0, 1,
516                                       &m_descriptorSetAS.get(), 0, DE_NULL);
517 
518             break;
519         }
520 
521         case UPDATE_METHOD_WITH_PUSH:
522         {
523             DescriptorSetUpdateBuilder()
524                 .writeSingle(*m_descriptorSetAS, DescriptorSetUpdateBuilder::Location::binding(0u),
525                              VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &writeDescriptorSetAccelerationStructure)
526                 .updateWithPush(vkd, *commandBuffer, getPipelineBindPoint(), *m_pipelineLayout, 0, 0, 1);
527 
528             updateCount++;
529 
530             break;
531         }
532 
533         case UPDATE_METHOD_WITH_PUSH_TEMPLATE:
534         {
535             vkd.cmdPushDescriptorSetWithTemplateKHR(*commandBuffer, *m_updateTemplate, *m_pipelineLayout, 0,
536                                                     topAccelerationStructurePtr);
537 
538             updateCount++;
539 
540             break;
541         }
542 
543         default:
544             TCU_THROW(InternalError, "Unknown update method");
545         }
546 
547         {
548             const VkMemoryBarrier preTraceMemoryBarrier = makeMemoryBarrier(
549                 VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR);
550             const VkPipelineStageFlags dstStageFlags = VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR;
551 
552             buildAccelerationStructures(m_context, m_testParams, *commandBuffer);
553 
554             cmdPipelineMemoryBarrier(vkd, *commandBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
555                                      dstStageFlags, &preTraceMemoryBarrier);
556         }
557 
558         fillCommandBuffer(*commandBuffer);
559 
560         {
561             const VkMemoryBarrier postTestMemoryBarrier =
562                 makeMemoryBarrier(VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
563 
564             cmdPipelineMemoryBarrier(vkd, *commandBuffer, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT,
565                                      VK_PIPELINE_STAGE_TRANSFER_BIT, &postTestMemoryBarrier);
566         }
567 
568         vkd.cmdCopyImageToBuffer(*commandBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **resultBuffer, 1u,
569                                  &resultBufferImageRegion);
570     }
571     endCommandBuffer(vkd, *commandBuffer);
572 
573     if (updateCount != 1)
574         TCU_THROW(InternalError, "Invalid descriptor update");
575 
576     submitCommandsAndWait(vkd, device, queue, commandBuffer.get());
577 
578     invalidateMappedMemoryRange(vkd, device, resultBuffer->getAllocation().getMemory(),
579                                 resultBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
580 
581     if (verify(resultBuffer.get(), m_context, m_testParams))
582         return tcu::TestStatus::pass("Pass");
583     else
584         return tcu::TestStatus::fail("Fail");
585 }
586 
createAccelerationStructures(Context & context,TestParams & testParams)587 const VkAccelerationStructureKHR *BindingAcceleratioStructureTestInstance::createAccelerationStructures(
588     Context &context, TestParams &testParams)
589 {
590     DE_UNREF(testParams);
591 
592     const DeviceInterface &vkd = context.getDeviceInterface();
593     const VkDevice device      = context.getDevice();
594     Allocator &allocator       = context.getDefaultAllocator();
595     de::MovePtr<BottomLevelAccelerationStructure> rayQueryBottomLevelAccelerationStructure =
596         makeBottomLevelAccelerationStructure();
597     de::MovePtr<TopLevelAccelerationStructure> rayQueryTopLevelAccelerationStructure =
598         makeTopLevelAccelerationStructure();
599     std::vector<tcu::Vec3> geometryData;
600 
601     // Generate in-plain square starting at (0,0,PLAIN_Z0) and ending at (1,1,PLAIN_Z1).
602     // Vertices 1,0 and 0,1 by Z axis are in the middle between PLAIN_Z0 and PLAIN_Z1
603     geometryData.push_back(tcu::Vec3(0.0f, 0.0f, PLAIN_Z0));
604     geometryData.push_back(tcu::Vec3(1.0f, 0.0f, (PLAIN_Z0 + PLAIN_Z1) / 2.0f));
605     geometryData.push_back(tcu::Vec3(0.0f, 1.0f, (PLAIN_Z0 + PLAIN_Z1) / 2.0f));
606     geometryData.push_back(tcu::Vec3(1.0f, 1.0f, PLAIN_Z1));
607     geometryData.push_back(tcu::Vec3(0.0f, 1.0f, (PLAIN_Z0 + PLAIN_Z1) / 2.0f));
608     geometryData.push_back(tcu::Vec3(1.0f, 0.0f, (PLAIN_Z0 + PLAIN_Z1) / 2.0f));
609 
610     rayQueryBottomLevelAccelerationStructure->setGeometryCount(1u);
611     rayQueryBottomLevelAccelerationStructure->addGeometry(geometryData, true);
612     rayQueryBottomLevelAccelerationStructure->create(vkd, device, allocator, 0);
613     m_bottomAccelerationStructures.push_back(
614         de::SharedPtr<BottomLevelAccelerationStructure>(rayQueryBottomLevelAccelerationStructure.release()));
615 
616     m_topAccelerationStructure =
617         de::SharedPtr<TopLevelAccelerationStructure>(rayQueryTopLevelAccelerationStructure.release());
618     m_topAccelerationStructure->addInstance(m_bottomAccelerationStructures.back());
619     m_topAccelerationStructure->create(vkd, device, allocator);
620 
621     return m_topAccelerationStructure.get()->getPtr();
622 }
623 
buildAccelerationStructures(Context & context,TestParams & testParams,VkCommandBuffer commandBuffer)624 void BindingAcceleratioStructureTestInstance::buildAccelerationStructures(Context &context, TestParams &testParams,
625                                                                           VkCommandBuffer commandBuffer)
626 {
627     DE_UNREF(testParams);
628 
629     const DeviceInterface &vkd = context.getDeviceInterface();
630     const VkDevice device      = context.getDevice();
631 
632     for (size_t blStructNdx = 0; blStructNdx < m_bottomAccelerationStructures.size(); ++blStructNdx)
633         m_bottomAccelerationStructures[blStructNdx]->build(vkd, device, commandBuffer);
634 
635     m_topAccelerationStructure->build(vkd, device, commandBuffer);
636 }
637 
verify(BufferWithMemory * resultBuffer,Context & context,TestParams & testParams)638 bool BindingAcceleratioStructureTestInstance::verify(BufferWithMemory *resultBuffer, Context &context,
639                                                      TestParams &testParams)
640 {
641     tcu::TestLog &log        = context.getTestContext().getLog();
642     const uint32_t width     = testParams.width;
643     const uint32_t height    = testParams.height;
644     const int32_t *retrieved = (int32_t *)resultBuffer->getAllocation().getHostPtr();
645     uint32_t failures        = 0;
646     uint32_t pos             = 0;
647     std::vector<int32_t> expected;
648 
649     expected.reserve(width * height);
650 
651     for (uint32_t y = 0; y < height; ++y)
652     {
653         const float expectedY = deFloatMix(PLAIN_Z0, PLAIN_Z1, (0.5f + float(y)) / float(height));
654 
655         for (uint32_t x = 0; x < width; ++x)
656         {
657             const float expectedX   = deFloatMix(PLAIN_Z0, PLAIN_Z1, (0.5f + float(x)) / float(width));
658             const int32_t expectedV = int32_t(float(FIXED_POINT_DIVISOR / 2) * (expectedX + expectedY));
659 
660             expected.push_back(expectedV);
661         }
662     }
663 
664     for (uint32_t y = 0; y < height; ++y)
665         for (uint32_t x = 0; x < width; ++x)
666         {
667             if (retrieved[pos] != expected[pos])
668             {
669                 failures++;
670 
671                 if (failures < 10)
672                 {
673                     const int32_t expectedValue  = expected[pos];
674                     const int32_t retrievedValue = retrieved[pos];
675 
676                     log << tcu::TestLog::Message << "At (" << x << "," << y << ") "
677                         << "expected " << std::fixed << std::setprecision(6) << std::setw(8)
678                         << float(expectedValue) / float(FIXED_POINT_DIVISOR) << " (" << expectedValue << ") "
679                         << "retrieved " << std::fixed << std::setprecision(6) << std::setw(8)
680                         << float(retrievedValue) / float(FIXED_POINT_DIVISOR) << " (" << retrievedValue << ") "
681                         << tcu::TestLog::EndMessage;
682                 }
683             }
684 
685             pos++;
686         }
687 
688     if (failures != 0)
689     {
690         for (uint32_t dumpNdx = 0; dumpNdx < 2; ++dumpNdx)
691         {
692             const int32_t *data  = (dumpNdx == 0) ? expected.data() : retrieved;
693             const char *dataName = (dumpNdx == 0) ? "Expected" : "Retrieved";
694             std::ostringstream css;
695 
696             pos = 0;
697 
698             for (uint32_t y = 0; y < height; ++y)
699             {
700                 for (uint32_t x = 0; x < width; ++x)
701                 {
702                     if (expected[pos] != retrieved[pos])
703                         css << std::fixed << std::setprecision(6) << std::setw(8)
704                             << float(data[pos]) / float(FIXED_POINT_DIVISOR) << ",";
705                     else
706                         css << "________,";
707 
708                     pos++;
709                 }
710 
711                 css << std::endl;
712             }
713 
714             log << tcu::TestLog::Message << dataName << ":" << tcu::TestLog::EndMessage;
715             log << tcu::TestLog::Message << css.str() << tcu::TestLog::EndMessage;
716         }
717     }
718 
719     return (failures == 0);
720 }
721 
722 class BindingAcceleratioStructureGraphicsTestInstance : public BindingAcceleratioStructureTestInstance
723 {
724 public:
725     static void checkSupport(Context &context, const TestParams &testParams);
726     static void initPrograms(SourceCollections &programCollection, const TestParams &testParams);
727 
728     BindingAcceleratioStructureGraphicsTestInstance(Context &context, const TestParams &testParams);
~BindingAcceleratioStructureGraphicsTestInstance()729     virtual ~BindingAcceleratioStructureGraphicsTestInstance()
730     {
731     }
732 
733 protected:
734     virtual void initPipeline(void) override;
735     virtual void fillCommandBuffer(VkCommandBuffer commandBuffer) override;
736 
737     void initVertexBuffer(void);
738     Move<VkPipeline> makeGraphicsPipeline(void);
739 
getExtraAccelerationDescriptorCount(void)740     virtual uint32_t getExtraAccelerationDescriptorCount(void) override
741     {
742         return 0;
743     }
getShaderStageFlags(void)744     virtual VkShaderStageFlags getShaderStageFlags(void) override
745     {
746         return VK_SHADER_STAGE_ALL_GRAPHICS;
747     }
getPipelineBindPoint(void)748     virtual VkPipelineBindPoint getPipelineBindPoint(void) override
749     {
750         return VK_PIPELINE_BIND_POINT_GRAPHICS;
751     }
752 
753     VkFormat m_framebufferFormat;
754     Move<VkImage> m_framebufferImage;
755     de::MovePtr<Allocation> m_framebufferImageAlloc;
756     Move<VkImageView> m_framebufferAttachment;
757 
758     Move<VkShaderModule> m_vertShaderModule;
759     Move<VkShaderModule> m_geomShaderModule;
760     Move<VkShaderModule> m_tescShaderModule;
761     Move<VkShaderModule> m_teseShaderModule;
762     Move<VkShaderModule> m_fragShaderModule;
763 
764     Move<VkRenderPass> m_renderPass;
765     Move<VkFramebuffer> m_framebuffer;
766 
767     uint32_t m_vertexCount;
768     Move<VkBuffer> m_vertexBuffer;
769     de::MovePtr<Allocation> m_vertexBufferAlloc;
770 };
771 
BindingAcceleratioStructureGraphicsTestInstance(Context & context,const TestParams & testParams)772 BindingAcceleratioStructureGraphicsTestInstance::BindingAcceleratioStructureGraphicsTestInstance(
773     Context &context, const TestParams &testParams)
774     : BindingAcceleratioStructureTestInstance(context, testParams)
775     , m_framebufferFormat(VK_FORMAT_R8G8B8A8_UNORM)
776     , m_framebufferImage()
777     , m_framebufferImageAlloc()
778     , m_framebufferAttachment()
779     , m_vertShaderModule()
780     , m_geomShaderModule()
781     , m_tescShaderModule()
782     , m_teseShaderModule()
783     , m_fragShaderModule()
784     , m_renderPass()
785     , m_framebuffer()
786     , m_vertexCount(0)
787     , m_vertexBuffer()
788     , m_vertexBufferAlloc()
789 {
790 }
791 
checkSupport(Context & context,const TestParams & testParams)792 void BindingAcceleratioStructureGraphicsTestInstance::checkSupport(Context &context, const TestParams &testParams)
793 {
794     switch (testParams.stage)
795     {
796     case VK_SHADER_STAGE_VERTEX_BIT:
797     case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
798     case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
799     case VK_SHADER_STAGE_GEOMETRY_BIT:
800         context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_VERTEX_PIPELINE_STORES_AND_ATOMICS);
801         break;
802     default:
803         break;
804     }
805 
806     switch (testParams.stage)
807     {
808     case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
809     case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
810         context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_TESSELLATION_SHADER);
811         break;
812     case VK_SHADER_STAGE_GEOMETRY_BIT:
813         context.requireDeviceCoreFeature(DEVICE_CORE_FEATURE_GEOMETRY_SHADER);
814         break;
815     default:
816         break;
817     }
818 }
819 
initPrograms(SourceCollections & programCollection,const TestParams & testParams)820 void BindingAcceleratioStructureGraphicsTestInstance::initPrograms(SourceCollections &programCollection,
821                                                                    const TestParams &testParams)
822 {
823     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
824     const std::string testShaderBody = testParams.testConfigShaderBodyText(testParams);
825 
826     switch (testParams.stage)
827     {
828     case VK_SHADER_STAGE_VERTEX_BIT:
829     {
830         {
831             std::ostringstream src;
832             src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
833                 << "#extension GL_EXT_ray_query : require\n"
834                 << "#extension GL_EXT_ray_tracing : require\n"
835                 << "\n"
836                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
837                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
838                 << "\n"
839                 << "void testFunc(ivec3 pos, ivec3 size)\n"
840                 << "{\n"
841                 << testShaderBody << "}\n"
842                 << "\n"
843                 << "void main(void)\n"
844                 << "{\n"
845                 << "  const int   posId    = int(gl_VertexIndex / 3);\n"
846                 << "  const int   vertId   = int(gl_VertexIndex % 3);\n"
847                 << "  const ivec3 size     = ivec3(" << testParams.width << ", " << testParams.height << ", 1);\n"
848                 << "  const ivec3 pos      = ivec3(posId % size.x, posId / size.x, 0);\n"
849                 << "\n"
850                 << "  if (vertId == 0)\n"
851                 << "  {\n"
852                 << "    testFunc(pos, size);\n"
853                 << "  }\n"
854                 << "}\n";
855 
856             programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << buildOptions;
857         }
858 
859         programCollection.glslSources.add("frag") << glu::FragmentSource(getGraphicsPassthrough()) << buildOptions;
860 
861         break;
862     }
863 
864     case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
865     {
866         {
867             std::ostringstream src;
868             src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
869                 << "\n"
870                 << "layout(location = 0) in vec4 in_position;\n"
871                 << "out gl_PerVertex\n"
872                 << "{\n"
873                 << "  vec4 gl_Position;\n"
874                 << "};\n"
875                 << "\n"
876                 << "void main(void)\n"
877                 << "{\n"
878                 << "  gl_Position = in_position;\n"
879                 << "}\n";
880 
881             programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << buildOptions;
882         }
883 
884         {
885             std::ostringstream src;
886             src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
887                 << "#extension GL_EXT_tessellation_shader : require\n"
888                 << "#extension GL_EXT_ray_query : require\n"
889                 << "\n"
890                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
891                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
892                 << "\n"
893                 << "in gl_PerVertex\n"
894                 << "{\n"
895                 << "  vec4 gl_Position;\n"
896                 << "} gl_in[];\n"
897                 << "layout(vertices = 3) out;\n"
898                 << "out gl_PerVertex\n"
899                 << "{\n"
900                 << "  vec4 gl_Position;\n"
901                 << "} gl_out[];\n"
902                 << "\n"
903                 << "void testFunc(ivec3 pos, ivec3 size)\n"
904                 << "{\n"
905                 << testShaderBody << "}\n"
906                 << "\n"
907                 << "void main(void)\n"
908                 << "{\n"
909                 << "\n"
910                 << "  if (gl_InvocationID == 0)\n"
911                 << "  {\n"
912                 << "    const ivec3 size = ivec3(" << testParams.width << ", " << testParams.height << ", 1);\n"
913                 << "    int index = int(gl_in[gl_InvocationID].gl_Position.z);\n"
914                 << "    int x = index % size.x;\n"
915                 << "    int y = index / size.y;\n"
916                 << "    const ivec3 pos = ivec3(x, y, 0);\n"
917                 << "    testFunc(pos, size);\n"
918                 << "  }\n"
919                 << "\n"
920                 << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
921                 << "  gl_TessLevelInner[0] = 1;\n"
922                 << "  gl_TessLevelInner[1] = 1;\n"
923                 << "  gl_TessLevelOuter[gl_InvocationID] = 1;\n"
924                 << "}\n";
925 
926             programCollection.glslSources.add("tesc") << glu::TessellationControlSource(src.str()) << buildOptions;
927         }
928 
929         {
930             std::ostringstream src;
931             src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
932                 << "#extension GL_EXT_tessellation_shader : require\n"
933                 << "layout(triangles, equal_spacing, ccw) in;\n"
934                 << "\n"
935                 << "in gl_PerVertex\n"
936                 << "{\n"
937                 << "  vec4 gl_Position;\n"
938                 << "} gl_in[];\n"
939                 << "\n"
940                 << "void main(void)\n"
941                 << "{\n"
942                 << "  gl_Position = gl_in[0].gl_Position;\n"
943                 << "}\n";
944 
945             programCollection.glslSources.add("tese") << glu::TessellationEvaluationSource(src.str()) << buildOptions;
946         }
947 
948         break;
949     }
950 
951     case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
952     {
953         {
954             std::ostringstream src;
955             src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
956                 << "\n"
957                 << "layout(location = 0) in vec4 in_position;\n"
958                 << "out gl_PerVertex"
959                 << "{\n"
960                 << "  vec4 gl_Position;\n"
961                 << "};\n"
962                 << "\n"
963                 << "void main(void)\n"
964                 << "{\n"
965                 << "  gl_Position = in_position;\n"
966                 << "}\n";
967 
968             programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << buildOptions;
969         }
970 
971         {
972             std::ostringstream src;
973             src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
974                 << "#extension GL_EXT_tessellation_shader : require\n"
975                 << "\n"
976                 << "in gl_PerVertex\n"
977                 << "{\n"
978                 << "  vec4 gl_Position;\n"
979                 << "} gl_in[];\n"
980                 << "layout(vertices = 3) out;\n"
981                 << "out gl_PerVertex\n"
982                 << "{\n"
983                 << "  vec4 gl_Position;\n"
984                 << "} gl_out[];\n"
985                 << "\n"
986                 << "void main(void)\n"
987                 << "{\n"
988                 << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
989                 << "  gl_TessLevelInner[0] = 1;\n"
990                 << "  gl_TessLevelInner[1] = 1;\n"
991                 << "  gl_TessLevelOuter[gl_InvocationID] = 1;\n"
992                 << "}\n";
993 
994             programCollection.glslSources.add("tesc") << glu::TessellationControlSource(src.str()) << buildOptions;
995         }
996 
997         {
998             std::ostringstream src;
999             src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
1000                 << "#extension GL_EXT_tessellation_shader : require\n"
1001                 << "#extension GL_EXT_ray_query : require\n"
1002                 << "\n"
1003                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
1004                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1005                 << "\n"
1006                 << "layout(triangles, equal_spacing, ccw) in;\n"
1007                 << "in gl_PerVertex\n"
1008                 << "{\n"
1009                 << "  vec4 gl_Position;\n"
1010                 << "} gl_in[];\n"
1011                 << "\n"
1012                 << "void testFunc(ivec3 pos, ivec3 size)\n"
1013                 << "{\n"
1014                 << testShaderBody << "}\n"
1015                 << "\n"
1016                 << "void main(void)\n"
1017                 << "{\n"
1018                 << "    const ivec3 size = ivec3(" << testParams.width << ", " << testParams.height << ", 1);\n"
1019                 << "    int index = int(gl_in[0].gl_Position.z);\n"
1020                 << "    int x = index % size.x;\n"
1021                 << "    int y = index / size.y;\n"
1022                 << "    const ivec3 pos = ivec3(x, y, 0);\n"
1023                 << "    testFunc(pos, size);\n"
1024                 << "    gl_Position = gl_in[0].gl_Position;\n"
1025                 << "}\n";
1026 
1027             programCollection.glslSources.add("tese") << glu::TessellationEvaluationSource(src.str()) << buildOptions;
1028         }
1029 
1030         break;
1031     }
1032 
1033     case VK_SHADER_STAGE_GEOMETRY_BIT:
1034     {
1035         programCollection.glslSources.add("vert") << glu::VertexSource(getVertexPassthrough()) << buildOptions;
1036 
1037         {
1038             std::ostringstream src;
1039             src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
1040                 << "#extension GL_EXT_ray_query : require\n"
1041                 << "\n"
1042                 << "layout(triangles) in;\n"
1043                 << "layout(points, max_vertices = 1) out;\n"
1044                 << "\n"
1045                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
1046                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1047                 << "\n"
1048                 << "void testFunc(ivec3 pos, ivec3 size)\n"
1049                 << "{\n"
1050                 << testShaderBody << "}\n"
1051                 << "\n"
1052                 << "void main(void)\n"
1053                 << "{\n"
1054                 << "  const int   posId    = int(gl_PrimitiveIDIn);\n"
1055                 << "  const ivec3 size     = ivec3(" << testParams.width << ", " << testParams.height << ", 1);\n"
1056                 << "  const ivec3 pos      = ivec3(posId % size.x, posId / size.x, 0);\n"
1057                 << "\n"
1058                 << "  testFunc(pos, size);\n"
1059                 << "  gl_PointSize = 1.0;\n"
1060                 << "}\n";
1061 
1062             programCollection.glslSources.add("geom") << glu::GeometrySource(src.str()) << buildOptions;
1063         }
1064 
1065         break;
1066     }
1067 
1068     case VK_SHADER_STAGE_FRAGMENT_BIT:
1069     {
1070         programCollection.glslSources.add("vert") << glu::VertexSource(getVertexPassthrough()) << buildOptions;
1071 
1072         {
1073             std::ostringstream src;
1074             src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
1075                 << "#extension GL_EXT_ray_query : require\n"
1076                 << "\n"
1077                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
1078                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1079                 << "\n"
1080                 << "void testFunc(ivec3 pos, ivec3 size)\n"
1081                 << "{\n"
1082                 << testShaderBody << "}\n"
1083                 << "\n"
1084                 << "void main(void)\n"
1085                 << "{\n"
1086                 << "  const ivec3 size     = ivec3(" << testParams.width << ", " << testParams.height << ", 1);\n"
1087                 << "  const ivec3 pos      = ivec3(int(gl_FragCoord.x - 0.5f), int(gl_FragCoord.y - 0.5f), 0);\n"
1088                 << "\n"
1089                 << "  testFunc(pos, size);\n"
1090                 << "}\n";
1091 
1092             programCollection.glslSources.add("frag") << glu::FragmentSource(src.str()) << buildOptions;
1093         }
1094 
1095         break;
1096     }
1097 
1098     default:
1099         TCU_THROW(InternalError, "Unknown stage");
1100     }
1101 }
1102 
initVertexBuffer(void)1103 void BindingAcceleratioStructureGraphicsTestInstance::initVertexBuffer(void)
1104 {
1105     const DeviceInterface &vkd = m_context.getDeviceInterface();
1106     const VkDevice device      = m_context.getDevice();
1107     const uint32_t width       = m_testParams.width;
1108     const uint32_t height      = m_testParams.height;
1109     Allocator &allocator       = m_context.getDefaultAllocator();
1110     std::vector<tcu::Vec4> vertices;
1111 
1112     switch (m_testParams.stage)
1113     {
1114     case VK_SHADER_STAGE_VERTEX_BIT:
1115     case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
1116     case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
1117     case VK_SHADER_STAGE_GEOMETRY_BIT:
1118     {
1119         float z = 0.0f;
1120 
1121         vertices.reserve(3 * height * width);
1122 
1123         for (uint32_t y = 0; y < height; ++y)
1124             for (uint32_t x = 0; x < width; ++x)
1125             {
1126                 const float x0 = float(x + 0) / float(width);
1127                 const float y0 = float(y + 0) / float(height);
1128                 const float x1 = float(x + 1) / float(width);
1129                 const float y1 = float(y + 1) / float(height);
1130                 const float xm = (x0 + x1) / 2.0f;
1131                 const float ym = (y0 + y1) / 2.0f;
1132 
1133                 vertices.push_back(tcu::Vec4(x0, y0, z, 1.0f));
1134                 vertices.push_back(tcu::Vec4(xm, y1, z, 1.0f));
1135                 vertices.push_back(tcu::Vec4(x1, ym, z, 1.0f));
1136 
1137                 z += 1.f;
1138             }
1139 
1140         break;
1141     }
1142 
1143     case VK_SHADER_STAGE_FRAGMENT_BIT:
1144     {
1145         const float z     = 1.0f;
1146         const tcu::Vec4 a = tcu::Vec4(-1.0f, -1.0f, z, 1.0f);
1147         const tcu::Vec4 b = tcu::Vec4(+1.0f, -1.0f, z, 1.0f);
1148         const tcu::Vec4 c = tcu::Vec4(-1.0f, +1.0f, z, 1.0f);
1149         const tcu::Vec4 d = tcu::Vec4(+1.0f, +1.0f, z, 1.0f);
1150 
1151         vertices.push_back(a);
1152         vertices.push_back(b);
1153         vertices.push_back(c);
1154 
1155         vertices.push_back(b);
1156         vertices.push_back(c);
1157         vertices.push_back(d);
1158 
1159         break;
1160     }
1161 
1162     default:
1163         TCU_THROW(InternalError, "Unknown stage");
1164     }
1165 
1166     // Initialize vertex buffer
1167     {
1168         const VkDeviceSize vertexBufferSize = sizeof(vertices[0][0]) * vertices[0].SIZE * vertices.size();
1169         const VkBufferCreateInfo vertexBufferCreateInfo =
1170             makeBufferCreateInfo(vertexBufferSize, VK_BUFFER_USAGE_VERTEX_BUFFER_BIT);
1171 
1172         m_vertexCount       = static_cast<uint32_t>(vertices.size());
1173         m_vertexBuffer      = createBuffer(vkd, device, &vertexBufferCreateInfo);
1174         m_vertexBufferAlloc = bindBuffer(vkd, device, allocator, *m_vertexBuffer, vk::MemoryRequirement::HostVisible);
1175 
1176         deMemcpy(m_vertexBufferAlloc->getHostPtr(), vertices.data(), (size_t)vertexBufferSize);
1177         flushAlloc(vkd, device, *m_vertexBufferAlloc);
1178     }
1179 }
1180 
makeGraphicsPipeline(void)1181 Move<VkPipeline> BindingAcceleratioStructureGraphicsTestInstance::makeGraphicsPipeline(void)
1182 {
1183     const DeviceInterface &vkd = m_context.getDeviceInterface();
1184     const VkDevice device      = m_context.getDevice();
1185     const bool tessStageTest   = (m_testParams.stage == VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT ||
1186                                 m_testParams.stage == VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
1187     const VkPrimitiveTopology topology =
1188         tessStageTest ? VK_PRIMITIVE_TOPOLOGY_PATCH_LIST : VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST;
1189     const uint32_t patchControlPoints = tessStageTest ? 3 : 0;
1190     const std::vector<VkViewport> viewports(1, makeViewport(m_testParams.width, m_testParams.height));
1191     const std::vector<VkRect2D> scissors(1, makeRect2D(m_testParams.width, m_testParams.height));
1192 
1193     return vk::makeGraphicsPipeline(vkd, device, *m_pipelineLayout, *m_vertShaderModule, *m_tescShaderModule,
1194                                     *m_teseShaderModule, *m_geomShaderModule, *m_fragShaderModule, *m_renderPass,
1195                                     viewports, scissors, topology, 0, patchControlPoints);
1196 }
1197 
initPipeline(void)1198 void BindingAcceleratioStructureGraphicsTestInstance::initPipeline(void)
1199 {
1200     const DeviceInterface &vkd       = m_context.getDeviceInterface();
1201     const VkDevice device            = m_context.getDevice();
1202     Allocator &allocator             = m_context.getDefaultAllocator();
1203     vk::BinaryCollection &collection = m_context.getBinaryCollection();
1204     VkShaderStageFlags shaders       = static_cast<VkShaderStageFlags>(0);
1205     uint32_t shaderCount             = 0;
1206 
1207     if (collection.contains("vert"))
1208         shaders |= VK_SHADER_STAGE_VERTEX_BIT;
1209     if (collection.contains("geom"))
1210         shaders |= VK_SHADER_STAGE_GEOMETRY_BIT;
1211     if (collection.contains("tesc"))
1212         shaders |= VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
1213     if (collection.contains("tese"))
1214         shaders |= VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
1215     if (collection.contains("frag"))
1216         shaders |= VK_SHADER_STAGE_FRAGMENT_BIT;
1217 
1218     for (BinaryCollection::Iterator it = collection.begin(); it != collection.end(); ++it)
1219         shaderCount++;
1220 
1221     if (shaderCount != (uint32_t)dePop32(shaders))
1222         TCU_THROW(InternalError, "Unused shaders detected in the collection");
1223 
1224     if (0 != (shaders & VK_SHADER_STAGE_VERTEX_BIT))
1225         m_vertShaderModule = createShaderModule(vkd, device, collection.get("vert"), 0);
1226     if (0 != (shaders & VK_SHADER_STAGE_GEOMETRY_BIT))
1227         m_geomShaderModule = createShaderModule(vkd, device, collection.get("geom"), 0);
1228     if (0 != (shaders & VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT))
1229         m_tescShaderModule = createShaderModule(vkd, device, collection.get("tesc"), 0);
1230     if (0 != (shaders & VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT))
1231         m_teseShaderModule = createShaderModule(vkd, device, collection.get("tese"), 0);
1232     if (0 != (shaders & VK_SHADER_STAGE_FRAGMENT_BIT))
1233         m_fragShaderModule = createShaderModule(vkd, device, collection.get("frag"), 0);
1234 
1235     m_framebufferImage      = makeImage(vkd, device,
1236                                         makeImageCreateInfo(m_framebufferFormat, m_testParams.width, m_testParams.height, 1u,
1237                                                             VK_IMAGE_TYPE_2D, VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT));
1238     m_framebufferImageAlloc = bindImage(vkd, device, allocator, *m_framebufferImage, MemoryRequirement::Any);
1239     m_framebufferAttachment =
1240         makeImageView(vkd, device, *m_framebufferImage, VK_IMAGE_VIEW_TYPE_2D, m_framebufferFormat,
1241                       makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u));
1242     m_renderPass = makeRenderPass(vkd, device, m_framebufferFormat);
1243     m_framebuffer =
1244         makeFramebuffer(vkd, device, *m_renderPass, *m_framebufferAttachment, m_testParams.width, m_testParams.height);
1245     m_pipelineLayout = makePipelineLayout(vkd, device, m_descriptorSetLayoutAS.get(), m_descriptorSetLayoutImg.get());
1246     m_pipeline       = makeGraphicsPipeline();
1247 
1248     initVertexBuffer();
1249 }
1250 
fillCommandBuffer(VkCommandBuffer commandBuffer)1251 void BindingAcceleratioStructureGraphicsTestInstance::fillCommandBuffer(VkCommandBuffer commandBuffer)
1252 {
1253     const DeviceInterface &vkd            = m_context.getDeviceInterface();
1254     const VkDeviceSize vertexBufferOffset = 0;
1255 
1256     vkd.cmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, *m_pipeline);
1257     vkd.cmdBindVertexBuffers(commandBuffer, 0u, 1u, &m_vertexBuffer.get(), &vertexBufferOffset);
1258 
1259     beginRenderPass(vkd, commandBuffer, *m_renderPass, *m_framebuffer,
1260                     makeRect2D(0, 0, m_testParams.width, m_testParams.height), tcu::UVec4());
1261 
1262     vkd.cmdDraw(commandBuffer, m_vertexCount, 1u, 0u, 0u);
1263 
1264     endRenderPass(vkd, commandBuffer);
1265 }
1266 
1267 class BindingAcceleratioStructureComputeTestInstance : public BindingAcceleratioStructureTestInstance
1268 {
1269 public:
1270     BindingAcceleratioStructureComputeTestInstance(Context &context, const TestParams &testParams);
1271 
~BindingAcceleratioStructureComputeTestInstance()1272     virtual ~BindingAcceleratioStructureComputeTestInstance()
1273     {
1274     }
1275 
1276     static void checkSupport(Context &context, const TestParams &testParams);
1277     static void initPrograms(SourceCollections &programCollection, const TestParams &testParams);
1278 
1279 protected:
1280     virtual void initPipeline(void) override;
1281     virtual void fillCommandBuffer(VkCommandBuffer commandBuffer) override;
1282 
getExtraAccelerationDescriptorCount(void)1283     virtual uint32_t getExtraAccelerationDescriptorCount(void) override
1284     {
1285         return 0;
1286     }
getShaderStageFlags(void)1287     virtual VkShaderStageFlags getShaderStageFlags(void) override
1288     {
1289         return VK_SHADER_STAGE_COMPUTE_BIT;
1290     }
getPipelineBindPoint(void)1291     virtual VkPipelineBindPoint getPipelineBindPoint(void) override
1292     {
1293         return VK_PIPELINE_BIND_POINT_COMPUTE;
1294     }
1295 
1296     Move<VkShaderModule> m_shaderModule;
1297 };
1298 
BindingAcceleratioStructureComputeTestInstance(Context & context,const TestParams & testParams)1299 BindingAcceleratioStructureComputeTestInstance::BindingAcceleratioStructureComputeTestInstance(
1300     Context &context, const TestParams &testParams)
1301     : BindingAcceleratioStructureTestInstance(context, testParams)
1302     , m_shaderModule()
1303 {
1304 }
1305 
checkSupport(Context & context,const TestParams & testParams)1306 void BindingAcceleratioStructureComputeTestInstance::checkSupport(Context &context, const TestParams &testParams)
1307 {
1308     DE_UNREF(context);
1309     DE_UNREF(testParams);
1310 }
1311 
initPrograms(SourceCollections & programCollection,const TestParams & testParams)1312 void BindingAcceleratioStructureComputeTestInstance::initPrograms(SourceCollections &programCollection,
1313                                                                   const TestParams &testParams)
1314 {
1315     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
1316     const std::string testShaderBody = testParams.testConfigShaderBodyText(testParams);
1317     const std::string testBody       = "  ivec3       pos      = ivec3(gl_WorkGroupID);\n"
1318                                        "  ivec3       size     = ivec3(gl_NumWorkGroups);\n" +
1319                                  testShaderBody;
1320 
1321     switch (testParams.stage)
1322     {
1323     case VK_SHADER_STAGE_COMPUTE_BIT:
1324     {
1325         std::stringstream css;
1326         css << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
1327             << "#extension GL_EXT_ray_query : require\n"
1328             << "\n"
1329             << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
1330             << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1331             << "\n"
1332             << "void main()\n"
1333             << "{\n"
1334             << testBody << "}\n";
1335 
1336         programCollection.glslSources.add("comp") << glu::ComputeSource(css.str()) << buildOptions;
1337 
1338         break;
1339     }
1340 
1341     default:
1342         TCU_THROW(InternalError, "Unknown stage");
1343     }
1344 }
1345 
initPipeline(void)1346 void BindingAcceleratioStructureComputeTestInstance::initPipeline(void)
1347 {
1348     const DeviceInterface &vkd       = m_context.getDeviceInterface();
1349     const VkDevice device            = m_context.getDevice();
1350     vk::BinaryCollection &collection = m_context.getBinaryCollection();
1351 
1352     m_shaderModule   = createShaderModule(vkd, device, collection.get("comp"), 0);
1353     m_pipelineLayout = makePipelineLayout(vkd, device, m_descriptorSetLayoutAS.get(), m_descriptorSetLayoutImg.get());
1354     m_pipeline       = makeComputePipeline(vkd, device, *m_pipelineLayout, *m_shaderModule);
1355 }
1356 
fillCommandBuffer(VkCommandBuffer commandBuffer)1357 void BindingAcceleratioStructureComputeTestInstance::fillCommandBuffer(VkCommandBuffer commandBuffer)
1358 {
1359     const DeviceInterface &vkd = m_context.getDeviceInterface();
1360 
1361     vkd.cmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, m_pipeline.get());
1362 
1363     vkd.cmdDispatch(commandBuffer, m_testParams.width, m_testParams.height, 1);
1364 }
1365 
1366 class BindingAcceleratioStructureRayTracingTestInstance : public BindingAcceleratioStructureTestInstance
1367 {
1368 public:
1369     BindingAcceleratioStructureRayTracingTestInstance(Context &context, const TestParams &testParams);
~BindingAcceleratioStructureRayTracingTestInstance()1370     virtual ~BindingAcceleratioStructureRayTracingTestInstance()
1371     {
1372     }
1373 
1374     static void checkSupport(Context &context, const TestParams &testParams);
1375     static void initPrograms(SourceCollections &programCollection, const TestParams &testParams);
1376 
1377 protected:
1378     virtual void initPipeline(void) override;
1379     virtual void fillCommandBuffer(VkCommandBuffer commandBuffer) override;
1380 
1381     de::MovePtr<BufferWithMemory> createShaderBindingTable(const InstanceInterface &vki, const DeviceInterface &vkd,
1382                                                            const VkDevice device, const VkPhysicalDevice physicalDevice,
1383                                                            const VkPipeline pipeline, Allocator &allocator,
1384                                                            de::MovePtr<RayTracingPipeline> &rayTracingPipeline,
1385                                                            const uint32_t group);
1386 
getExtraAccelerationDescriptorCount(void)1387     virtual uint32_t getExtraAccelerationDescriptorCount(void) override
1388     {
1389         return 1;
1390     }
getShaderStageFlags(void)1391     virtual VkShaderStageFlags getShaderStageFlags(void) override
1392     {
1393         return ALL_RAY_TRACING_STAGES;
1394     }
getPipelineBindPoint(void)1395     virtual VkPipelineBindPoint getPipelineBindPoint(void) override
1396     {
1397         return VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR;
1398     }
1399 
1400     uint32_t m_shaders;
1401     uint32_t m_raygenShaderGroup;
1402     uint32_t m_missShaderGroup;
1403     uint32_t m_hitShaderGroup;
1404     uint32_t m_callableShaderGroup;
1405     uint32_t m_shaderGroupCount;
1406 
1407     Move<VkDescriptorSetLayout> m_descriptorSetLayoutSvc;
1408     Move<VkDescriptorSet> m_descriptorSetSvc;
1409 
1410     de::MovePtr<RayTracingPipeline> m_rayTracingPipeline;
1411 
1412     de::MovePtr<BufferWithMemory> m_raygenShaderBindingTable;
1413     de::MovePtr<BufferWithMemory> m_hitShaderBindingTable;
1414     de::MovePtr<BufferWithMemory> m_missShaderBindingTable;
1415     de::MovePtr<BufferWithMemory> m_callableShaderBindingTable;
1416 
1417     VkStridedDeviceAddressRegionKHR m_raygenShaderBindingTableRegion;
1418     VkStridedDeviceAddressRegionKHR m_missShaderBindingTableRegion;
1419     VkStridedDeviceAddressRegionKHR m_hitShaderBindingTableRegion;
1420     VkStridedDeviceAddressRegionKHR m_callableShaderBindingTableRegion;
1421 
1422     de::SharedPtr<BottomLevelAccelerationStructure> m_bottomLevelAccelerationStructure;
1423     de::SharedPtr<TopLevelAccelerationStructure> m_topLevelAccelerationStructure;
1424 };
1425 
BindingAcceleratioStructureRayTracingTestInstance(Context & context,const TestParams & testParams)1426 BindingAcceleratioStructureRayTracingTestInstance::BindingAcceleratioStructureRayTracingTestInstance(
1427     Context &context, const TestParams &testParams)
1428     : BindingAcceleratioStructureTestInstance(context, testParams)
1429     , m_shaders(0)
1430     , m_raygenShaderGroup(~0u)
1431     , m_missShaderGroup(~0u)
1432     , m_hitShaderGroup(~0u)
1433     , m_callableShaderGroup(~0u)
1434     , m_shaderGroupCount(0)
1435 
1436     , m_descriptorSetLayoutSvc()
1437     , m_descriptorSetSvc()
1438 
1439     , m_rayTracingPipeline()
1440 
1441     , m_raygenShaderBindingTable()
1442     , m_hitShaderBindingTable()
1443     , m_missShaderBindingTable()
1444     , m_callableShaderBindingTable()
1445 
1446     , m_raygenShaderBindingTableRegion()
1447     , m_missShaderBindingTableRegion()
1448     , m_hitShaderBindingTableRegion()
1449     , m_callableShaderBindingTableRegion()
1450 
1451     , m_bottomLevelAccelerationStructure()
1452     , m_topLevelAccelerationStructure()
1453 {
1454 }
1455 
checkSupport(Context & context,const TestParams & testParams)1456 void BindingAcceleratioStructureRayTracingTestInstance::checkSupport(Context &context, const TestParams &testParams)
1457 {
1458     DE_UNREF(testParams);
1459 
1460     context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
1461 
1462     const VkPhysicalDeviceRayTracingPipelineFeaturesKHR &rayTracingPipelineFeaturesKHR =
1463         context.getRayTracingPipelineFeatures();
1464 
1465     if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == false)
1466         TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
1467 }
1468 
initPrograms(SourceCollections & programCollection,const TestParams & testParams)1469 void BindingAcceleratioStructureRayTracingTestInstance::initPrograms(SourceCollections &programCollection,
1470                                                                      const TestParams &testParams)
1471 {
1472     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
1473     const std::string testShaderBody = testParams.testConfigShaderBodyText(testParams);
1474     const std::string testBody       = "  ivec3       pos      = ivec3(gl_LaunchIDEXT);\n"
1475                                        "  ivec3       size     = ivec3(gl_LaunchSizeEXT);\n" +
1476                                  testShaderBody;
1477     const std::string commonRayGenerationShader =
1478         std::string(glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460)) +
1479         "\n"
1480         "#extension GL_EXT_ray_tracing : require\n"
1481         "\n"
1482         "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
1483         "layout(set = 2, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
1484         "\n"
1485         "void main()\n"
1486         "{\n"
1487         "  uint  rayFlags = 0;\n"
1488         "  uint  cullMask = 0xFF;\n"
1489         "  float tmin     = 0.0;\n"
1490         "  float tmax     = 9.0;\n"
1491         "  vec3  origin   = vec3((float(gl_LaunchIDEXT.x) + 0.5f) / float(gl_LaunchSizeEXT.x), "
1492         "(float(gl_LaunchIDEXT.y) + 0.5f) / float(gl_LaunchSizeEXT.y), 0.0);\n"
1493         "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
1494         "  traceRayEXT(topLevelAS, rayFlags, cullMask, 0, 0, 0, origin, tmin, direct, tmax, 0);\n"
1495         "}\n";
1496 
1497     switch (testParams.stage)
1498     {
1499     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1500     {
1501         std::stringstream css;
1502         css << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
1503             << "#extension GL_EXT_ray_tracing : require\n"
1504             << "#extension GL_EXT_ray_query : require\n"
1505             << "\n"
1506             << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
1507             << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1508             << "\n"
1509             << "void main()\n"
1510             << "{\n"
1511             << testBody << "}\n";
1512 
1513         programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1514 
1515         break;
1516     }
1517 
1518     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
1519     {
1520         programCollection.glslSources.add("rgen") << glu::RaygenSource(commonRayGenerationShader) << buildOptions;
1521 
1522         {
1523             std::stringstream css;
1524             css << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
1525                 << "#extension GL_EXT_ray_tracing : require\n"
1526                 << "#extension GL_EXT_ray_query : require\n"
1527                 << "\n"
1528                 << "hitAttributeEXT vec3 attribs;\n"
1529                 << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1530                 << "\n"
1531                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
1532                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1533                 << "\n"
1534                 << "void main()\n"
1535                 << "{\n"
1536                 << testBody << "}\n";
1537 
1538             programCollection.glslSources.add("ahit") << glu::AnyHitSource(css.str()) << buildOptions;
1539         }
1540 
1541         programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1542         programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1543 
1544         break;
1545     }
1546 
1547     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1548     {
1549         programCollection.glslSources.add("rgen") << glu::RaygenSource(commonRayGenerationShader) << buildOptions;
1550 
1551         {
1552             std::stringstream css;
1553             css << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
1554                 << "#extension GL_EXT_ray_tracing : require\n"
1555                 << "#extension GL_EXT_ray_query : require\n"
1556                 << "\n"
1557                 << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1558                 << "hitAttributeEXT vec3 attribs;\n"
1559                 << "\n"
1560                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
1561                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1562                 << "\n"
1563                 << "void main()\n"
1564                 << "{\n"
1565                 << testBody << "}\n";
1566 
1567             programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
1568         }
1569 
1570         programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1571         programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1572 
1573         break;
1574     }
1575 
1576     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
1577     {
1578         programCollection.glslSources.add("rgen") << glu::RaygenSource(commonRayGenerationShader) << buildOptions;
1579 
1580         {
1581             std::stringstream css;
1582             css << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
1583                 << "#extension GL_EXT_ray_tracing : require\n"
1584                 << "#extension GL_EXT_ray_query : require\n"
1585                 << "hitAttributeEXT vec3 hitAttribute;\n"
1586                 << "\n"
1587                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
1588                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1589                 << "\n"
1590                 << "void main()\n"
1591                 << "{\n"
1592                 << testBody << "  hitAttribute = vec3(0.0f, 0.0f, 0.0f);\n"
1593                 << "  reportIntersectionEXT(1.0f, 0);\n"
1594                 << "}\n";
1595 
1596             programCollection.glslSources.add("sect") << glu::IntersectionSource(css.str()) << buildOptions;
1597         }
1598 
1599         programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1600         programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1601         programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1602 
1603         break;
1604     }
1605 
1606     case VK_SHADER_STAGE_MISS_BIT_KHR:
1607     {
1608         programCollection.glslSources.add("rgen") << glu::RaygenSource(commonRayGenerationShader) << buildOptions;
1609 
1610         {
1611             std::stringstream css;
1612             css << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
1613                 << "#extension GL_EXT_ray_tracing : require\n"
1614                 << "#extension GL_EXT_ray_query : require\n"
1615                 << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1616                 << "\n"
1617                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
1618                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1619                 << "\n"
1620                 << "void main()\n"
1621                 << "{\n"
1622                 << testBody << "}\n";
1623 
1624             programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
1625         }
1626 
1627         programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1628         programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1629 
1630         break;
1631     }
1632 
1633     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
1634     {
1635         {
1636             std::stringstream css;
1637             css << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
1638                 << "#extension GL_EXT_ray_tracing : require\n"
1639                 << "#extension GL_EXT_ray_query : require\n"
1640                 << "\n"
1641                 << "layout(location = 0) callableDataEXT float dummy;"
1642                 << "layout(set = 2, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
1643                 << "\n"
1644                 << "void main()\n"
1645                 << "{\n"
1646                 << "  executeCallableEXT(0, 0);\n"
1647                 << "}\n";
1648 
1649             programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1650         }
1651 
1652         {
1653             std::stringstream css;
1654             css << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
1655                 << "#extension GL_EXT_ray_tracing : require\n"
1656                 << "#extension GL_EXT_ray_query : require\n"
1657                 << "layout(location = 0) callableDataInEXT float dummy;"
1658                 << "\n"
1659                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;\n"
1660                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1661                 << "\n"
1662                 << "void main()\n"
1663                 << "{\n"
1664                 << testBody << "}\n";
1665 
1666             programCollection.glslSources.add("call") << glu::CallableSource(css.str()) << buildOptions;
1667         }
1668 
1669         programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1670         programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1671         programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1672 
1673         break;
1674     }
1675 
1676     default:
1677         TCU_THROW(InternalError, "Unknown stage");
1678     }
1679 }
1680 
initPipeline(void)1681 void BindingAcceleratioStructureRayTracingTestInstance::initPipeline(void)
1682 {
1683     const InstanceInterface &vki          = m_context.getInstanceInterface();
1684     const DeviceInterface &vkd            = m_context.getDeviceInterface();
1685     const VkDevice device                 = m_context.getDevice();
1686     const VkPhysicalDevice physicalDevice = m_context.getPhysicalDevice();
1687     vk::BinaryCollection &collection      = m_context.getBinaryCollection();
1688     Allocator &allocator                  = m_context.getDefaultAllocator();
1689     const uint32_t shaderGroupHandleSize  = getShaderGroupHandleSize(vki, physicalDevice);
1690     const VkShaderStageFlags hitStages =
1691         VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
1692     uint32_t shaderCount = 0;
1693 
1694     m_shaderGroupCount = 0;
1695 
1696     if (collection.contains("rgen"))
1697         m_shaders |= VK_SHADER_STAGE_RAYGEN_BIT_KHR;
1698     if (collection.contains("ahit"))
1699         m_shaders |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
1700     if (collection.contains("chit"))
1701         m_shaders |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
1702     if (collection.contains("miss"))
1703         m_shaders |= VK_SHADER_STAGE_MISS_BIT_KHR;
1704     if (collection.contains("sect"))
1705         m_shaders |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
1706     if (collection.contains("call"))
1707         m_shaders |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
1708 
1709     for (BinaryCollection::Iterator it = collection.begin(); it != collection.end(); ++it)
1710         shaderCount++;
1711 
1712     if (shaderCount != (uint32_t)dePop32(m_shaders))
1713         TCU_THROW(InternalError, "Unused shaders detected in the collection");
1714 
1715     if (0 != (m_shaders & VK_SHADER_STAGE_RAYGEN_BIT_KHR))
1716         m_raygenShaderGroup = m_shaderGroupCount++;
1717 
1718     if (0 != (m_shaders & VK_SHADER_STAGE_MISS_BIT_KHR))
1719         m_missShaderGroup = m_shaderGroupCount++;
1720 
1721     if (0 != (m_shaders & hitStages))
1722         m_hitShaderGroup = m_shaderGroupCount++;
1723 
1724     if (0 != (m_shaders & VK_SHADER_STAGE_CALLABLE_BIT_KHR))
1725         m_callableShaderGroup = m_shaderGroupCount++;
1726 
1727     m_rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
1728 
1729     m_descriptorSetLayoutSvc =
1730         DescriptorSetLayoutBuilder()
1731             .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
1732             .build(vkd, device);
1733     m_descriptorSetSvc = makeDescriptorSet(vkd, device, *m_descriptorPool, *m_descriptorSetLayoutSvc);
1734 
1735     if (0 != (m_shaders & VK_SHADER_STAGE_RAYGEN_BIT_KHR))
1736         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,
1737                                         createShaderModule(vkd, device, collection.get("rgen"), 0),
1738                                         m_raygenShaderGroup);
1739     if (0 != (m_shaders & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))
1740         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR,
1741                                         createShaderModule(vkd, device, collection.get("ahit"), 0), m_hitShaderGroup);
1742     if (0 != (m_shaders & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))
1743         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
1744                                         createShaderModule(vkd, device, collection.get("chit"), 0), m_hitShaderGroup);
1745     if (0 != (m_shaders & VK_SHADER_STAGE_MISS_BIT_KHR))
1746         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,
1747                                         createShaderModule(vkd, device, collection.get("miss"), 0), m_missShaderGroup);
1748     if (0 != (m_shaders & VK_SHADER_STAGE_INTERSECTION_BIT_KHR))
1749         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR,
1750                                         createShaderModule(vkd, device, collection.get("sect"), 0), m_hitShaderGroup);
1751     if (0 != (m_shaders & VK_SHADER_STAGE_CALLABLE_BIT_KHR))
1752         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR,
1753                                         createShaderModule(vkd, device, collection.get("call"), 0),
1754                                         m_callableShaderGroup);
1755 
1756     m_pipelineLayout = makePipelineLayout(vkd, device, m_descriptorSetLayoutAS.get(), m_descriptorSetLayoutImg.get(),
1757                                           m_descriptorSetLayoutSvc.get());
1758     m_pipeline       = m_rayTracingPipeline->createPipeline(vkd, device, *m_pipelineLayout);
1759 
1760     m_raygenShaderBindingTable   = createShaderBindingTable(vki, vkd, device, physicalDevice, *m_pipeline, allocator,
1761                                                             m_rayTracingPipeline, m_raygenShaderGroup);
1762     m_missShaderBindingTable     = createShaderBindingTable(vki, vkd, device, physicalDevice, *m_pipeline, allocator,
1763                                                             m_rayTracingPipeline, m_missShaderGroup);
1764     m_hitShaderBindingTable      = createShaderBindingTable(vki, vkd, device, physicalDevice, *m_pipeline, allocator,
1765                                                             m_rayTracingPipeline, m_hitShaderGroup);
1766     m_callableShaderBindingTable = createShaderBindingTable(vki, vkd, device, physicalDevice, *m_pipeline, allocator,
1767                                                             m_rayTracingPipeline, m_callableShaderGroup);
1768 
1769     m_raygenShaderBindingTableRegion =
1770         makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(m_raygenShaderBindingTable), shaderGroupHandleSize, 1);
1771     m_missShaderBindingTableRegion =
1772         makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(m_missShaderBindingTable), shaderGroupHandleSize, 1);
1773     m_hitShaderBindingTableRegion =
1774         makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(m_hitShaderBindingTable), shaderGroupHandleSize, 1);
1775     m_callableShaderBindingTableRegion = makeStridedDeviceAddressRegion(
1776         vkd, device, getVkBuffer(m_callableShaderBindingTable), shaderGroupHandleSize, 1);
1777 }
1778 
fillCommandBuffer(VkCommandBuffer commandBuffer)1779 void BindingAcceleratioStructureRayTracingTestInstance::fillCommandBuffer(VkCommandBuffer commandBuffer)
1780 {
1781     const DeviceInterface &vkd = m_context.getDeviceInterface();
1782     const VkDevice device      = m_context.getDevice();
1783     Allocator &allocator       = m_context.getDefaultAllocator();
1784     de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure =
1785         makeBottomLevelAccelerationStructure();
1786     de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure = makeTopLevelAccelerationStructure();
1787 
1788     m_bottomLevelAccelerationStructure =
1789         de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release());
1790     m_bottomLevelAccelerationStructure->setDefaultGeometryData(m_testParams.stage);
1791     m_bottomLevelAccelerationStructure->createAndBuild(vkd, device, commandBuffer, allocator);
1792 
1793     m_topLevelAccelerationStructure =
1794         de::SharedPtr<TopLevelAccelerationStructure>(topLevelAccelerationStructure.release());
1795     m_topLevelAccelerationStructure->setInstanceCount(1);
1796     m_topLevelAccelerationStructure->addInstance(m_bottomLevelAccelerationStructure);
1797     m_topLevelAccelerationStructure->createAndBuild(vkd, device, commandBuffer, allocator);
1798 
1799     const TopLevelAccelerationStructure *topLevelAccelerationStructurePtr = m_topLevelAccelerationStructure.get();
1800     const VkWriteDescriptorSetAccelerationStructureKHR accelerationStructureWriteDescriptorSet =
1801         makeWriteDescriptorSetAccelerationStructureKHR(topLevelAccelerationStructurePtr->getPtr());
1802 
1803     DescriptorSetUpdateBuilder()
1804         .writeSingle(*m_descriptorSetSvc, DescriptorSetUpdateBuilder::Location::binding(0u),
1805                      VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
1806         .update(vkd, device);
1807 
1808     vkd.cmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *m_pipelineLayout, 2, 1,
1809                               &m_descriptorSetSvc.get(), 0, DE_NULL);
1810 
1811     vkd.cmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, m_pipeline.get());
1812 
1813     cmdTraceRays(vkd, commandBuffer, &m_raygenShaderBindingTableRegion, &m_missShaderBindingTableRegion,
1814                  &m_hitShaderBindingTableRegion, &m_callableShaderBindingTableRegion, m_testParams.width,
1815                  m_testParams.height, 1);
1816 }
1817 
createShaderBindingTable(const InstanceInterface & vki,const DeviceInterface & vkd,const VkDevice device,const VkPhysicalDevice physicalDevice,const VkPipeline pipeline,Allocator & allocator,de::MovePtr<RayTracingPipeline> & rayTracingPipeline,const uint32_t group)1818 de::MovePtr<BufferWithMemory> BindingAcceleratioStructureRayTracingTestInstance::createShaderBindingTable(
1819     const InstanceInterface &vki, const DeviceInterface &vkd, const VkDevice device,
1820     const VkPhysicalDevice physicalDevice, const VkPipeline pipeline, Allocator &allocator,
1821     de::MovePtr<RayTracingPipeline> &rayTracingPipeline, const uint32_t group)
1822 {
1823     de::MovePtr<BufferWithMemory> shaderBindingTable;
1824 
1825     if (group < m_shaderGroupCount)
1826     {
1827         const uint32_t shaderGroupHandleSize    = getShaderGroupHandleSize(vki, physicalDevice);
1828         const uint32_t shaderGroupBaseAlignment = getShaderGroupBaseAlignment(vki, physicalDevice);
1829 
1830         shaderBindingTable = rayTracingPipeline->createShaderBindingTable(
1831             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, group, 1u);
1832     }
1833 
1834     return shaderBindingTable;
1835 }
1836 
1837 class BindingAcceleratioStructureRayTracingRayTracingTestInstance : public BindingAcceleratioStructureTestInstance
1838 {
1839 public:
1840     BindingAcceleratioStructureRayTracingRayTracingTestInstance(Context &context, const TestParams &testParams);
~BindingAcceleratioStructureRayTracingRayTracingTestInstance()1841     virtual ~BindingAcceleratioStructureRayTracingRayTracingTestInstance()
1842     {
1843     }
1844 
1845     static void checkSupport(Context &context, const TestParams &testParams);
1846     static void initPrograms(SourceCollections &programCollection, const TestParams &testParams);
1847 
1848 protected:
1849     virtual void initPipeline(void) override;
1850     virtual void fillCommandBuffer(VkCommandBuffer commandBuffer) override;
1851 
1852     void calcShaderGroup(uint32_t &shaderGroupCounter, const VkShaderStageFlags shaders1,
1853                          const VkShaderStageFlags shaders2, const VkShaderStageFlags shaderStageFlags,
1854                          uint32_t &shaderGroup, uint32_t &shaderGroupCount) const;
1855 
1856     de::MovePtr<BufferWithMemory> createShaderBindingTable(const InstanceInterface &vki, const DeviceInterface &vkd,
1857                                                            const VkDevice device, const VkPhysicalDevice physicalDevice,
1858                                                            const VkPipeline pipeline, Allocator &allocator,
1859                                                            de::MovePtr<RayTracingPipeline> &rayTracingPipeline,
1860                                                            const uint32_t group, const uint32_t groupCount = 1);
1861 
getExtraAccelerationDescriptorCount(void)1862     virtual uint32_t getExtraAccelerationDescriptorCount(void) override
1863     {
1864         return 1;
1865     }
getShaderStageFlags(void)1866     virtual VkShaderStageFlags getShaderStageFlags(void) override
1867     {
1868         return ALL_RAY_TRACING_STAGES;
1869     }
getPipelineBindPoint(void)1870     virtual VkPipelineBindPoint getPipelineBindPoint(void) override
1871     {
1872         return VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR;
1873     }
1874 
1875     uint32_t m_shaders;
1876     uint32_t m_raygenShaderGroup;
1877     uint32_t m_missShaderGroup;
1878     uint32_t m_hitShaderGroup;
1879     uint32_t m_callableShaderGroup;
1880     uint32_t m_shaderGroupCount;
1881 
1882     Move<VkDescriptorSetLayout> m_descriptorSetLayoutSvc;
1883     Move<VkDescriptorSet> m_descriptorSetSvc;
1884 
1885     de::MovePtr<RayTracingPipeline> m_rayTracingPipeline;
1886 
1887     de::MovePtr<BufferWithMemory> m_raygenShaderBindingTable;
1888     de::MovePtr<BufferWithMemory> m_hitShaderBindingTable;
1889     de::MovePtr<BufferWithMemory> m_missShaderBindingTable;
1890     de::MovePtr<BufferWithMemory> m_callableShaderBindingTable;
1891 
1892     VkStridedDeviceAddressRegionKHR m_raygenShaderBindingTableRegion;
1893     VkStridedDeviceAddressRegionKHR m_missShaderBindingTableRegion;
1894     VkStridedDeviceAddressRegionKHR m_hitShaderBindingTableRegion;
1895     VkStridedDeviceAddressRegionKHR m_callableShaderBindingTableRegion;
1896 
1897     de::SharedPtr<BottomLevelAccelerationStructure> m_bottomLevelAccelerationStructure;
1898     de::SharedPtr<TopLevelAccelerationStructure> m_topLevelAccelerationStructure;
1899 };
1900 
1901 BindingAcceleratioStructureRayTracingRayTracingTestInstance::
BindingAcceleratioStructureRayTracingRayTracingTestInstance(Context & context,const TestParams & testParams)1902     BindingAcceleratioStructureRayTracingRayTracingTestInstance(Context &context, const TestParams &testParams)
1903     : BindingAcceleratioStructureTestInstance(context, testParams)
1904     , m_shaders(0)
1905     , m_raygenShaderGroup(~0u)
1906     , m_missShaderGroup(~0u)
1907     , m_hitShaderGroup(~0u)
1908     , m_callableShaderGroup(~0u)
1909     , m_shaderGroupCount(0)
1910 
1911     , m_descriptorSetLayoutSvc()
1912     , m_descriptorSetSvc()
1913 
1914     , m_rayTracingPipeline()
1915 
1916     , m_raygenShaderBindingTable()
1917     , m_hitShaderBindingTable()
1918     , m_missShaderBindingTable()
1919     , m_callableShaderBindingTable()
1920 
1921     , m_raygenShaderBindingTableRegion()
1922     , m_missShaderBindingTableRegion()
1923     , m_hitShaderBindingTableRegion()
1924     , m_callableShaderBindingTableRegion()
1925 
1926     , m_bottomLevelAccelerationStructure()
1927     , m_topLevelAccelerationStructure()
1928 {
1929 }
1930 
checkSupport(Context & context,const TestParams & testParams)1931 void BindingAcceleratioStructureRayTracingRayTracingTestInstance::checkSupport(Context &context,
1932                                                                                const TestParams &testParams)
1933 {
1934     DE_UNREF(testParams);
1935 
1936     context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
1937 
1938     const VkPhysicalDeviceRayTracingPipelineFeaturesKHR &rayTracingPipelineFeaturesKHR =
1939         context.getRayTracingPipelineFeatures();
1940 
1941     if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == false)
1942         TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
1943     const VkPhysicalDeviceRayTracingPipelinePropertiesKHR &rayTracingPipelinePropertiesKHR =
1944         context.getRayTracingPipelineProperties();
1945     if (rayTracingPipelinePropertiesKHR.maxRayRecursionDepth < 2 &&
1946         testParams.testType == TEST_TYPE_USING_RAY_TRACING &&
1947         (testParams.stage == VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR || testParams.stage == VK_SHADER_STAGE_MISS_BIT_KHR))
1948         TCU_THROW(NotSupportedError, "rayTracingPipelinePropertiesKHR.maxRayRecursionDepth is smaller than required");
1949 }
1950 
initPrograms(SourceCollections & programCollection,const TestParams & testParams)1951 void BindingAcceleratioStructureRayTracingRayTracingTestInstance::initPrograms(SourceCollections &programCollection,
1952                                                                                const TestParams &testParams)
1953 {
1954     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
1955     const std::string testShaderBody = testParams.testConfigShaderBodyText(testParams);
1956     const std::string testBody       = "  ivec3       pos      = ivec3(gl_LaunchIDEXT);\n"
1957                                        "  ivec3       size     = ivec3(gl_LaunchSizeEXT);\n" +
1958                                  testShaderBody;
1959     const std::string testOutClosestHitShader = std::string(glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460)) +
1960                                                 "\n"
1961                                                 "#extension GL_EXT_ray_tracing : require\n"
1962                                                 "\n"
1963                                                 "hitAttributeEXT vec3 attribs;\n"
1964                                                 "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1965                                                 "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1966                                                 "\n"
1967                                                 "void main()\n"
1968                                                 "{\n" +
1969                                                 testBody + "}\n";
1970     const std::string testInShaderFragment =
1971         "  uint  rayFlags = 0;\n"
1972         "  uint  cullMask = 0xFF;\n"
1973         "  float tmin     = 0.0;\n"
1974         "  float tmax     = 9.0;\n"
1975         "  vec3  origin   = vec3((float(gl_LaunchIDEXT.x) + 0.5f) / float(gl_LaunchSizeEXT.x), "
1976         "(float(gl_LaunchIDEXT.y) + 0.5f) / float(gl_LaunchSizeEXT.y), 0.0);\n"
1977         "  vec3  direct   = vec3(0.0, 0.0, 1.0);\n"
1978         "\n"
1979         "  traceRayEXT(topLevelAS, rayFlags, cullMask, 1, 0, 1, origin, tmin, direct, tmax, 0);\n";
1980     const std::string commonRayGenerationShader =
1981         std::string(glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460)) +
1982         "\n"
1983         "#extension GL_EXT_ray_tracing : require\n"
1984         "\n"
1985         "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
1986         "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
1987         "layout(set = 2, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
1988         "\n"
1989         "void main()\n"
1990         "{\n"
1991         "  uint  rayFlags = 0;\n"
1992         "  uint  cullMask = 0xFF;\n"
1993         "  float tmin     = 0.0;\n"
1994         "  float tmax     = 9.0;\n"
1995         "  vec3  origin   = vec3((float(gl_LaunchIDEXT.x) + 0.5f) / float(gl_LaunchSizeEXT.x), "
1996         "(float(gl_LaunchIDEXT.y) + 0.5f) / float(gl_LaunchSizeEXT.y), 0.0);\n"
1997         "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
1998         "\n"
1999         "  traceRayEXT(topLevelAS, rayFlags, cullMask, 0, 0, 0, origin, tmin, direct, tmax, 0);\n"
2000         "}\n";
2001 
2002     programCollection.glslSources.add("chit0") << glu::ClosestHitSource(testOutClosestHitShader) << buildOptions;
2003     programCollection.glslSources.add("ahit0") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
2004     programCollection.glslSources.add("miss0") << glu::MissSource(getMissPassthrough()) << buildOptions;
2005 
2006     switch (testParams.stage)
2007     {
2008     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
2009     {
2010         {
2011             std::stringstream css;
2012             css << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
2013                 << "#extension GL_EXT_ray_tracing : require\n"
2014                 << "\n"
2015                 << "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
2016                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
2017                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
2018                 << "\n"
2019                 << "void main()\n"
2020                 << "{\n"
2021                 << testInShaderFragment << "}\n";
2022 
2023             programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
2024         }
2025 
2026         programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
2027         programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
2028         programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
2029 
2030         break;
2031     }
2032 
2033     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
2034     {
2035         programCollection.glslSources.add("rgen") << glu::RaygenSource(commonRayGenerationShader) << buildOptions;
2036 
2037         {
2038             std::stringstream css;
2039             css << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
2040                 << "#extension GL_EXT_ray_tracing : require\n"
2041                 << "\n"
2042                 << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
2043                 << "hitAttributeEXT vec3 attribs;\n"
2044                 << "\n"
2045                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
2046                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
2047                 << "\n"
2048                 << "void main()\n"
2049                 << "{\n"
2050                 << testInShaderFragment << "}\n";
2051 
2052             programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
2053         }
2054 
2055         programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
2056         programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
2057 
2058         break;
2059     }
2060 
2061     case VK_SHADER_STAGE_MISS_BIT_KHR:
2062     {
2063         programCollection.glslSources.add("rgen") << glu::RaygenSource(commonRayGenerationShader) << buildOptions;
2064 
2065         {
2066             std::stringstream css;
2067             css << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_460) << "\n"
2068                 << "#extension GL_EXT_ray_tracing : require\n"
2069                 << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
2070                 << "\n"
2071                 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
2072                 << "layout(set = 1, binding = 0, r32i) uniform iimage3D result;\n"
2073                 << "\n"
2074                 << "void main()\n"
2075                 << "{\n"
2076                 << testInShaderFragment << "}\n";
2077 
2078             programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
2079         }
2080 
2081         programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
2082         programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
2083 
2084         break;
2085     }
2086 
2087     default:
2088         TCU_THROW(InternalError, "Unknown stage");
2089     }
2090 }
2091 
calcShaderGroup(uint32_t & shaderGroupCounter,const VkShaderStageFlags shaders1,const VkShaderStageFlags shaders2,const VkShaderStageFlags shaderStageFlags,uint32_t & shaderGroup,uint32_t & shaderGroupCount) const2092 void BindingAcceleratioStructureRayTracingRayTracingTestInstance::calcShaderGroup(
2093     uint32_t &shaderGroupCounter, const VkShaderStageFlags shaders1, const VkShaderStageFlags shaders2,
2094     const VkShaderStageFlags shaderStageFlags, uint32_t &shaderGroup, uint32_t &shaderGroupCount) const
2095 {
2096     const uint32_t shader1Count = ((shaders1 & shaderStageFlags) != 0) ? 1 : 0;
2097     const uint32_t shader2Count = ((shaders2 & shaderStageFlags) != 0) ? 1 : 0;
2098 
2099     shaderGroupCount = shader1Count + shader2Count;
2100 
2101     if (shaderGroupCount != 0)
2102     {
2103         shaderGroup = shaderGroupCounter;
2104         shaderGroupCounter += shaderGroupCount;
2105     }
2106 }
2107 
initPipeline(void)2108 void BindingAcceleratioStructureRayTracingRayTracingTestInstance::initPipeline(void)
2109 {
2110     const InstanceInterface &vki          = m_context.getInstanceInterface();
2111     const DeviceInterface &vkd            = m_context.getDeviceInterface();
2112     const VkDevice device                 = m_context.getDevice();
2113     const VkPhysicalDevice physicalDevice = m_context.getPhysicalDevice();
2114     vk::BinaryCollection &collection      = m_context.getBinaryCollection();
2115     Allocator &allocator                  = m_context.getDefaultAllocator();
2116     const uint32_t shaderGroupHandleSize  = getShaderGroupHandleSize(vki, physicalDevice);
2117     const VkShaderStageFlags hitStages =
2118         VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
2119     uint32_t shaderCount            = 0;
2120     VkShaderStageFlags shaders0     = static_cast<VkShaderStageFlags>(0);
2121     uint32_t raygenShaderGroupCount = 0;
2122     uint32_t hitShaderGroupCount    = 0;
2123     uint32_t missShaderGroupCount   = 0;
2124 
2125     if (collection.contains("rgen"))
2126         m_shaders |= VK_SHADER_STAGE_RAYGEN_BIT_KHR;
2127     if (collection.contains("ahit"))
2128         m_shaders |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
2129     if (collection.contains("chit"))
2130         m_shaders |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2131     if (collection.contains("miss"))
2132         m_shaders |= VK_SHADER_STAGE_MISS_BIT_KHR;
2133 
2134     if (collection.contains("ahit0"))
2135         shaders0 |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
2136     if (collection.contains("chit0"))
2137         shaders0 |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2138     if (collection.contains("miss0"))
2139         shaders0 |= VK_SHADER_STAGE_MISS_BIT_KHR;
2140 
2141     for (BinaryCollection::Iterator it = collection.begin(); it != collection.end(); ++it)
2142         shaderCount++;
2143 
2144     if (shaderCount != (uint32_t)(dePop32(m_shaders) + dePop32(shaders0)))
2145         TCU_THROW(InternalError, "Unused shaders detected in the collection");
2146 
2147     calcShaderGroup(m_shaderGroupCount, m_shaders, shaders0, VK_SHADER_STAGE_RAYGEN_BIT_KHR, m_raygenShaderGroup,
2148                     raygenShaderGroupCount);
2149     calcShaderGroup(m_shaderGroupCount, m_shaders, shaders0, VK_SHADER_STAGE_MISS_BIT_KHR, m_missShaderGroup,
2150                     missShaderGroupCount);
2151     calcShaderGroup(m_shaderGroupCount, m_shaders, shaders0, hitStages, m_hitShaderGroup, hitShaderGroupCount);
2152 
2153     m_rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
2154 
2155     m_descriptorSetLayoutSvc =
2156         DescriptorSetLayoutBuilder()
2157             .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
2158             .build(vkd, device);
2159     m_descriptorSetSvc = makeDescriptorSet(vkd, device, *m_descriptorPool, *m_descriptorSetLayoutSvc);
2160 
2161     if (0 != (m_shaders & VK_SHADER_STAGE_RAYGEN_BIT_KHR))
2162         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR,
2163                                         createShaderModule(vkd, device, collection.get("rgen"), 0),
2164                                         m_raygenShaderGroup);
2165     if (0 != (m_shaders & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))
2166         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR,
2167                                         createShaderModule(vkd, device, collection.get("ahit"), 0), m_hitShaderGroup);
2168     if (0 != (m_shaders & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))
2169         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
2170                                         createShaderModule(vkd, device, collection.get("chit"), 0), m_hitShaderGroup);
2171     if (0 != (m_shaders & VK_SHADER_STAGE_MISS_BIT_KHR))
2172         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,
2173                                         createShaderModule(vkd, device, collection.get("miss"), 0), m_missShaderGroup);
2174 
2175     // The "chit" and "miss" cases both generate more rays from their shaders.
2176     if (m_testParams.testType == TEST_TYPE_USING_RAY_TRACING &&
2177         (m_testParams.stage == VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR ||
2178          m_testParams.stage == VK_SHADER_STAGE_MISS_BIT_KHR))
2179         m_rayTracingPipeline->setMaxRecursionDepth(2u);
2180 
2181     if (0 != (shaders0 & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))
2182         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR,
2183                                         createShaderModule(vkd, device, collection.get("ahit0"), 0),
2184                                         m_hitShaderGroup + 1);
2185     if (0 != (shaders0 & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))
2186         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
2187                                         createShaderModule(vkd, device, collection.get("chit0"), 0),
2188                                         m_hitShaderGroup + 1);
2189     if (0 != (shaders0 & VK_SHADER_STAGE_MISS_BIT_KHR))
2190         m_rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR,
2191                                         createShaderModule(vkd, device, collection.get("miss0"), 0),
2192                                         m_missShaderGroup + 1);
2193 
2194     m_pipelineLayout = makePipelineLayout(vkd, device, m_descriptorSetLayoutAS.get(), m_descriptorSetLayoutImg.get(),
2195                                           m_descriptorSetLayoutSvc.get());
2196     m_pipeline       = m_rayTracingPipeline->createPipeline(vkd, device, *m_pipelineLayout);
2197 
2198     m_raygenShaderBindingTable =
2199         createShaderBindingTable(vki, vkd, device, physicalDevice, *m_pipeline, allocator, m_rayTracingPipeline,
2200                                  m_raygenShaderGroup, raygenShaderGroupCount);
2201     m_missShaderBindingTable = createShaderBindingTable(vki, vkd, device, physicalDevice, *m_pipeline, allocator,
2202                                                         m_rayTracingPipeline, m_missShaderGroup, missShaderGroupCount);
2203     m_hitShaderBindingTable  = createShaderBindingTable(vki, vkd, device, physicalDevice, *m_pipeline, allocator,
2204                                                         m_rayTracingPipeline, m_hitShaderGroup, hitShaderGroupCount);
2205 
2206     m_raygenShaderBindingTableRegion = makeStridedDeviceAddressRegion(
2207         vkd, device, getVkBuffer(m_raygenShaderBindingTable), shaderGroupHandleSize, raygenShaderGroupCount);
2208     m_missShaderBindingTableRegion = makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(m_missShaderBindingTable),
2209                                                                     shaderGroupHandleSize, missShaderGroupCount);
2210     m_hitShaderBindingTableRegion  = makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(m_hitShaderBindingTable),
2211                                                                     shaderGroupHandleSize, hitShaderGroupCount);
2212     m_callableShaderBindingTableRegion = makeStridedDeviceAddressRegion(vkd, device, DE_NULL, 0, 0);
2213 }
2214 
fillCommandBuffer(VkCommandBuffer commandBuffer)2215 void BindingAcceleratioStructureRayTracingRayTracingTestInstance::fillCommandBuffer(VkCommandBuffer commandBuffer)
2216 {
2217     const DeviceInterface &vkd = m_context.getDeviceInterface();
2218     const VkDevice device      = m_context.getDevice();
2219     Allocator &allocator       = m_context.getDefaultAllocator();
2220     de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure =
2221         makeBottomLevelAccelerationStructure();
2222     de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure = makeTopLevelAccelerationStructure();
2223 
2224     m_bottomLevelAccelerationStructure =
2225         de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release());
2226     m_bottomLevelAccelerationStructure->setDefaultGeometryData(m_testParams.stage);
2227     m_bottomLevelAccelerationStructure->createAndBuild(vkd, device, commandBuffer, allocator);
2228 
2229     m_topLevelAccelerationStructure =
2230         de::SharedPtr<TopLevelAccelerationStructure>(topLevelAccelerationStructure.release());
2231     m_topLevelAccelerationStructure->setInstanceCount(1);
2232     m_topLevelAccelerationStructure->addInstance(m_bottomLevelAccelerationStructure);
2233     m_topLevelAccelerationStructure->createAndBuild(vkd, device, commandBuffer, allocator);
2234 
2235     const TopLevelAccelerationStructure *topLevelAccelerationStructurePtr = m_topLevelAccelerationStructure.get();
2236     const VkWriteDescriptorSetAccelerationStructureKHR accelerationStructureWriteDescriptorSet =
2237         makeWriteDescriptorSetAccelerationStructureKHR(topLevelAccelerationStructurePtr->getPtr());
2238 
2239     DescriptorSetUpdateBuilder()
2240         .writeSingle(*m_descriptorSetSvc, DescriptorSetUpdateBuilder::Location::binding(0u),
2241                      VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
2242         .update(vkd, device);
2243 
2244     vkd.cmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *m_pipelineLayout, 2, 1,
2245                               &m_descriptorSetSvc.get(), 0, DE_NULL);
2246 
2247     vkd.cmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, m_pipeline.get());
2248 
2249     cmdTraceRays(vkd, commandBuffer, &m_raygenShaderBindingTableRegion, &m_missShaderBindingTableRegion,
2250                  &m_hitShaderBindingTableRegion, &m_callableShaderBindingTableRegion, m_testParams.width,
2251                  m_testParams.height, 1);
2252 }
2253 
createShaderBindingTable(const InstanceInterface & vki,const DeviceInterface & vkd,const VkDevice device,const VkPhysicalDevice physicalDevice,const VkPipeline pipeline,Allocator & allocator,de::MovePtr<RayTracingPipeline> & rayTracingPipeline,const uint32_t group,const uint32_t groupCount)2254 de::MovePtr<BufferWithMemory> BindingAcceleratioStructureRayTracingRayTracingTestInstance::createShaderBindingTable(
2255     const InstanceInterface &vki, const DeviceInterface &vkd, const VkDevice device,
2256     const VkPhysicalDevice physicalDevice, const VkPipeline pipeline, Allocator &allocator,
2257     de::MovePtr<RayTracingPipeline> &rayTracingPipeline, const uint32_t group, const uint32_t groupCount)
2258 {
2259     de::MovePtr<BufferWithMemory> shaderBindingTable;
2260 
2261     if (group < m_shaderGroupCount)
2262     {
2263         const uint32_t shaderGroupHandleSize    = getShaderGroupHandleSize(vki, physicalDevice);
2264         const uint32_t shaderGroupBaseAlignment = getShaderGroupBaseAlignment(vki, physicalDevice);
2265 
2266         shaderBindingTable = rayTracingPipeline->createShaderBindingTable(
2267             vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, group, groupCount);
2268     }
2269 
2270     return shaderBindingTable;
2271 }
2272 
getRayQueryShaderBodyText(const TestParams & testParams)2273 const std::string getRayQueryShaderBodyText(const TestParams &testParams)
2274 {
2275     DE_UNREF(testParams);
2276 
2277     const std::string result =
2278         "  const float mult     = " + de::toString(FIXED_POINT_DIVISOR) +
2279         ".0f;\n"
2280         "  uint        rayFlags = 0;\n"
2281         "  uint        cullMask = 0xFF;\n"
2282         "  float       tmin     = 0.0;\n"
2283         "  float       tmax     = 9.0;\n"
2284         "  vec3        origin   = vec3((float(pos.x) + 0.5f) / float(size.x), (float(pos.y) + 0.5f) / float(size.y), "
2285         "0.0);\n"
2286         "  vec3        direct   = vec3(0.0, 0.0, 1.0);\n"
2287         "  int         value    = 0;\n"
2288         "  rayQueryEXT rayQuery;\n"
2289         "\n"
2290         "  rayQueryInitializeEXT(rayQuery, tlas, rayFlags, cullMask, origin, tmin, direct, tmax);\n"
2291         "\n"
2292         "  while(rayQueryProceedEXT(rayQuery))\n"
2293         "  {\n"
2294         "    if (rayQueryGetIntersectionTypeEXT(rayQuery, false) == gl_RayQueryCandidateIntersectionTriangleEXT)\n"
2295         "    {\n"
2296         "      const float t = rayQueryGetIntersectionTEXT(rayQuery, false);"
2297         "\n"
2298         "      value = int(round(mult * t));\n"
2299         "    }\n"
2300         "  }\n"
2301         "\n"
2302         "  imageStore(result, pos, ivec4(value, 0, 0, 0));\n";
2303 
2304     return result;
2305 }
2306 
getRayTracingShaderBodyText(const TestParams & testParams)2307 const std::string getRayTracingShaderBodyText(const TestParams &testParams)
2308 {
2309     DE_UNREF(testParams);
2310 
2311     const std::string result = "  const float mult     = " + de::toString(FIXED_POINT_DIVISOR) +
2312                                ".0f;\n"
2313                                "  int         value    = int(round(mult * gl_HitTEXT));\n"
2314                                "\n"
2315                                "  imageStore(result, pos, ivec4(value, 0, 0, 0));\n";
2316 
2317     return result;
2318 }
2319 
2320 class BindingAccelerationStructureTestCase : public TestCase
2321 {
2322 public:
2323     BindingAccelerationStructureTestCase(tcu::TestContext &context, const char *name, const TestParams testParams);
2324     ~BindingAccelerationStructureTestCase(void);
2325 
2326     virtual void checkSupport(Context &context) const;
2327     virtual void initPrograms(SourceCollections &programCollection) const;
2328     virtual TestInstance *createInstance(Context &context) const;
2329 
2330 private:
2331     TestParams m_testParams;
2332 };
2333 
BindingAccelerationStructureTestCase(tcu::TestContext & context,const char * name,const TestParams testParams)2334 BindingAccelerationStructureTestCase::BindingAccelerationStructureTestCase(tcu::TestContext &context, const char *name,
2335                                                                            const TestParams testParams)
2336     : vkt::TestCase(context, name)
2337     , m_testParams(testParams)
2338 {
2339 }
2340 
~BindingAccelerationStructureTestCase(void)2341 BindingAccelerationStructureTestCase::~BindingAccelerationStructureTestCase(void)
2342 {
2343 }
2344 
checkSupport(Context & context) const2345 void BindingAccelerationStructureTestCase::checkSupport(Context &context) const
2346 {
2347     context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
2348 
2349     const VkPhysicalDeviceAccelerationStructureFeaturesKHR &accelerationStructureFeaturesKHR =
2350         context.getAccelerationStructureFeatures();
2351     if (accelerationStructureFeaturesKHR.accelerationStructure == false)
2352         TCU_THROW(TestError, "Requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
2353 
2354     switch (m_testParams.testType)
2355     {
2356     case TEST_TYPE_USING_RAY_QUERY:
2357     {
2358         context.requireDeviceFunctionality("VK_KHR_ray_query");
2359 
2360         const VkPhysicalDeviceRayQueryFeaturesKHR &rayQueryFeaturesKHR = context.getRayQueryFeatures();
2361 
2362         if (rayQueryFeaturesKHR.rayQuery == false)
2363             TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayQueryFeaturesKHR.rayQuery");
2364 
2365         break;
2366     }
2367 
2368     case TEST_TYPE_USING_RAY_TRACING:
2369     {
2370         context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
2371 
2372         const VkPhysicalDeviceRayTracingPipelineFeaturesKHR &rayTracingPipelineFeaturesKHR =
2373             context.getRayTracingPipelineFeatures();
2374 
2375         if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == false)
2376             TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
2377 
2378         break;
2379     }
2380 
2381     default:
2382         TCU_THROW(InternalError, "Unknown test type");
2383     }
2384 
2385     switch (m_testParams.updateMethod)
2386     {
2387     case UPDATE_METHOD_NORMAL:
2388     {
2389         break;
2390     }
2391 
2392     case UPDATE_METHOD_WITH_TEMPLATE:
2393     {
2394         context.requireDeviceFunctionality("VK_KHR_descriptor_update_template");
2395 
2396         break;
2397     }
2398 
2399     case UPDATE_METHOD_WITH_PUSH:
2400     {
2401         context.requireDeviceFunctionality("VK_KHR_push_descriptor");
2402 
2403         break;
2404     }
2405 
2406     case UPDATE_METHOD_WITH_PUSH_TEMPLATE:
2407     {
2408         context.requireDeviceFunctionality("VK_KHR_push_descriptor");
2409         context.requireDeviceFunctionality("VK_KHR_descriptor_update_template");
2410 
2411         break;
2412     }
2413 
2414     default:
2415         TCU_THROW(InternalError, "Unknown update method");
2416     }
2417 
2418     m_testParams.pipelineCheckSupport(context, m_testParams);
2419 }
2420 
createInstance(Context & context) const2421 TestInstance *BindingAccelerationStructureTestCase::createInstance(Context &context) const
2422 {
2423     switch (m_testParams.testType)
2424     {
2425     case TEST_TYPE_USING_RAY_QUERY:
2426     {
2427         switch (m_testParams.stage)
2428         {
2429         case VK_SHADER_STAGE_VERTEX_BIT:
2430         case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
2431         case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
2432         case VK_SHADER_STAGE_GEOMETRY_BIT:
2433         case VK_SHADER_STAGE_FRAGMENT_BIT:
2434         {
2435             return new BindingAcceleratioStructureGraphicsTestInstance(context, m_testParams);
2436         }
2437 
2438         case VK_SHADER_STAGE_COMPUTE_BIT:
2439         {
2440             return new BindingAcceleratioStructureComputeTestInstance(context, m_testParams);
2441         }
2442 
2443         case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
2444         case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
2445         case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
2446         case VK_SHADER_STAGE_MISS_BIT_KHR:
2447         case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
2448         case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
2449         {
2450             return new BindingAcceleratioStructureRayTracingTestInstance(context, m_testParams);
2451         }
2452 
2453         default:
2454             TCU_THROW(InternalError, "Unknown shader stage");
2455         }
2456     }
2457 
2458     case TEST_TYPE_USING_RAY_TRACING:
2459     {
2460         return new BindingAcceleratioStructureRayTracingRayTracingTestInstance(context, m_testParams);
2461     }
2462 
2463     default:
2464         TCU_THROW(InternalError, "Unknown shader stage");
2465     }
2466 }
2467 
initPrograms(SourceCollections & programCollection) const2468 void BindingAccelerationStructureTestCase::initPrograms(SourceCollections &programCollection) const
2469 {
2470     m_testParams.pipelineInitPrograms(programCollection, m_testParams);
2471 }
2472 
getPipelineRayQueryCheckSupport(const VkShaderStageFlagBits stage)2473 static inline CheckSupportFunc getPipelineRayQueryCheckSupport(const VkShaderStageFlagBits stage)
2474 {
2475     switch (stage)
2476     {
2477     case VK_SHADER_STAGE_VERTEX_BIT:
2478     case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
2479     case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
2480     case VK_SHADER_STAGE_GEOMETRY_BIT:
2481     case VK_SHADER_STAGE_FRAGMENT_BIT:
2482         return BindingAcceleratioStructureGraphicsTestInstance::checkSupport;
2483 
2484     case VK_SHADER_STAGE_COMPUTE_BIT:
2485         return BindingAcceleratioStructureComputeTestInstance::checkSupport;
2486 
2487     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
2488     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
2489     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
2490     case VK_SHADER_STAGE_MISS_BIT_KHR:
2491     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
2492     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
2493         return BindingAcceleratioStructureRayTracingTestInstance::checkSupport;
2494 
2495     default:
2496         TCU_THROW(InternalError, "Unknown shader stage");
2497     }
2498 }
2499 
getPipelineRayTracingCheckSupport(const VkShaderStageFlagBits stage)2500 static inline CheckSupportFunc getPipelineRayTracingCheckSupport(const VkShaderStageFlagBits stage)
2501 {
2502     DE_UNREF(stage);
2503 
2504     return BindingAcceleratioStructureRayTracingRayTracingTestInstance::checkSupport;
2505 }
2506 
getPipelineRayQueryInitPrograms(const VkShaderStageFlagBits stage)2507 static inline InitProgramsFunc getPipelineRayQueryInitPrograms(const VkShaderStageFlagBits stage)
2508 {
2509     switch (stage)
2510     {
2511     case VK_SHADER_STAGE_VERTEX_BIT:
2512     case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
2513     case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
2514     case VK_SHADER_STAGE_GEOMETRY_BIT:
2515     case VK_SHADER_STAGE_FRAGMENT_BIT:
2516         return BindingAcceleratioStructureGraphicsTestInstance::initPrograms;
2517 
2518     case VK_SHADER_STAGE_COMPUTE_BIT:
2519         return BindingAcceleratioStructureComputeTestInstance::initPrograms;
2520 
2521     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
2522     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
2523     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
2524     case VK_SHADER_STAGE_MISS_BIT_KHR:
2525     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
2526     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
2527         return BindingAcceleratioStructureRayTracingTestInstance::initPrograms;
2528 
2529     default:
2530         TCU_THROW(InternalError, "Unknown shader stage");
2531     }
2532 }
2533 
getPipelineRayTracingInitPrograms(const VkShaderStageFlagBits stage)2534 static inline InitProgramsFunc getPipelineRayTracingInitPrograms(const VkShaderStageFlagBits stage)
2535 {
2536     DE_UNREF(stage);
2537 
2538     return BindingAcceleratioStructureRayTracingRayTracingTestInstance::initPrograms;
2539 }
2540 
getShaderBodyTextFunc(const TestType testType)2541 static inline ShaderBodyTextFunc getShaderBodyTextFunc(const TestType testType)
2542 {
2543     switch (testType)
2544     {
2545     case TEST_TYPE_USING_RAY_QUERY:
2546         return getRayQueryShaderBodyText;
2547     case TEST_TYPE_USING_RAY_TRACING:
2548         return getRayTracingShaderBodyText;
2549     default:
2550         TCU_THROW(InternalError, "Unknown test type");
2551     }
2552 }
2553 
2554 } // namespace
2555 
createDescriptorUpdateASTests(tcu::TestContext & testCtx)2556 tcu::TestCaseGroup *createDescriptorUpdateASTests(tcu::TestContext &testCtx)
2557 {
2558     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "acceleration_structure"));
2559 
2560     const struct TestTypes
2561     {
2562         TestType testType;
2563         const char *name;
2564     } testTypes[] = {
2565         {TEST_TYPE_USING_RAY_QUERY, "ray_query"},
2566         {TEST_TYPE_USING_RAY_TRACING, "ray_tracing"},
2567     };
2568     const struct UpdateMethods
2569     {
2570         const UpdateMethod method;
2571         const char *name;
2572     } updateMethods[] = {
2573         // Use regular descriptor updates
2574         {UPDATE_METHOD_NORMAL, "regular"},
2575         // Use descriptor update templates
2576         {UPDATE_METHOD_WITH_TEMPLATE, "with_template"},
2577         // Use push descriptor updates
2578         {UPDATE_METHOD_WITH_PUSH, "with_push"},
2579         // Use push descriptor update templates
2580         {UPDATE_METHOD_WITH_PUSH_TEMPLATE, "with_push_template"},
2581     };
2582     const struct PipelineStages
2583     {
2584         VkShaderStageFlagBits stage;
2585         const char *name;
2586         const bool rayTracing;
2587     } pipelineStages[] = {
2588         {VK_SHADER_STAGE_VERTEX_BIT, "vert", false},
2589         {VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, "tesc", false},
2590         {VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, "tese", false},
2591         {VK_SHADER_STAGE_GEOMETRY_BIT, "geom", false},
2592         {VK_SHADER_STAGE_FRAGMENT_BIT, "frag", false},
2593         {VK_SHADER_STAGE_COMPUTE_BIT, "comp", false},
2594         {VK_SHADER_STAGE_RAYGEN_BIT_KHR, "rgen", true},
2595         {VK_SHADER_STAGE_ANY_HIT_BIT_KHR, "ahit", false},
2596         {VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, "chit", true},
2597         {VK_SHADER_STAGE_MISS_BIT_KHR, "miss", true},
2598         {VK_SHADER_STAGE_INTERSECTION_BIT_KHR, "sect", false},
2599         {VK_SHADER_STAGE_CALLABLE_BIT_KHR, "call", false},
2600     };
2601 
2602     for (size_t testTypeNdx = 0; testTypeNdx < DE_LENGTH_OF_ARRAY(testTypes); ++testTypeNdx)
2603     {
2604         de::MovePtr<tcu::TestCaseGroup> testTypeGroup(
2605             new tcu::TestCaseGroup(group->getTestContext(), testTypes[testTypeNdx].name));
2606         const TestType testType                     = testTypes[testTypeNdx].testType;
2607         const ShaderBodyTextFunc shaderBodyTextFunc = getShaderBodyTextFunc(testType);
2608         const uint32_t imageDepth                   = 1;
2609 
2610         for (size_t updateMethodsNdx = 0; updateMethodsNdx < DE_LENGTH_OF_ARRAY(updateMethods); ++updateMethodsNdx)
2611         {
2612             de::MovePtr<tcu::TestCaseGroup> updateMethodsGroup(
2613                 new tcu::TestCaseGroup(group->getTestContext(), updateMethods[updateMethodsNdx].name));
2614             const UpdateMethod updateMethod = updateMethods[updateMethodsNdx].method;
2615 
2616             for (size_t pipelineStageNdx = 0; pipelineStageNdx < DE_LENGTH_OF_ARRAY(pipelineStages); ++pipelineStageNdx)
2617             {
2618                 const VkShaderStageFlagBits stage           = pipelineStages[pipelineStageNdx].stage;
2619                 const CheckSupportFunc pipelineCheckSupport = (testType == TEST_TYPE_USING_RAY_QUERY) ?
2620                                                                   getPipelineRayQueryCheckSupport(stage) :
2621                                                                   getPipelineRayTracingCheckSupport(stage);
2622                 const InitProgramsFunc pipelineInitPrograms = (testType == TEST_TYPE_USING_RAY_QUERY) ?
2623                                                                   getPipelineRayQueryInitPrograms(stage) :
2624                                                                   getPipelineRayTracingInitPrograms(stage);
2625 
2626                 if (testType == TEST_TYPE_USING_RAY_TRACING && !pipelineStages[pipelineStageNdx].rayTracing)
2627                     continue;
2628 
2629                 const TestParams testParams = {
2630                     TEST_WIDTH,           //  uint32_t width;
2631                     TEST_HEIGHT,          //  uint32_t height;
2632                     imageDepth,           //  uint32_t depth;
2633                     testType,             //  TestType testType;
2634                     updateMethod,         //  UpdateMethod updateMethod;
2635                     stage,                //  VkShaderStageFlagBits stage;
2636                     VK_FORMAT_R32_SINT,   //  VkFormat format;
2637                     pipelineCheckSupport, //  CheckSupportFunc pipelineCheckSupport;
2638                     pipelineInitPrograms, //  InitProgramsFunc pipelineInitPrograms;
2639                     shaderBodyTextFunc,   //  ShaderTestTextFunc testConfigShaderBodyText;
2640                 };
2641 
2642                 updateMethodsGroup->addChild(new BindingAccelerationStructureTestCase(
2643                     group->getTestContext(), pipelineStages[pipelineStageNdx].name, testParams));
2644             }
2645 
2646             testTypeGroup->addChild(updateMethodsGroup.release());
2647         }
2648 
2649         group->addChild(testTypeGroup.release());
2650     }
2651 
2652     return group.release();
2653 }
2654 
2655 } // namespace BindingModel
2656 } // namespace vkt
2657